Compare commits

..

27 Commits

Author SHA1 Message Date
VaalaCat
2436e7afb8 feat: add reasoning format 2025-06-15 12:59:44 +00:00
VaalaCat
67f3b169df feat: add include_reasoning 2025-06-15 12:59:08 +00:00
VaalaCat
3640274cd1 feat: change repo name 2025-06-15 12:58:45 +00:00
JT A.
ff9d83a485 skip json field (#1009)
* skip json field

* backfill some coverage and tests
2025-05-29 11:31:35 +01:00
Axb12
8c65b35c57 update image api *os.File to io.Reader (#994)
* update image api *os.File to io.Reader

* update code style

* add reader test

* supplementary reader test

* update the reader in the form builder test

* add commnet

* update comment

* update code style
2025-05-20 14:45:40 +01:00
Alex Baranov
4d2e7ab29d fix lint (#998) 2025-05-13 12:59:06 +01:00
Justa
6aaa732296 add ChatTemplateKwargs to ChatCompletionRequest (#980)
Co-authored-by: Justa <justa.cai@akuvox.com>
2025-05-13 12:52:44 +01:00
Pedro Chaparro
0116f2994d feat: add support for image generation using gpt-image-1 (#971)
* feat: add gpt-image-1 support

* feat: add example to generate image using gpt-image-1 model

* style: missing period in comments

* feat: add missing fields to example

* docs: add GPT Image 1 to README

* revert: keep `examples/images/main.go` unchanged

* docs: remove unnecessary newline from example in README file
2025-05-13 12:51:08 +01:00
Alex Baranov
8ba38f6ba1 remove backup file (#996) 2025-05-13 12:44:16 +01:00
Alex Baranov
6181facea7 update codecov action, pass token (#987) 2025-05-04 15:45:40 +01:00
Alex Baranov
77ccac8d34 Upgrade golangci-lint to 2.1.5 (#986)
* Upgrade golangci-lint to 2.1.5

* update action
2025-05-03 22:39:47 +01:00
Alex Baranov
5ea214a188 Improve unit test coverage (#984)
* add tests for config

* add audio tests

* lint

* lint

* lint
2025-05-03 22:25:14 +01:00
Ben Katz
d65f0cb54e Fix: Corrected typo in O4Mini20250416 model name and endpoint map. (#981) 2025-05-03 21:44:48 +01:00
Daniel Peng
93a611cf4f Add Prediction field (#970)
* Add Prediction field to ChatCompletionRequest

* Include prediction tokens in response
2025-04-29 14:38:27 +01:00
Oleksandr Redko
6836cf6a6f Remove redundant typecheck linter (#955) 2025-04-29 14:36:38 +01:00
Sean McGinnis
da5f9bc9bc Add CompletionRequest.StreamOptions (#959)
The legacy completion API supports a `stream_options` object when
`stream` is set to true [0]. This adds a StreamOptions property to the
CompletionRequest struct to support this setting.

[0] https://platform.openai.com/docs/api-reference/completions/create#completions-create-stream_options

Signed-off-by: Sean McGinnis <sean.mcginnis@gmail.com>
2025-04-29 14:35:26 +01:00
rory malcolm
bb5bc27567 Add support for 4o-mini and 3o (#968)
- This adds supports, and tests, for the 3o and 4o-mini class of models
2025-04-29 14:34:33 +01:00
Zhongxian Pan
4cccc6c934 Adapt different stream data prefix, with or without space (#945) 2025-04-29 14:29:15 +01:00
goodenough
306fbbbe6f Add support for reasoning_content field in chat completion messages for DeepSeek R1 (#925)
* support deepseek field "reasoning_content"

* support deepseek field "reasoning_content"

* Comment ends in a period (godot)

* add comment on field reasoning_content

* fix go lint error

* chore: trigger CI

* make field "content" in MarshalJSON function omitempty

* remove reasoning_content in TestO1ModelChatCompletions func

* feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.

* feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.

* feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.
2025-04-29 14:24:45 +01:00
netr
658beda2ba feat: Add missing TTS models and voices (#958)
* feat: Add missing TTS models and voices

* feat: Add new instruction field to create speech request

- From docs: Control the voice of your generated audio with additional instructions. Does not work with tts-1 or tts-1-hd.

* fix: add canary-tts back to SpeechModel
2025-04-26 11:13:43 +01:00
Takahiro Ikeuchi
d68a683815 feat: add new GPT-4.1 model variants to completion.go (#966)
* feat: add new GPT-4.1 model variants to completion.go

* feat: add tests for unsupported models in completion endpoint

* fix: add missing periods to test function comments in completion_test.go
2025-04-23 22:50:47 +01:00
JT A.
e99eb54c9d add enum tag to jsonschema (#962)
* fix jsonschema tests

* ensure all run during PR Github Action

* add test for struct to schema

* add support for enum tag

* support nullable tag
2025-04-13 19:00:48 +01:00
Liu Shuang
74d6449f22 feat: add gpt-4.5-preview models (#947) 2025-03-04 08:26:59 +00:00
Alex Baranov
261721bfdb Fix linter (#943)
* fix lint

* remove linters
2025-02-25 16:56:35 +00:00
Dan Ackerson
be2e2387d4 feat: add Anthropic API support with custom version header (#934)
* feat: add Anthropic API support with custom version header

* refactor: use switch statement for API type header handling

* refactor: add OpenAI & AzureAD types to be exhaustive

* Update client.go

need explicit fallthrough in empty case statements

* constant for APIVersion; addtl tests
2025-02-25 11:03:38 +00:00
Liu Shuang
85f578b865 fix: remove validateO1Specific (#939)
* fix: remove validateO1Specific

* update golangci-lint-action version

* fix actions

* fix actions

* fix actions

* fix actions

* remove some o1 test
2025-02-17 11:29:18 +00:00
Liu Shuang
c0a9a75fe0 feat: add developer role (#936) 2025-02-12 15:05:44 +00:00
63 changed files with 1487 additions and 682 deletions

View File

@@ -8,7 +8,7 @@ assignees: ''
--- ---
Your issue may already be reported! Your issue may already be reported!
Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. Please search on the [issue tracker](https://git.vaala.cloud/VaalaCat/go-openai/issues) before creating one.
**Describe the bug** **Describe the bug**
A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s). A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s).

View File

@@ -8,7 +8,7 @@ assignees: ''
--- ---
Your issue may already be reported! Your issue may already be reported!
Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. Please search on the [issue tracker](https://git.vaala.cloud/VaalaCat/go-openai/issues) before creating one.
**Is your feature request related to a problem? Please describe.** **Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

View File

@@ -1,5 +1,5 @@
A similar PR may already be submitted! A similar PR may already be submitted!
Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. Please search among the [Pull request](https://git.vaala.cloud/VaalaCat/go-openai/pulls) before creating one.
If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly. If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly.

View File

@@ -13,15 +13,17 @@ jobs:
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: '1.21' go-version: '1.24'
- name: Run vet - name: Run vet
run: | run: |
go vet . go vet .
- name: Run golangci-lint - name: Run golangci-lint
uses: golangci/golangci-lint-action@v4 uses: golangci/golangci-lint-action@v7
with: with:
version: latest version: v2.1.5
- name: Run tests - name: Run tests
run: go test -race -covermode=atomic -coverprofile=coverage.out -v . run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./...
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4 uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1,66 +1,94 @@
## Golden config for golangci-lint v1.47.3 version: "2"
# linters:
# This is the best config for golangci-lint based on my experience and opinion. default: none
# It is very strict, but not extremely strict. enable:
# Feel free to adopt and change it for your needs. - asciicheck
- bidichk
run: - bodyclose
# Timeout for analysis, e.g. 30s, 5m. - contextcheck
# Default: 1m - cyclop
timeout: 3m - dupl
- durationcheck
- errcheck
# This file contains only configs which differ from defaults. - errname
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml - errorlint
linters-settings: - exhaustive
- forbidigo
- funlen
- gochecknoinits
- gocognit
- goconst
- gocritic
- gocyclo
- godot
- gomoddirectives
- gomodguard
- goprintffuncname
- gosec
- govet
- ineffassign
- lll
- makezero
- mnd
- nestif
- nilerr
- nilnil
- nolintlint
- nosprintfhostport
- predeclared
- promlinter
- revive
- rowserrcheck
- sqlclosecheck
- staticcheck
- testpackage
- tparallel
- unconvert
- unparam
- unused
- usetesting
- wastedassign
- whitespace
settings:
cyclop: cyclop:
# The maximal code complexity to report.
# Default: 10
max-complexity: 30 max-complexity: 30
# The maximal average package complexity. package-average: 10
# If it's higher than 0.0 (float) the check is enabled
# Default: 0.0
package-average: 10.0
errcheck: errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true check-type-assertions: true
funlen: funlen:
# Checks the number of lines in a function.
# If lower than 0, disable the check.
# Default: 60
lines: 100 lines: 100
# Checks the number of statements in a function.
# If lower than 0, disable the check.
# Default: 40
statements: 50 statements: 50
gocognit: gocognit:
# Minimal code complexity to report
# Default: 30 (but we recommend 10-20)
min-complexity: 20 min-complexity: 20
gocritic: gocritic:
# Settings passed to gocritic.
# The settings key is the name of a supported gocritic checker.
# The list of supported checkers can be find in https://go-critic.github.io/overview.
settings: settings:
captLocal: captLocal:
# Whether to restrict checker to params only.
# Default: true
paramsOnly: false paramsOnly: false
underef: underef:
# Whether to skip (*x).method() calls where x is a pointer receiver.
# Default: true
skipRecvDeref: false skipRecvDeref: false
gomodguard:
blocked:
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: satori's package is not maintained
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw'
govet:
disable:
- fieldalignment
enable-all: true
settings:
shadow:
strict: true
mnd: mnd:
# List of function patterns to exclude from analysis.
# Values always ignored: `time.Date`
# Default: []
ignored-functions: ignored-functions:
- os.Chmod - os.Chmod
- os.Mkdir - os.Mkdir
@@ -76,194 +104,44 @@ linters-settings:
- strconv.ParseFloat - strconv.ParseFloat
- strconv.ParseInt - strconv.ParseInt
- strconv.ParseUint - strconv.ParseUint
gomodguard:
blocked:
# List of blocked modules.
# Default: []
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: "satori's package is not maintained"
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
govet:
# Enable all analyzers.
# Default: false
enable-all: true
# Disable analyzers by name.
# Run `go tool vet help` to see all analyzers.
# Default: []
disable:
- fieldalignment # too strict
# Settings per analyzer.
settings:
shadow:
# Whether to be strict about shadowing; can be noisy.
# Default: false
strict: true
nakedret: nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns.
# Default: 30
max-func-lines: 0 max-func-lines: 0
nolintlint: nolintlint:
# Exclude following linters from requiring an explanation.
# Default: []
allow-no-explanation: [ funlen, gocognit, lll ]
# Enable to require an explanation of nonzero length after each nolint directive.
# Default: false
require-explanation: true require-explanation: true
# Enable to require nolint directives to mention the specific linter being suppressed.
# Default: false
require-specific: true require-specific: true
allow-no-explanation:
- funlen
- gocognit
- lll
rowserrcheck: rowserrcheck:
# database/sql is always checked
# Default: []
packages: packages:
- github.com/jmoiron/sqlx - github.com/jmoiron/sqlx
exclusions:
tenv: generated: lax
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. presets:
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - comments
# Default: false - common-false-positives
all: true - legacy
- std-error-handling
varcheck: rules:
# Check usage of exported fields and variables. - linters:
# Default: false - forbidigo
exported-fields: false # default false # TODO: enable after fixing false positives - mnd
- revive
path : ^examples/.*\.go$
linters: - linters:
disable-all: true - lll
enable: source: ^//\s*go:generate\s
## enabled by default - linters:
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - godot
- gosimple # Linter for Go source code that specializes in simplifying a code source: (noinspection|TODO)
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - linters:
- ineffassign # Detects when assignments to existing variables are not used - gocritic
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks source: //noinspection
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - linters:
- unused # Checks Go code for unused constants, variables, functions and types - errorlint
## disabled by default source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok {
# - asasalint # Check for pass []any as any in variadic func(...any) - linters:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
# Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- forbidigo # Forbids identifiers
- funlen # Tool for detection of long functions
# - gochecknoglobals # check that no global variables exist
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with f at the end
- gosec # Inspects source code for security problems
- lll # Reports long lines
- makezero # Finds slice declarations with non-zero initial length
# - nakedret # Finds naked returns in functions greater than a specified function length
- mnd # An analyzer to detect magic numbers.
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
# - noctx # noctx finds sending http request without context.Context
- nolintlint # Reports ill-formed or insufficient nolint directives
# - nonamedreturns # Reports all named returns
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
- predeclared # find code that shadows one of Go's predeclared identifiers
- promlinter # Check Prometheus metrics naming via promlint
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- stylecheck # Stylecheck is a replacement for golint
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- testpackage # linter that makes you use a separate _test package
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- usetesting # Reports uses of functions with replacement inside the testing package
- wastedassign # wastedassign finds wasted assignment statements.
- whitespace # Tool for detection of leading and trailing whitespace
## you may want to enable
#- decorder # check declaration order and count of types, constants, variables and functions
#- exhaustruct # Checks if all structure fields are initialized
#- goheader # Checks is file header matches to pattern
#- ireturn # Accept Interfaces, Return Concrete Types
#- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
#- wrapcheck # Checks that errors returned from external packages are wrapped
## disabled
#- containedctx # containedctx is a linter that detects struct contained context.Context field
#- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
#- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
#- gci # Gci controls golang package import order and makes it always deterministic.
#- godox # Tool for detection of FIXME, TODO and other comment keywords
#- goerr113 # [too strict] Golang linter to check the errors handling expressions
#- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
#- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
#- grouper # An analyzer to analyze expression groups.
#- ifshort # Checks that your code uses short syntax for if-statements whenever possible
#- importas # Enforces consistent import aliases
#- maintidx # maintidx measures the maintainability index of each function.
#- misspell # [useless] Finds commonly misspelled English words in comments
#- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
#- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
#- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
#- tagliatelle # Checks the struct tags.
#- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
#- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
## deprecated
#- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized
#- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
#- interfacer # [deprecated] Linter that suggests narrower interface types
#- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted
#- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 50
exclude-rules:
- source: "^//\\s*go:generate\\s"
linters: [ lll ]
- source: "(noinspection|TODO)"
linters: [ godot ]
- source: "//noinspection"
linters: [ gocritic ]
- source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
linters: [ errorlint ]
- path: "_test\\.go"
linters:
- bodyclose - bodyclose
- dupl - dupl
- funlen - funlen
@@ -271,3 +149,20 @@ issues:
- gosec - gosec
- noctx - noctx
- wrapcheck - wrapcheck
- staticcheck
path: _test\.go
paths:
- third_party$
- builtin$
- examples$
issues:
max-same-issues: 50
formatters:
enable:
- goimports
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -1,22 +1,22 @@
# Contributing Guidelines # Contributing Guidelines
## Overview ## Overview
Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests. Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://git.vaala.cloud/VaalaCat/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests.
## Reporting Bugs ## Reporting Bugs
If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it. If you discover a bug, first check the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it.
## Suggesting Features ## Suggesting Features
If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion. If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion.
## Reporting Vulnerabilities ## Reporting Vulnerabilities
If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published. If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published.
## Questions for Users ## Questions for Users
If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions). If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://git.vaala.cloud/VaalaCat/go-openai/discussions).
## Contributing Code ## Contributing Code
There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one. There might already be a similar pull requests submitted! Please search for [pull requests](https://git.vaala.cloud/VaalaCat/go-openai/pulls) before creating one.
### Requirements for Merging a Pull Request ### Requirements for Merging a Pull Request

106
README.md
View File

@@ -1,19 +1,19 @@
# Go OpenAI # Go OpenAI
[![Go Reference](https://pkg.go.dev/badge/github.com/sashabaranov/go-openai.svg)](https://pkg.go.dev/github.com/sashabaranov/go-openai) [![Go Reference](https://pkg.go.dev/badge/git.vaala.cloud/VaalaCat/go-openai.svg)](https://pkg.go.dev/git.vaala.cloud/VaalaCat/go-openai)
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) [![Go Report Card](https://goreportcard.com/badge/git.vaala.cloud/VaalaCat/go-openai)](https://goreportcard.com/report/git.vaala.cloud/VaalaCat/go-openai)
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) [![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai)
This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:
* ChatGPT 4o, o1 * ChatGPT 4o, o1
* GPT-3, GPT-4 * GPT-3, GPT-4
* DALL·E 2, DALL·E 3 * DALL·E 2, DALL·E 3, GPT Image 1
* Whisper * Whisper
## Installation ## Installation
``` ```
go get github.com/sashabaranov/go-openai go get git.vaala.cloud/VaalaCat/go-openai
``` ```
Currently, go-openai requires Go version 1.18 or greater. Currently, go-openai requires Go version 1.18 or greater.
@@ -28,7 +28,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -80,7 +80,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -133,7 +133,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -166,7 +166,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -215,7 +215,7 @@ import (
"context" "context"
"fmt" "fmt"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -247,7 +247,7 @@ import (
"fmt" "fmt"
"os" "os"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -288,7 +288,7 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
"image/png" "image/png"
"os" "os"
) )
@@ -357,6 +357,66 @@ func main() {
``` ```
</details> </details>
<details>
<summary>GPT Image 1 image generation</summary>
```go
package main
import (
"context"
"encoding/base64"
"fmt"
"os"
openai "github.com/sashabaranov/go-openai"
)
func main() {
c := openai.NewClient("your token")
ctx := context.Background()
req := openai.ImageRequest{
Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.",
Background: openai.CreateImageBackgroundOpaque,
Model: openai.CreateImageModelGptImage1,
Size: openai.CreateImageSize1024x1024,
N: 1,
Quality: openai.CreateImageQualityLow,
OutputCompression: 100,
OutputFormat: openai.CreateImageOutputFormatJPEG,
// Moderation: openai.CreateImageModerationLow,
// User: "",
}
resp, err := c.CreateImage(ctx, req)
if err != nil {
fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err)
return
}
fmt.Println("Image Base64:", resp.Data[0].B64JSON)
// Decode the base64 data
imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
if err != nil {
fmt.Printf("Base64 decode error: %v\n", err)
return
}
// Write image to file
outputPath := "generated_image.jpg"
err = os.WriteFile(outputPath, imgBytes, 0644)
if err != nil {
fmt.Printf("Failed to write image file: %v\n", err)
return
}
fmt.Printf("The image was saved as %s\n", outputPath)
}
```
</details>
<details> <details>
<summary>Configuring proxy</summary> <summary>Configuring proxy</summary>
@@ -376,7 +436,7 @@ config.HTTPClient = &http.Client{
c := openai.NewClientWithConfig(config) c := openai.NewClientWithConfig(config)
``` ```
See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig See also: https://pkg.go.dev/git.vaala.cloud/VaalaCat/go-openai#ClientConfig
</details> </details>
<details> <details>
@@ -392,7 +452,7 @@ import (
"os" "os"
"strings" "strings"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -446,7 +506,7 @@ import (
"context" "context"
"fmt" "fmt"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -492,7 +552,7 @@ package main
import ( import (
"context" "context"
"log" "log"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
@@ -549,7 +609,7 @@ import (
"context" "context"
"fmt" "fmt"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -680,7 +740,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {
@@ -755,8 +815,8 @@ import (
"fmt" "fmt"
"log" "log"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/jsonschema" "git.vaala.cloud/VaalaCat/go-openai/jsonschema"
) )
func main() { func main() {
@@ -828,7 +888,7 @@ Due to the factors mentioned above, different answers may be returned even for t
By adopting these strategies, you can expect more consistent results. By adopting these strategies, you can expect more consistent results.
**Related Issues:** **Related Issues:**
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9) [omitempty option of request struct will generate incorrect request when parameter is 0.](https://git.vaala.cloud/VaalaCat/go-openai/issues/9)
### Does Go OpenAI provide a method to count tokens? ### Does Go OpenAI provide a method to count tokens?
@@ -839,15 +899,15 @@ For counting tokens, you might find the following links helpful:
- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) - [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)
**Related Issues:** **Related Issues:**
[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62) [Is it possible to join the implementation of GPT3 Tokenizer](https://git.vaala.cloud/VaalaCat/go-openai/issues/62)
## Contributing ## Contributing
By following [Contributing Guidelines](https://github.com/sashabaranov/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently. By following [Contributing Guidelines](https://git.vaala.cloud/VaalaCat/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently.
## Thank you ## Thank you
We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project: We want to take a moment to express our deepest gratitude to the [contributors](https://git.vaala.cloud/VaalaCat/go-openai/graphs/contributors) and sponsors of this project:
- [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com) - [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com)
To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together! To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together!

View File

@@ -10,9 +10,9 @@ import (
"os" "os"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"github.com/sashabaranov/go-openai/jsonschema" "git.vaala.cloud/VaalaCat/go-openai/jsonschema"
) )
func TestAPI(t *testing.T) { func TestAPI(t *testing.T) {

View File

@@ -3,8 +3,8 @@ package openai_test
import ( import (
"context" "context"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"encoding/json" "encoding/json"
"fmt" "fmt"

View File

@@ -8,7 +8,7 @@ import (
"net/http" "net/http"
"os" "os"
utils "github.com/sashabaranov/go-openai/internal" utils "git.vaala.cloud/VaalaCat/go-openai/internal"
) )
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.

View File

@@ -12,9 +12,9 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "git.vaala.cloud/VaalaCat/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. // TestAudio Tests the transcription and translation endpoints of the API using the mocked server.

View File

@@ -2,14 +2,17 @@ package openai //nolint:testpackage // testing private field
import ( import (
"bytes" "bytes"
"context"
"errors"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/sashabaranov/go-openai/internal/test" "git.vaala.cloud/VaalaCat/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestAudioWithFailingFormBuilder(t *testing.T) { func TestAudioWithFailingFormBuilder(t *testing.T) {
@@ -107,3 +110,131 @@ func TestCreateFileField(t *testing.T) {
checks.HasError(t, err, "createFileField using file should return error when open file fails") checks.HasError(t, err, "createFileField using file should return error when open file fails")
}) })
} }
// failingFormBuilder always returns an error when creating form files.
type failingFormBuilder struct{ err error }
func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error {
return f.err
}
func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error {
return f.err
}
func (f *failingFormBuilder) WriteField(_, _ string) error {
return nil
}
func (f *failingFormBuilder) Close() error {
return nil
}
func (f *failingFormBuilder) FormDataContentType() string {
return "multipart/form-data"
}
// failingAudioRequestBuilder simulates an error during HTTP request construction.
type failingAudioRequestBuilder struct{ err error }
func (f *failingAudioRequestBuilder) Build(
_ context.Context,
_, _ string,
_ any,
_ http.Header,
) (*http.Request, error) {
return nil, f.err
}
// errorHTTPClient always returns an error when making HTTP calls.
type errorHTTPClient struct{ err error }
func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) {
return nil, e.err
}
func TestCallAudioAPIMultipartFormError(t *testing.T) {
client := NewClient("test-token")
errForm := errors.New("mock create form file failure")
// Override form builder to force an error during multipart form creation.
client.createFormBuilder = func(_ io.Writer) utils.FormBuilder {
return &failingFormBuilder{err: errForm}
}
// Provide a reader so createFileField uses the reader path (no file open).
req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errForm) {
t.Errorf("expected error %v, got %v", errForm, err)
}
}
func TestCallAudioAPINewRequestError(t *testing.T) {
client := NewClient("test-token")
// Create a real temp file so multipart form succeeds.
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errBuild := errors.New("mock build failure")
client.requestBuilder = &failingAudioRequestBuilder{err: errBuild}
req := AudioRequest{FilePath: path, Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "translations")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errBuild) {
t.Errorf("expected error %v, got %v", errBuild, err)
}
}
func TestCallAudioAPISendRequestErrorJSON(t *testing.T) {
client := NewClient("test-token")
// Create a real temp file so multipart form succeeds.
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errHTTP := errors.New("mock HTTPClient failure")
// Override HTTP client to simulate a network error.
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
req := AudioRequest{FilePath: path, Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errHTTP) {
t.Errorf("expected error %v, got %v", errHTTP, err)
}
}
func TestCallAudioAPISendRequestErrorText(t *testing.T) {
client := NewClient("test-token")
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errHTTP := errors.New("mock HTTPClient failure")
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
// Use a non-JSON response format to exercise the text path.
req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText}
_, err := client.callAudioAPI(context.Background(), req, "translations")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errHTTP) {
t.Errorf("expected error %v, got %v", errHTTP, err)
}
}

View File

@@ -7,8 +7,8 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestUploadBatchFile(t *testing.T) { func TestUploadBatchFile(t *testing.T) {

27
chat.go
View File

@@ -14,6 +14,7 @@ const (
ChatMessageRoleAssistant = "assistant" ChatMessageRoleAssistant = "assistant"
ChatMessageRoleFunction = "function" ChatMessageRoleFunction = "function"
ChatMessageRoleTool = "tool" ChatMessageRoleTool = "tool"
ChatMessageRoleDeveloper = "developer"
) )
const chatCompletionsSuffix = "/chat/completions" const chatCompletionsSuffix = "/chat/completions"
@@ -103,6 +104,12 @@ type ChatCompletionMessage struct {
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
// This property is used for the "reasoning" feature supported by deepseek-reasoner
// which is not in the official documentation.
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
@@ -123,6 +130,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"content,omitempty"` MultiContent []ChatMessagePart `json:"content,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
@@ -136,6 +144,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"-"` MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
@@ -146,10 +155,11 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
msg := struct { msg := struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content,omitempty"` Content string `json:"content"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart MultiContent []ChatMessagePart
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
@@ -165,6 +175,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"content"` MultiContent []ChatMessagePart `json:"content"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
@@ -262,6 +273,15 @@ type ChatCompletionRequest struct {
ReasoningEffort string `json:"reasoning_effort,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"`
// Metadata to store with the completion. // Metadata to store with the completion.
Metadata map[string]string `json:"metadata,omitempty"` Metadata map[string]string `json:"metadata,omitempty"`
IncludeReasoning *bool `json:"include_reasoning,omitempty"`
ReasoningFormat *string `json:"reasoning_format,omitempty"`
// Configuration for a predicted output.
Prediction *Prediction `json:"prediction,omitempty"`
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
} }
type StreamOptions struct { type StreamOptions struct {
@@ -329,6 +349,11 @@ type LogProbs struct {
Content []LogProb `json:"content"` Content []LogProb `json:"content"`
} }
type Prediction struct {
Content string `json:"content"`
Type string `json:"type"`
}
type FinishReason string type FinishReason string
const ( const (

View File

@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
// This property is used for the "reasoning" feature supported by deepseek-reasoner
// which is not in the official documentation.
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
} }
type ChatCompletionStreamChoiceLogprobs struct { type ChatCompletionStreamChoiceLogprobs struct {

View File

@@ -10,8 +10,8 @@ import (
"strconv" "strconv"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestChatCompletionsStreamWrongModel(t *testing.T) { func TestChatCompletionsStreamWrongModel(t *testing.T) {
@@ -959,6 +959,56 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
} }
} }
func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O3,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err)
}
}
func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O4Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err)
}
}
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
if c1.Index != c2.Index { if c1.Index != c2.Index {
return false return false

View File

@@ -12,9 +12,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"github.com/sashabaranov/go-openai/jsonschema" "git.vaala.cloud/VaalaCat/go-openai/jsonschema"
) )
const ( const (
@@ -106,40 +106,6 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
}, },
expectedError: openai.ErrReasoningModelLimitationsLogprobs, expectedError: openai.ErrReasoningModelLimitationsLogprobs,
}, },
{
name: "message_type_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
},
},
},
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
},
{
name: "tool_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
},
},
},
expectedError: openai.ErrO1BetaLimitationsTools,
},
{ {
name: "set_temperature_unsupported", name: "set_temperature_unsupported",
in: openai.ChatCompletionRequest{ in: openai.ChatCompletionRequest{
@@ -445,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error") checks.NoError(t, err, "CreateChatCompletion error")
} }
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: "deepseek-reasoner",
MaxCompletionTokens: 100,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) { func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
@@ -856,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
} }
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq openai.ChatCompletionRequest
if completionReq, err = getChatCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := openai.ChatCompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
n := completionReq.N
if n == 0 {
n = 1
}
if completionReq.MaxCompletionTokens == 0 {
completionReq.MaxCompletionTokens = 1000
}
for i := 0; i < n; i++ {
reasoningContent := "User says hello! And I need to reply"
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
ReasoningContent: reasoningContent,
Content: completionStr,
},
Index: i,
})
}
inputTokens := numTokens(completionReq.Messages[0].Content) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = openai.Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}
fmt.Fprintln(w, string(resBytes))
}
// getChatCompletionBody Returns the body of the request to create a completion. // getChatCompletionBody Returns the body of the request to create a completion.
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
completion := openai.ChatCompletionRequest{} completion := openai.ChatCompletionRequest{}

View File

@@ -10,7 +10,7 @@ import (
"net/url" "net/url"
"strings" "strings"
utils "github.com/sashabaranov/go-openai/internal" utils "git.vaala.cloud/VaalaCat/go-openai/internal"
) )
// Client is OpenAI GPT-3 API client. // Client is OpenAI GPT-3 API client.
@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) { func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
switch c.config.APIType {
case APITypeAzure, APITypeCloudflareAzure:
// Azure API Key authentication // Azure API Key authentication
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken) req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" { case APITypeAnthropic:
// OpenAI or Azure AD authentication // https://docs.anthropic.com/en/api/versioning
req.Header.Set("anthropic-version", c.config.APIVersion)
case APITypeOpenAI, APITypeAzureAD:
fallthrough
default:
if c.config.authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
} }
}
if c.config.OrgID != "" { if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID) req.Header.Set("OpenAI-Organization", c.config.OrgID)
} }

View File

@@ -10,8 +10,8 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/sashabaranov/go-openai/internal/test" "git.vaala.cloud/VaalaCat/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
var errTestRequestBuilderFailed = errors.New("test request builder failed") var errTestRequestBuilderFailed = errors.New("test request builder failed")
@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
} }
} }
func TestSetCommonHeadersAnthropic(t *testing.T) {
config := DefaultAnthropicConfig("mock-token", "")
client := NewClientWithConfig(config)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
client.setCommonHeaders(req)
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
}
}
func TestDecodeResponse(t *testing.T) { func TestDecodeResponse(t *testing.T) {
stringInput := "" stringInput := ""

View File

@@ -15,6 +15,8 @@ type Usage struct {
type CompletionTokensDetails struct { type CompletionTokensDetails struct {
AudioTokens int `json:"audio_tokens"` AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"` ReasoningTokens int `json:"reasoning_tokens"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
} }
// PromptTokensDetails Breakdown of tokens used in the prompt. // PromptTokensDetails Breakdown of tokens used in the prompt.

View File

@@ -16,8 +16,12 @@ const (
O1Preview20240912 = "o1-preview-2024-09-12" O1Preview20240912 = "o1-preview-2024-09-12"
O1 = "o1" O1 = "o1"
O120241217 = "o1-2024-12-17" O120241217 = "o1-2024-12-17"
O3 = "o3"
O320250416 = "o3-2025-04-16"
O3Mini = "o3-mini" O3Mini = "o3-mini"
O3Mini20250131 = "o3-mini-2025-01-31" O3Mini20250131 = "o3-mini-2025-01-31"
O4Mini = "o4-mini"
O4Mini20250416 = "o4-mini-2025-04-16"
GPT432K0613 = "gpt-4-32k-0613" GPT432K0613 = "gpt-4-32k-0613"
GPT432K0314 = "gpt-4-32k-0314" GPT432K0314 = "gpt-4-32k-0314"
GPT432K = "gpt-4-32k" GPT432K = "gpt-4-32k"
@@ -37,6 +41,14 @@ const (
GPT4TurboPreview = "gpt-4-turbo-preview" GPT4TurboPreview = "gpt-4-turbo-preview"
GPT4VisionPreview = "gpt-4-vision-preview" GPT4VisionPreview = "gpt-4-vision-preview"
GPT4 = "gpt-4" GPT4 = "gpt-4"
GPT4Dot1 = "gpt-4.1"
GPT4Dot120250414 = "gpt-4.1-2025-04-14"
GPT4Dot1Mini = "gpt-4.1-mini"
GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14"
GPT4Dot1Nano = "gpt-4.1-nano"
GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14"
GPT4Dot5Preview = "gpt-4.5-preview"
GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27"
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
@@ -91,6 +103,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
O1Preview20240912: true, O1Preview20240912: true,
O3Mini: true, O3Mini: true,
O3Mini20250131: true, O3Mini20250131: true,
O4Mini: true,
O4Mini20250416: true,
O3: true,
O320250416: true,
GPT3Dot5Turbo: true, GPT3Dot5Turbo: true,
GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0301: true,
GPT3Dot5Turbo0613: true, GPT3Dot5Turbo0613: true,
@@ -99,6 +115,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K: true,
GPT3Dot5Turbo16K0613: true, GPT3Dot5Turbo16K0613: true,
GPT4: true, GPT4: true,
GPT4Dot5Preview: true,
GPT4Dot5Preview20250227: true,
GPT4o: true, GPT4o: true,
GPT4o20240513: true, GPT4o20240513: true,
GPT4o20240806: true, GPT4o20240806: true,
@@ -117,6 +135,13 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT432K: true, GPT432K: true,
GPT432K0314: true, GPT432K0314: true,
GPT432K0613: true, GPT432K0613: true,
O1: true,
GPT4Dot1: true,
GPT4Dot120250414: true,
GPT4Dot1Mini: true,
GPT4Dot1Mini20250414: true,
GPT4Dot1Nano: true,
GPT4Dot1Nano20250414: true,
}, },
chatCompletionsSuffix: { chatCompletionsSuffix: {
CodexCodeDavinci002: true, CodexCodeDavinci002: true,
@@ -190,6 +215,8 @@ type CompletionRequest struct {
Temperature float32 `json:"temperature,omitempty"` Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"` TopP float32 `json:"top_p,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
// Options for streaming response. Only set this when you set stream: true.
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
} }
// CompletionChoice represents one of possible completions. // CompletionChoice represents one of possible completions.

View File

@@ -12,8 +12,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestCompletionsWrongModel(t *testing.T) { func TestCompletionsWrongModel(t *testing.T) {
@@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) {
} }
} }
// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported.
func TestCompletionsWrongModelO3(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O3,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err)
}
}
// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported.
func TestCompletionsWrongModelO4Mini(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O4Mini,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err)
}
}
func TestCompletionWithStream(t *testing.T) { func TestCompletionWithStream(t *testing.T) {
config := openai.DefaultConfig("whatever") config := openai.DefaultConfig("whatever")
client := openai.NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
@@ -181,3 +217,86 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
} }
return completion, nil return completion, nil
} }
// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint.
func TestCompletionWithO1Model(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O1,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err)
}
}
// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint.
func TestCompletionWithGPT4DotModels(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
models := []string{
openai.GPT4Dot1,
openai.GPT4Dot120250414,
openai.GPT4Dot1Mini,
openai.GPT4Dot1Mini20250414,
openai.GPT4Dot1Nano,
openai.GPT4Dot1Nano20250414,
openai.GPT4Dot5Preview,
openai.GPT4Dot5Preview20250227,
}
for _, model := range models {
t.Run(model, func(t *testing.T) {
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: model,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
}
})
}
}
// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint.
func TestCompletionWithGPT4oModels(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
models := []string{
openai.GPT4o,
openai.GPT4o20240513,
openai.GPT4o20240806,
openai.GPT4o20241120,
openai.GPT4oLatest,
openai.GPT4oMini,
openai.GPT4oMini20240718,
}
for _, model := range models {
t.Run(model, func(t *testing.T) {
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: model,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
}
})
}
}

View File

@@ -11,6 +11,8 @@ const (
azureAPIPrefix = "openai" azureAPIPrefix = "openai"
azureDeploymentsPrefix = "deployments" azureDeploymentsPrefix = "deployments"
AnthropicAPIVersion = "2023-06-01"
) )
type APIType string type APIType string
@@ -20,6 +22,7 @@ const (
APITypeAzure APIType = "AZURE" APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD" APITypeAzureAD APIType = "AZURE_AD"
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
APITypeAnthropic APIType = "ANTHROPIC"
) )
const AzureAPIKeyHeader = "api-key" const AzureAPIKeyHeader = "api-key"
@@ -37,7 +40,7 @@ type ClientConfig struct {
BaseURL string BaseURL string
OrgID string OrgID string
APIType APIType APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
AssistantVersion string AssistantVersion string
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient HTTPDoer HTTPClient HTTPDoer
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
} }
} }
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
if baseURL == "" {
baseURL = "https://api.anthropic.com/v1"
}
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAnthropic,
APIVersion: AnthropicAPIVersion,
HTTPClient: &http.Client{},
EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}
func (ClientConfig) String() string { func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>" return "<OpenAI API ClientConfig>"
} }

View File

@@ -3,7 +3,7 @@ package openai_test
import ( import (
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func TestGetAzureDeploymentByModel(t *testing.T) { func TestGetAzureDeploymentByModel(t *testing.T) {
@@ -60,3 +60,64 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
}) })
} }
} }
func TestDefaultAnthropicConfig(t *testing.T) {
apiKey := "test-key"
baseURL := "https://api.anthropic.com/v1"
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
}
if config.BaseURL != baseURL {
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
}
if config.EmptyMessagesLimit != 300 {
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
}
}
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
config := openai.DefaultAnthropicConfig("", "")
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
}
expectedBaseURL := "https://api.anthropic.com/v1"
if config.BaseURL != expectedBaseURL {
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
}
}
func TestClientConfigString(t *testing.T) {
// String() should always return the constant value
conf := openai.DefaultConfig("dummy-token")
expected := "<OpenAI API ClientConfig>"
got := conf.String()
if got != expected {
t.Errorf("ClientConfig.String() = %q; want %q", got, expected)
}
}
func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) {
// On a zero-value or DefaultConfig, AzureModelMapperFunc is nil,
// so GetAzureDeploymentByModel should just return the input model.
conf := openai.DefaultConfig("dummy-token")
model := "some-model"
got := conf.GetAzureDeploymentByModel(model)
if got != model {
t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model)
}
}

View File

@@ -9,8 +9,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
// TestEdits Tests the edits endpoint of the API using the mocked server. // TestEdits Tests the edits endpoint of the API using the mocked server.

View File

@@ -11,8 +11,8 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestEmbedding(t *testing.T) { func TestEmbedding(t *testing.T) {

View File

@@ -7,8 +7,8 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. // TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.

View File

@@ -54,7 +54,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
err = json.Unmarshal(rawMap["message"], &e.Message) err = json.Unmarshal(rawMap["message"], &e.Message)
if err != nil { if err != nil {
// If the parameter field of a function call is invalid as a JSON schema // If the parameter field of a function call is invalid as a JSON schema
// refs: https://github.com/sashabaranov/go-openai/issues/381 // refs: https://git.vaala.cloud/VaalaCat/go-openai/issues/381
var messages []string var messages []string
err = json.Unmarshal(rawMap["message"], &messages) err = json.Unmarshal(rawMap["message"], &messages)
if err != nil { if err != nil {
@@ -64,7 +64,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
} }
// optional fields for azure openai // optional fields for azure openai
// refs: https://github.com/sashabaranov/go-openai/issues/343 // refs: https://git.vaala.cloud/VaalaCat/go-openai/issues/343
if _, ok := rawMap["type"]; ok { if _, ok := rawMap["type"]; ok {
err = json.Unmarshal(rawMap["type"], &e.Type) err = json.Unmarshal(rawMap["type"], &e.Type)
if err != nil { if err != nil {

View File

@@ -6,7 +6,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func TestAPIErrorUnmarshalJSON(t *testing.T) { func TestAPIErrorUnmarshalJSON(t *testing.T) {

View File

@@ -11,7 +11,7 @@ import (
"net/url" "net/url"
"os" "os"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func Example() { func Example() {

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {

View File

@@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/jsonschema" "git.vaala.cloud/VaalaCat/go-openai/jsonschema"
) )
func main() { func main() {

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
) )
func main() { func main() {

View File

@@ -12,8 +12,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestFileBytesUpload(t *testing.T) { func TestFileBytesUpload(t *testing.T) {

View File

@@ -7,8 +7,8 @@ import (
"os" "os"
"testing" "testing"
utils "github.com/sashabaranov/go-openai/internal" utils "git.vaala.cloud/VaalaCat/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) { func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) {

View File

@@ -7,8 +7,8 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
const testFineTuneID = "fine-tune-id" const testFineTuneID = "fine-tune-id"

View File

@@ -7,8 +7,8 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
const testFineTuninigJobID = "fine-tuning-job-id" const testFineTuninigJobID = "fine-tuning-job-id"

2
go.mod
View File

@@ -1,3 +1,3 @@
module github.com/sashabaranov/go-openai module git.vaala.cloud/VaalaCat/go-openai
go 1.18 go 1.18

View File

@@ -3,8 +3,8 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"io"
"net/http" "net/http"
"os"
"strconv" "strconv"
) )
@@ -13,31 +13,62 @@ const (
CreateImageSize256x256 = "256x256" CreateImageSize256x256 = "256x256"
CreateImageSize512x512 = "512x512" CreateImageSize512x512 = "512x512"
CreateImageSize1024x1024 = "1024x1024" CreateImageSize1024x1024 = "1024x1024"
// dall-e-3 supported only. // dall-e-3 supported only.
CreateImageSize1792x1024 = "1792x1024" CreateImageSize1792x1024 = "1792x1024"
CreateImageSize1024x1792 = "1024x1792" CreateImageSize1024x1792 = "1024x1792"
// gpt-image-1 supported only.
CreateImageSize1536x1024 = "1536x1024" // Landscape
CreateImageSize1024x1536 = "1024x1536" // Portrait
) )
const ( const (
CreateImageResponseFormatURL = "url" // dall-e-2 and dall-e-3 only.
CreateImageResponseFormatB64JSON = "b64_json" CreateImageResponseFormatB64JSON = "b64_json"
CreateImageResponseFormatURL = "url"
) )
const ( const (
CreateImageModelDallE2 = "dall-e-2" CreateImageModelDallE2 = "dall-e-2"
CreateImageModelDallE3 = "dall-e-3" CreateImageModelDallE3 = "dall-e-3"
CreateImageModelGptImage1 = "gpt-image-1"
) )
const ( const (
CreateImageQualityHD = "hd" CreateImageQualityHD = "hd"
CreateImageQualityStandard = "standard" CreateImageQualityStandard = "standard"
// gpt-image-1 only.
CreateImageQualityHigh = "high"
CreateImageQualityMedium = "medium"
CreateImageQualityLow = "low"
) )
const ( const (
// dall-e-3 only.
CreateImageStyleVivid = "vivid" CreateImageStyleVivid = "vivid"
CreateImageStyleNatural = "natural" CreateImageStyleNatural = "natural"
) )
const (
// gpt-image-1 only.
CreateImageBackgroundTransparent = "transparent"
CreateImageBackgroundOpaque = "opaque"
)
const (
// gpt-image-1 only.
CreateImageModerationLow = "low"
)
const (
// gpt-image-1 only.
CreateImageOutputFormatPNG = "png"
CreateImageOutputFormatJPEG = "jpeg"
CreateImageOutputFormatWEBP = "webp"
)
// ImageRequest represents the request structure for the image API. // ImageRequest represents the request structure for the image API.
type ImageRequest struct { type ImageRequest struct {
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@@ -48,16 +79,35 @@ type ImageRequest struct {
Style string `json:"style,omitempty"` Style string `json:"style,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
} }
// ImageResponse represents a response structure for image API. // ImageResponse represents a response structure for image API.
type ImageResponse struct { type ImageResponse struct {
Created int64 `json:"created,omitempty"` Created int64 `json:"created,omitempty"`
Data []ImageResponseDataInner `json:"data,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"`
Usage ImageResponseUsage `json:"usage,omitempty"`
httpHeader httpHeader
} }
// ImageResponseInputTokensDetails represents the token breakdown for input tokens.
type ImageResponseInputTokensDetails struct {
TextTokens int `json:"text_tokens,omitempty"`
ImageTokens int `json:"image_tokens,omitempty"`
}
// ImageResponseUsage represents the token usage information for image API.
type ImageResponseUsage struct {
TotalTokens int `json:"total_tokens,omitempty"`
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"`
}
// ImageResponseDataInner represents a response data structure for image API. // ImageResponseDataInner represents a response data structure for image API.
type ImageResponseDataInner struct { type ImageResponseDataInner struct {
URL string `json:"url,omitempty"` URL string `json:"url,omitempty"`
@@ -84,13 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
// ImageEditRequest represents the request structure for the image API. // ImageEditRequest represents the request structure for the image API.
type ImageEditRequest struct { type ImageEditRequest struct {
Image *os.File `json:"image,omitempty"` Image io.Reader `json:"image,omitempty"`
Mask *os.File `json:"mask,omitempty"` Mask io.Reader `json:"mask,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
Quality string `json:"quality,omitempty"`
User string `json:"user,omitempty"`
} }
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
@@ -98,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
body := &bytes.Buffer{} body := &bytes.Buffer{}
builder := c.createFormBuilder(body) builder := c.createFormBuilder(body)
// image // image, filename is not required
err = builder.CreateFormFile("image", request.Image) err = builder.CreateFormFileReader("image", request.Image, "")
if err != nil { if err != nil {
return return
} }
// mask, it is optional // mask, it is optional
if request.Mask != nil { if request.Mask != nil {
err = builder.CreateFormFile("mask", request.Mask) // mask, filename is not required
err = builder.CreateFormFileReader("mask", request.Mask, "")
if err != nil { if err != nil {
return return
} }
@@ -154,11 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
// ImageVariRequest represents the request structure for the image API. // ImageVariRequest represents the request structure for the image API.
type ImageVariRequest struct { type ImageVariRequest struct {
Image *os.File `json:"image,omitempty"` Image io.Reader `json:"image,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
User string `json:"user,omitempty"`
} }
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
@@ -167,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
body := &bytes.Buffer{} body := &bytes.Buffer{}
builder := c.createFormBuilder(body) builder := c.createFormBuilder(body)
// image // image, filename is not required
err = builder.CreateFormFile("image", request.Image) err = builder.CreateFormFileReader("image", request.Image, "")
if err != nil { if err != nil {
return return
} }

View File

@@ -11,8 +11,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestImages(t *testing.T) { func TestImages(t *testing.T) {

View File

@@ -1,8 +1,8 @@
package openai //nolint:testpackage // testing private field package openai //nolint:testpackage // testing private field
import ( import (
utils "github.com/sashabaranov/go-openai/internal" utils "git.vaala.cloud/VaalaCat/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"context" "context"
"fmt" "fmt"
@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
} }
mockFailedErr := fmt.Errorf("mock form builder fail") mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFile = func(string, *os.File) error { mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return mockFailedErr return mockFailedErr
} }
_, err := client.CreateEditImage(ctx, req) _, err := client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
if name == "mask" { if name == "mask" {
return mockFailedErr return mockFailedErr
} }
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
req := ImageVariRequest{} req := ImageVariRequest{}
mockFailedErr := fmt.Errorf("mock form builder fail") mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFile = func(string, *os.File) error { mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return mockFailedErr return mockFailedErr
} }
_, err := client.CreateVariImage(ctx, req) _, err := client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(string, *os.File) error { mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return nil return nil
} }

View File

@@ -5,8 +5,8 @@ import (
"errors" "errors"
"testing" "testing"
utils "github.com/sashabaranov/go-openai/internal" utils "git.vaala.cloud/VaalaCat/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test" "git.vaala.cloud/VaalaCat/go-openai/internal/test"
) )
func TestErrorAccumulatorBytes(t *testing.T) { func TestErrorAccumulatorBytes(t *testing.T) {

View File

@@ -4,8 +4,10 @@ import (
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"net/textproto"
"os" "os"
"path" "path/filepath"
"strings"
) )
type FormBuilder interface { type FormBuilder interface {
@@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
return fb.createFormFile(fieldname, file, file.Name()) return fb.createFormFile(fieldname, file, file.Name())
} }
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
// CreateFormFileReader creates a form field with a file reader.
// The filename in parameters can be an empty string.
// The filename in Content-Disposition is required, But it can be an empty string.
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.createFormFile(fieldname, r, path.Base(filename)) h := make(textproto.MIMEHeader)
h.Set(
"Content-Disposition",
fmt.Sprintf(
`form-data; name="%s"; filename="%s"`,
escapeQuotes(fieldname),
escapeQuotes(filepath.Base(filename)),
),
)
fieldWriter, err := fb.writer.CreatePart(h)
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}
return nil
} }
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {

View File

@@ -1,7 +1,7 @@
package openai //nolint:testpackage // testing private field package openai //nolint:testpackage // testing private field
import ( import (
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"bytes" "bytes"
"errors" "errors"
@@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
checks.HasError(t, err, "formbuilder should return error if file is closed") checks.HasError(t, err, "formbuilder should return error if file is closed")
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
} }
type failingReader struct {
}
var errMockFailingReaderError = errors.New("mock reader failed")
func (*failingReader) Read([]byte) (int, error) {
return 0, errMockFailingReaderError
}
func TestFormBuilderWithReader(t *testing.T) {
file, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatalf("Error creating tmp file: %v", err)
}
defer file.Close()
builder := NewFormBuilder(&failingWriter{})
err = builder.CreateFormFileReader("file", file, file.Name())
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
builder = NewFormBuilder(&bytes.Buffer{})
reader := &failingReader{}
err = builder.CreateFormFileReader("file", reader, "")
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
successReader := &bytes.Buffer{}
err = builder.CreateFormFileReader("file", successReader, "")
checks.NoError(t, err, "formbuilder should not return error")
}

View File

@@ -1,7 +1,7 @@
package test package test
import ( import (
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"net/http" "net/http"
"os" "os"

View File

@@ -46,6 +46,8 @@ type Definition struct {
// additionalProperties: false // additionalProperties: false
// additionalProperties: jsonschema.Definition{Type: jsonschema.String} // additionalProperties: jsonschema.Definition{Type: jsonschema.String}
AdditionalProperties any `json:"additionalProperties,omitempty"` AdditionalProperties any `json:"additionalProperties,omitempty"`
// Whether the schema is nullable or not.
Nullable bool `json:"nullable,omitempty"`
} }
func (d *Definition) MarshalJSON() ([]byte, error) { func (d *Definition) MarshalJSON() ([]byte, error) {
@@ -124,9 +126,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
} }
jsonTag := field.Tag.Get("json") jsonTag := field.Tag.Get("json")
var required = true var required = true
if jsonTag == "" { switch {
case jsonTag == "-":
continue
case jsonTag == "":
jsonTag = field.Name jsonTag = field.Name
} else if strings.HasSuffix(jsonTag, ",omitempty") { case strings.HasSuffix(jsonTag, ",omitempty"):
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
required = false required = false
} }
@@ -139,6 +144,16 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
if description != "" { if description != "" {
item.Description = description item.Description = description
} }
enum := field.Tag.Get("enum")
if enum != "" {
item.Enum = strings.Split(enum, ",")
}
if n := field.Tag.Get("nullable"); n != "" {
nullable, _ := strconv.ParseBool(n)
item.Nullable = nullable
}
properties[jsonTag] = *item properties[jsonTag] = *item
if s := field.Tag.Get("required"); s != "" { if s := field.Tag.Get("required"); s != "" {

View File

@@ -5,7 +5,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/sashabaranov/go-openai/jsonschema" "git.vaala.cloud/VaalaCat/go-openai/jsonschema"
) )
func TestDefinition_MarshalJSON(t *testing.T) { func TestDefinition_MarshalJSON(t *testing.T) {
@@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
{ {
name: "Test with empty Definition", name: "Test with empty Definition",
def: jsonschema.Definition{}, def: jsonschema.Definition{},
want: `{"properties":{}}`, want: `{}`,
}, },
{ {
name: "Test with Definition properties set", name: "Test with Definition properties set",
@@ -35,11 +35,10 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"description":"A string type", "description":"A string type",
"properties":{ "properties":{
"name":{ "name":{
"type":"string", "type":"string"
"properties":{}
} }
} }
}`, }`,
}, },
{ {
name: "Test with nested Definition properties", name: "Test with nested Definition properties",
@@ -66,17 +65,15 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object", "type":"object",
"properties":{ "properties":{
"name":{ "name":{
"type":"string", "type":"string"
"properties":{}
}, },
"age":{ "age":{
"type":"integer", "type":"integer"
"properties":{}
} }
} }
} }
} }
}`, }`,
}, },
{ {
name: "Test with complex nested Definition", name: "Test with complex nested Definition",
@@ -114,30 +111,26 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object", "type":"object",
"properties":{ "properties":{
"name":{ "name":{
"type":"string", "type":"string"
"properties":{}
}, },
"age":{ "age":{
"type":"integer", "type":"integer"
"properties":{}
}, },
"address":{ "address":{
"type":"object", "type":"object",
"properties":{ "properties":{
"city":{ "city":{
"type":"string", "type":"string"
"properties":{}
}, },
"country":{ "country":{
"type":"string", "type":"string"
"properties":{}
} }
} }
} }
} }
} }
} }
}`, }`,
}, },
{ {
name: "Test with Array type Definition", name: "Test with Array type Definition",
@@ -155,18 +148,14 @@ func TestDefinition_MarshalJSON(t *testing.T) {
want: `{ want: `{
"type":"array", "type":"array",
"items":{ "items":{
"type":"string", "type":"string"
"properties":{
}
}, },
"properties":{ "properties":{
"name":{ "name":{
"type":"string", "type":"string"
"properties":{}
} }
} }
}`, }`,
}, },
} }
@@ -193,6 +182,232 @@ func TestDefinition_MarshalJSON(t *testing.T) {
} }
} }
func TestStructToSchema(t *testing.T) {
tests := []struct {
name string
in any
want string
}{
{
name: "Test with empty struct",
in: struct{}{},
want: `{
"type":"object",
"additionalProperties":false
}`,
},
{
name: "Test with struct containing many fields",
in: struct {
Name string `json:"name"`
Age int `json:"age"`
Active bool `json:"active"`
Height float64 `json:"height"`
Cities []struct {
Name string `json:"name"`
State string `json:"state"`
} `json:"cities"`
}{
Name: "John Doe",
Age: 30,
Cities: []struct {
Name string `json:"name"`
State string `json:"state"`
}{
{Name: "New York", State: "NY"},
{Name: "Los Angeles", State: "CA"},
},
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
},
"age":{
"type":"integer"
},
"active":{
"type":"boolean"
},
"height":{
"type":"number"
},
"cities":{
"type":"array",
"items":{
"additionalProperties":false,
"type":"object",
"properties":{
"name":{
"type":"string"
},
"state":{
"type":"string"
}
},
"required":["name","state"]
}
}
},
"required":["name","age","active","height","cities"],
"additionalProperties":false
}`,
},
{
name: "Test with description tag",
in: struct {
Name string `json:"name" description:"The name of the person"`
}{
Name: "John Doe",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string",
"description":"The name of the person"
}
},
"required":["name"],
"additionalProperties":false
}`,
},
{
name: "Test with required tag",
in: struct {
Name string `json:"name" required:"false"`
}{
Name: "John Doe",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
}
},
"additionalProperties":false
}`,
},
{
name: "Test with enum tag",
in: struct {
Color string `json:"color" enum:"red,green,blue"`
}{
Color: "red",
},
want: `{
"type":"object",
"properties":{
"color":{
"type":"string",
"enum":["red","green","blue"]
}
},
"required":["color"],
"additionalProperties":false
}`,
},
{
name: "Test with nullable tag",
in: struct {
Name *string `json:"name" nullable:"true"`
}{
Name: nil,
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string",
"nullable":true
}
},
"required":["name"],
"additionalProperties":false
}`,
},
{
name: "Test with exclude mark",
in: struct {
Name string `json:"-"`
}{
Name: "Name",
},
want: `{
"type":"object",
"additionalProperties":false
}`,
},
{
name: "Test with no json tag",
in: struct {
Name string
}{
Name: "",
},
want: `{
"type":"object",
"properties":{
"Name":{
"type":"string"
}
},
"required":["Name"],
"additionalProperties":false
}`,
},
{
name: "Test with omitempty tag",
in: struct {
Name string `json:"name,omitempty"`
}{
Name: "",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
}
},
"additionalProperties":false
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wantBytes := []byte(tt.want)
schema, err := jsonschema.GenerateSchemaForType(tt.in)
if err != nil {
t.Errorf("Failed to generate schema: error = %v", err)
return
}
var want map[string]interface{}
err = json.Unmarshal(wantBytes, &want)
if err != nil {
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
return
}
got := structToMap(t, schema)
gotPtr := structToMap(t, &schema)
if !reflect.DeepEqual(got, want) {
t.Errorf("MarshalJSON() got = %v, want %v", got, want)
}
if !reflect.DeepEqual(gotPtr, want) {
t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want)
}
})
}
}
func structToMap(t *testing.T, v any) map[string]any { func structToMap(t *testing.T, v any) map[string]any {
t.Helper() t.Helper()
gotBytes, err := json.Marshal(v) gotBytes, err := json.Marshal(v)

View File

@@ -3,7 +3,7 @@ package jsonschema_test
import ( import (
"testing" "testing"
"github.com/sashabaranov/go-openai/jsonschema" "git.vaala.cloud/VaalaCat/go-openai/jsonschema"
) )
func Test_Validate(t *testing.T) { func Test_Validate(t *testing.T) {

View File

@@ -7,9 +7,9 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "git.vaala.cloud/VaalaCat/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
var emptyStr = "" var emptyStr = ""

View File

@@ -9,8 +9,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
const testFineTuneModelID = "fine-tune-model-id" const testFineTuneModelID = "fine-tune-model-id"
@@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) {
checks.NoError(t, err, "GetModel error") checks.NoError(t, err, "GetModel error")
} }
// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server.
func TestGetModelO3(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint)
_, err := client.GetModel(context.Background(), "o3")
checks.NoError(t, err, "GetModel error for O3")
}
// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server.
func TestGetModelO4Mini(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint)
_, err := client.GetModel(context.Background(), "o4-mini")
checks.NoError(t, err, "GetModel error for O4Mini")
}
func TestAzureGetModel(t *testing.T) { func TestAzureGetModel(t *testing.T) {
client, server, teardown := setupAzureTestServer() client, server, teardown := setupAzureTestServer()
defer teardown() defer teardown()

View File

@@ -11,8 +11,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
// TestModeration Tests the moderations endpoint of the API using the mocked server. // TestModeration Tests the moderations endpoint of the API using the mocked server.

View File

@@ -1,8 +1,8 @@
package openai_test package openai_test
import ( import (
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "git.vaala.cloud/VaalaCat/go-openai/internal/test"
) )
func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
@@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
// numTokens Returns the number of GPT-3 encoded tokens in the given text. // numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI: // This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer/ // https://beta.openai.com/tokenizer.
// //
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available). // TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
func numTokens(s string) int { func numTokens(s string) int {

View File

@@ -28,15 +28,6 @@ var (
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
) )
var unsupportedToolsForO1Models = map[ToolType]struct{}{
ToolTypeFunction: {},
}
var availableMessageRoleForO1Models = map[string]struct{}{
ChatMessageRoleUser: {},
ChatMessageRoleAssistant: {},
}
// ReasoningValidator handles validation for o-series model requests. // ReasoningValidator handles validation for o-series model requests.
type ReasoningValidator struct{} type ReasoningValidator struct{}
@@ -49,8 +40,9 @@ func NewReasoningValidator() *ReasoningValidator {
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
o1Series := strings.HasPrefix(request.Model, "o1") o1Series := strings.HasPrefix(request.Model, "o1")
o3Series := strings.HasPrefix(request.Model, "o3") o3Series := strings.HasPrefix(request.Model, "o3")
o4Series := strings.HasPrefix(request.Model, "o4")
if !o1Series && !o3Series { if !o1Series && !o3Series && !o4Series {
return nil return nil
} }
@@ -58,12 +50,6 @@ func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
return err return err
} }
if o1Series {
if err := v.validateO1Specific(request); err != nil {
return err
}
}
return nil return nil
} }
@@ -93,19 +79,3 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion
return nil return nil
} }
// validateO1Specific checks O1-specific limitations.
func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error {
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
}
}
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
}
}
return nil
}

View File

@@ -3,8 +3,8 @@ package openai_test
import ( import (
"context" "context"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"encoding/json" "encoding/json"
"fmt" "fmt"

View File

@@ -11,17 +11,22 @@ const (
TTSModel1 SpeechModel = "tts-1" TTSModel1 SpeechModel = "tts-1"
TTSModel1HD SpeechModel = "tts-1-hd" TTSModel1HD SpeechModel = "tts-1-hd"
TTSModelCanary SpeechModel = "canary-tts" TTSModelCanary SpeechModel = "canary-tts"
TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts"
) )
type SpeechVoice string type SpeechVoice string
const ( const (
VoiceAlloy SpeechVoice = "alloy" VoiceAlloy SpeechVoice = "alloy"
VoiceAsh SpeechVoice = "ash"
VoiceBallad SpeechVoice = "ballad"
VoiceCoral SpeechVoice = "coral"
VoiceEcho SpeechVoice = "echo" VoiceEcho SpeechVoice = "echo"
VoiceFable SpeechVoice = "fable" VoiceFable SpeechVoice = "fable"
VoiceOnyx SpeechVoice = "onyx" VoiceOnyx SpeechVoice = "onyx"
VoiceNova SpeechVoice = "nova" VoiceNova SpeechVoice = "nova"
VoiceShimmer SpeechVoice = "shimmer" VoiceShimmer SpeechVoice = "shimmer"
VoiceVerse SpeechVoice = "verse"
) )
type SpeechResponseFormat string type SpeechResponseFormat string
@@ -39,6 +44,7 @@ type CreateSpeechRequest struct {
Model SpeechModel `json:"model"` Model SpeechModel `json:"model"`
Input string `json:"input"` Input string `json:"input"`
Voice SpeechVoice `json:"voice"` Voice SpeechVoice `json:"voice"`
Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd.
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
} }

View File

@@ -11,9 +11,9 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "git.vaala.cloud/VaalaCat/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestSpeechIntegration(t *testing.T) { func TestSpeechIntegration(t *testing.T) {

View File

@@ -6,13 +6,14 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"regexp"
utils "github.com/sashabaranov/go-openai/internal" utils "git.vaala.cloud/VaalaCat/go-openai/internal"
) )
var ( var (
headerData = []byte("data: ") headerData = regexp.MustCompile(`^data:\s*`)
errorPrefix = []byte(`data: {"error":`) errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
) )
type streamable interface { type streamable interface {
@@ -70,12 +71,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
} }
noSpaceLine := bytes.TrimSpace(rawLine) noSpaceLine := bytes.TrimSpace(rawLine)
if bytes.HasPrefix(noSpaceLine, errorPrefix) { if errorPrefix.Match(noSpaceLine) {
hasErrorPrefix = true hasErrorPrefix = true
} }
if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { if !headerData.Match(noSpaceLine) || hasErrorPrefix {
if hasErrorPrefix { if hasErrorPrefix {
noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil)
} }
writeErr := stream.errAccumulator.Write(noSpaceLine) writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil { if writeErr != nil {
@@ -89,7 +90,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
continue continue
} }
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil)
if string(noPrefixLine) == "[DONE]" { if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true stream.isFinished = true
return nil, io.EOF return nil, io.EOF

View File

@@ -6,9 +6,9 @@ import (
"errors" "errors"
"testing" "testing"
utils "github.com/sashabaranov/go-openai/internal" utils "git.vaala.cloud/VaalaCat/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test" "git.vaala.cloud/VaalaCat/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")

View File

@@ -10,8 +10,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai" "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
func TestCompletionsStreamWrongModel(t *testing.T) { func TestCompletionsStreamWrongModel(t *testing.T) {

View File

@@ -7,8 +7,8 @@ import (
"net/http" "net/http"
"testing" "testing"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
) )
// TestThread Tests the thread endpoint of the API using the mocked server. // TestThread Tests the thread endpoint of the API using the mocked server.

View File

@@ -3,8 +3,8 @@ package openai_test
import ( import (
"context" "context"
openai "github.com/sashabaranov/go-openai" openai "git.vaala.cloud/VaalaCat/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"encoding/json" "encoding/json"
"fmt" "fmt"