Compare commits

..

10 Commits

Author SHA1 Message Date
Mazyar Yousefiniyae shad
a62919e8c6 ref: add image url support to messages (#933)
Some checks failed
Integration tests / Run integration tests (push) Has been cancelled
Sanity check / Sanity check (push) Has been cancelled
* ref: add image url support to messages

* fix linter error

* fix linter error
2025-02-09 18:36:44 +00:00
rory malcolm
2054db016c Add support for O3-mini (#930)
* Add support for O3-mini

- Add support for the o3 mini set of models, including tests that match the constraints in OpenAI's API docs (https://platform.openai.com/docs/models#o3-mini).

* Deprecate and refactor

- Deprecate `ErrO1BetaLimitationsLogprobs` and `ErrO1BetaLimitationsOther`

- Implement `validationRequestForReasoningModels`, which works on both o1 & o3, and has per-model-type restrictions on functionality (eg, o3 class are allowed function calls and system messages, o1 isn't)

* Move reasoning validation to `reasoning_validator.go`

- Add a `NewReasoningValidator` which exposes a `Validate()` method for a given request

- Also adds a test for chat streams

* Final nits
2025-02-06 14:53:19 +00:00
saileshd1402
45aa99607b Make "Content" field in "ChatCompletionMessage" omitempty (#926) 2025-01-31 19:05:29 +00:00
Trevor Creech
9823a8bbbd Chat Completion API: add ReasoningEffort and new o1 models (#928)
* add reasoning_effort param

* add o1 model

* fix lint
2025-01-31 18:57:57 +00:00
Oleksandr Redko
7a2915a37d Simplify tests with T.TempDir (#929) 2025-01-31 18:55:41 +00:00
Sabuhi Gurbani
2a0ff5ac63 Added additional_messages (#914) 2024-12-27 10:01:16 +00:00
Alex Baranov
56a9acf86f Ignore test.mp3 (#913) 2024-12-08 13:16:48 +00:00
Tim Misiak
af5355f5b1 Fix ID field to be optional (#911)
The ID field is not always present for streaming responses. Without omitempty, the entire ToolCall struct will be missing.
2024-12-08 13:12:05 +00:00
Qiying Wang
c203ca001f feat: add RecvRaw (#896) 2024-11-30 10:29:05 +00:00
Liu Shuang
21fa42c18d feat: add gpt-4o-2024-11-20 model (#905) 2024-11-30 09:39:47 +00:00
19 changed files with 519 additions and 192 deletions

3
.gitignore vendored
View File

@@ -17,3 +17,6 @@
# Auth token for tests # Auth token for tests
.openai-token .openai-token
.idea .idea
# Generated by tests
test.mp3

View File

@@ -206,6 +206,7 @@ linters:
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # Remove unnecessary type conversions - unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters - unparam # Reports unused function parameters
- usetesting # Reports uses of functions with replacement inside the testing package
- wastedassign # wastedassign finds wasted assignment statements. - wastedassign # wastedassign finds wasted assignment statements.
- whitespace # Tool for detection of leading and trailing whitespace - whitespace # Tool for detection of leading and trailing whitespace
## you may want to enable ## you may want to enable

View File

@@ -40,12 +40,9 @@ func TestAudio(t *testing.T) {
ctx := context.Background() ctx := context.Background()
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
for _, tc := range testcases { for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3") path := filepath.Join(t.TempDir(), "fake.mp3")
test.CreateTestFile(t, path) test.CreateTestFile(t, path)
req := openai.AudioRequest{ req := openai.AudioRequest{
@@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) {
ctx := context.Background() ctx := context.Background()
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
for _, tc := range testcases { for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3") path := filepath.Join(t.TempDir(), "fake.mp3")
test.CreateTestFile(t, path) test.CreateTestFile(t, path)
req := openai.AudioRequest{ req := openai.AudioRequest{

View File

@@ -13,9 +13,7 @@ import (
) )
func TestAudioWithFailingFormBuilder(t *testing.T) { func TestAudioWithFailingFormBuilder(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t) path := filepath.Join(t.TempDir(), "fake.mp3")
defer cleanup()
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path) test.CreateTestFile(t, path)
req := AudioRequest{ req := AudioRequest{
@@ -63,9 +61,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
func TestCreateFileField(t *testing.T) { func TestCreateFileField(t *testing.T) {
t.Run("createFileField failing file", func(t *testing.T) { t.Run("createFileField failing file", func(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t) path := filepath.Join(t.TempDir(), "fake.mp3")
defer cleanup()
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path) test.CreateTestFile(t, path)
req := AudioRequest{ req := AudioRequest{

13
chat.go
View File

@@ -93,7 +93,7 @@ type ChatMessagePart struct {
type ChatCompletionMessage struct { type ChatCompletionMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart MultiContent []ChatMessagePart
@@ -132,7 +132,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
msg := struct { msg := struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"-"` MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
@@ -146,7 +146,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
msg := struct { msg := struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart MultiContent []ChatMessagePart
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
@@ -179,7 +179,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
type ToolCall struct { type ToolCall struct {
// Index is not nil only in chat completion chunk object // Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
ID string `json:"id"` ID string `json:"id,omitempty"`
Type ToolType `json:"type"` Type ToolType `json:"type"`
Function FunctionCall `json:"function"` Function FunctionCall `json:"function"`
} }
@@ -258,6 +258,8 @@ type ChatCompletionRequest struct {
// Store can be set to true to store the output of this completion request for use in distillations and evals. // Store can be set to true to store the output of this completion request for use in distillations and evals.
// https://platform.openai.com/docs/api-reference/chat/create#chat-create-store // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store
Store bool `json:"store,omitempty"` Store bool `json:"store,omitempty"`
// Controls effort on reasoning for reasoning models. It can be set to "low", "medium", or "high".
ReasoningEffort string `json:"reasoning_effort,omitempty"`
// Metadata to store with the completion. // Metadata to store with the completion.
Metadata map[string]string `json:"metadata,omitempty"` Metadata map[string]string `json:"metadata,omitempty"`
} }
@@ -390,7 +392,8 @@ func (c *Client) CreateChatCompletion(
return return
} }
if err = validateRequestForO1Models(request); err != nil { reasoningValidator := NewReasoningValidator()
if err = reasoningValidator.Validate(request); err != nil {
return return
} }

View File

@@ -80,7 +80,8 @@ func (c *Client) CreateChatCompletionStream(
} }
request.Stream = true request.Stream = true
if err = validateRequestForO1Models(request); err != nil { reasoningValidator := NewReasoningValidator()
if err = reasoningValidator.Validate(request); err != nil {
return return
} }

View File

@@ -792,6 +792,173 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
return true return true
} }
func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
dataBytes := []byte{}
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxCompletionTokens: 2000,
Model: openai.O3Mini20250131,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
expectedResponses := []openai.ChatCompletionStreamResponse{
{
ID: "1",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: "assistant",
},
},
},
},
{
ID: "2",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: "Hello",
},
},
},
},
{
ID: "3",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: " from",
},
},
},
},
{
ID: "4",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: " O3Mini",
},
},
},
},
{
ID: "5",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{},
FinishReason: "stop",
},
},
},
}
for ix, expectedResponse := range expectedResponses {
b, _ := json.Marshal(expectedResponse)
t.Logf("%d: %s", ix, string(b))
receivedResponse, streamErr := stream.Recv()
checks.NoError(t, streamErr, "stream.Recv() failed")
if !compareChatResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
}
_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}
}
func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err)
}
}
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
if c1.Index != c2.Index { if c1.Index != c2.Index {
return false return false

View File

@@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
MaxTokens: 5, MaxTokens: 5,
Model: openai.O1Preview, Model: openai.O1Preview,
}, },
expectedError: openai.ErrO1MaxTokensDeprecated, expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
}, },
{ {
name: "o1-mini_MaxTokens_deprecated", name: "o1-mini_MaxTokens_deprecated",
@@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
MaxTokens: 5, MaxTokens: 5,
Model: openai.O1Mini, Model: openai.O1Mini,
}, },
expectedError: openai.ErrO1MaxTokensDeprecated, expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
}, },
} }
@@ -104,7 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
LogProbs: true, LogProbs: true,
Model: openai.O1Preview, Model: openai.O1Preview,
}, },
expectedError: openai.ErrO1BetaLimitationsLogprobs, expectedError: openai.ErrReasoningModelLimitationsLogprobs,
}, },
{ {
name: "message_type_unsupported", name: "message_type_unsupported",
@@ -155,7 +155,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
}, },
Temperature: float32(2), Temperature: float32(2),
}, },
expectedError: openai.ErrO1BetaLimitationsOther, expectedError: openai.ErrReasoningModelLimitationsOther,
}, },
{ {
name: "set_top_unsupported", name: "set_top_unsupported",
@@ -173,7 +173,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
Temperature: float32(1), Temperature: float32(1),
TopP: float32(0.1), TopP: float32(0.1),
}, },
expectedError: openai.ErrO1BetaLimitationsOther, expectedError: openai.ErrReasoningModelLimitationsOther,
}, },
{ {
name: "set_n_unsupported", name: "set_n_unsupported",
@@ -192,7 +192,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
TopP: float32(1), TopP: float32(1),
N: 2, N: 2,
}, },
expectedError: openai.ErrO1BetaLimitationsOther, expectedError: openai.ErrReasoningModelLimitationsOther,
}, },
{ {
name: "set_presence_penalty_unsupported", name: "set_presence_penalty_unsupported",
@@ -209,7 +209,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
}, },
PresencePenalty: float32(1), PresencePenalty: float32(1),
}, },
expectedError: openai.ErrO1BetaLimitationsOther, expectedError: openai.ErrReasoningModelLimitationsOther,
}, },
{ {
name: "set_frequency_penalty_unsupported", name: "set_frequency_penalty_unsupported",
@@ -226,7 +226,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
}, },
FrequencyPenalty: float32(0.1), FrequencyPenalty: float32(0.1),
}, },
expectedError: openai.ErrO1BetaLimitationsOther, expectedError: openai.ErrReasoningModelLimitationsOther,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
ctx := context.Background()
_, err := client.CreateChatCompletion(ctx, tt.in)
checks.HasError(t, err)
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, tt.expectedError, msg)
})
}
}
func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) {
tests := []struct {
name string
in openai.ChatCompletionRequest
expectedError error
}{
{
name: "log_probs_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
LogProbs: true,
Model: openai.O3Mini,
},
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
},
{
name: "set_temperature_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(2),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_top_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(0.1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_n_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(1),
N: 2,
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_presence_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
PresencePenalty: float32(1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_frequency_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
FrequencyPenalty: float32(0.1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
}, },
} }
@@ -308,6 +428,23 @@ func TestO1ModelChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error") checks.NoError(t, err, "CreateChatCompletion error")
} }
func TestO3ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: openai.O3Mini,
MaxCompletionTokens: 1000,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) { func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
@@ -631,7 +768,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
t.Fatalf("Unexpected error") t.Fatalf("Unexpected error")
} }
res = strings.ReplaceAll(string(s), " ", "") res = strings.ReplaceAll(string(s), " ", "")
if res != `{"role":"user","content":""}` { if res != `{"role":"user"}` {
t.Fatalf("invalid message: %s", string(s)) t.Fatalf("invalid message: %s", string(s))
} }
} }

View File

@@ -2,24 +2,9 @@ package openai
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
) )
var (
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
)
var (
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)
// GPT3 Defines the models provided by OpenAI to use when generating // GPT3 Defines the models provided by OpenAI to use when generating
// completions from OpenAI. // completions from OpenAI.
// GPT3 Models are designed for text-based tasks. For code-specific // GPT3 Models are designed for text-based tasks. For code-specific
@@ -29,6 +14,10 @@ const (
O1Mini20240912 = "o1-mini-2024-09-12" O1Mini20240912 = "o1-mini-2024-09-12"
O1Preview = "o1-preview" O1Preview = "o1-preview"
O1Preview20240912 = "o1-preview-2024-09-12" O1Preview20240912 = "o1-preview-2024-09-12"
O1 = "o1"
O120241217 = "o1-2024-12-17"
O3Mini = "o3-mini"
O3Mini20250131 = "o3-mini-2025-01-31"
GPT432K0613 = "gpt-4-32k-0613" GPT432K0613 = "gpt-4-32k-0613"
GPT432K0314 = "gpt-4-32k-0314" GPT432K0314 = "gpt-4-32k-0314"
GPT432K = "gpt-4-32k" GPT432K = "gpt-4-32k"
@@ -37,6 +26,7 @@ const (
GPT4o = "gpt-4o" GPT4o = "gpt-4o"
GPT4o20240513 = "gpt-4o-2024-05-13" GPT4o20240513 = "gpt-4o-2024-05-13"
GPT4o20240806 = "gpt-4o-2024-08-06" GPT4o20240806 = "gpt-4o-2024-08-06"
GPT4o20241120 = "gpt-4o-2024-11-20"
GPT4oLatest = "chatgpt-4o-latest" GPT4oLatest = "chatgpt-4o-latest"
GPT4oMini = "gpt-4o-mini" GPT4oMini = "gpt-4o-mini"
GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18"
@@ -93,21 +83,14 @@ const (
CodexCodeDavinci001 = "code-davinci-001" CodexCodeDavinci001 = "code-davinci-001"
) )
// O1SeriesModels List of new Series of OpenAI models.
// Some old api attributes not supported.
var O1SeriesModels = map[string]struct{}{
O1Mini: {},
O1Mini20240912: {},
O1Preview: {},
O1Preview20240912: {},
}
var disabledModelsForEndpoints = map[string]map[string]bool{ var disabledModelsForEndpoints = map[string]map[string]bool{
"/completions": { "/completions": {
O1Mini: true, O1Mini: true,
O1Mini20240912: true, O1Mini20240912: true,
O1Preview: true, O1Preview: true,
O1Preview20240912: true, O1Preview20240912: true,
O3Mini: true,
O3Mini20250131: true,
GPT3Dot5Turbo: true, GPT3Dot5Turbo: true,
GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0301: true,
GPT3Dot5Turbo0613: true, GPT3Dot5Turbo0613: true,
@@ -119,6 +102,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT4o: true, GPT4o: true,
GPT4o20240513: true, GPT4o20240513: true,
GPT4o20240806: true, GPT4o20240806: true,
GPT4o20241120: true,
GPT4oLatest: true, GPT4oLatest: true,
GPT4oMini: true, GPT4oMini: true,
GPT4oMini20240718: true, GPT4oMini20240718: true,
@@ -179,64 +163,6 @@ func checkPromptType(prompt any) bool {
return true // all items in the slice are string, so it is []string return true // all items in the slice are string, so it is []string
} }
var unsupportedToolsForO1Models = map[ToolType]struct{}{
ToolTypeFunction: {},
}
var availableMessageRoleForO1Models = map[string]struct{}{
ChatMessageRoleUser: {},
ChatMessageRoleAssistant: {},
}
// validateRequestForO1Models checks for deprecated fields of OpenAI models.
func validateRequestForO1Models(request ChatCompletionRequest) error {
if _, found := O1SeriesModels[request.Model]; !found {
return nil
}
if request.MaxTokens > 0 {
return ErrO1MaxTokensDeprecated
}
// Logprobs: not supported.
if request.LogProbs {
return ErrO1BetaLimitationsLogprobs
}
// Message types: user and assistant messages only, system messages are not supported.
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
}
}
// Tools: tools, function calling, and response format parameters are not supported
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
}
}
// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
if request.Temperature > 0 && request.Temperature != 1 {
return ErrO1BetaLimitationsOther
}
if request.TopP > 0 && request.TopP != 1 {
return ErrO1BetaLimitationsOther
}
if request.N > 0 && request.N != 1 {
return ErrO1BetaLimitationsOther
}
if request.PresencePenalty > 0 {
return ErrO1BetaLimitationsOther
}
if request.FrequencyPenalty > 0 {
return ErrO1BetaLimitationsOther
}
return nil
}
// CompletionRequest represents a request structure for completion API. // CompletionRequest represents a request structure for completion API.
type CompletionRequest struct { type CompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`

View File

@@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"path/filepath"
"testing" "testing"
"time" "time"
@@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) {
defer teardown() defer teardown()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
origin, err := os.Create("image.png") origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil { if err != nil {
t.Error("open origin file error") t.Fatalf("open origin file error: %v", err)
return
} }
defer origin.Close()
mask, err := os.Create("mask.png") mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png"))
if err != nil { if err != nil {
t.Error("open mask file error") t.Fatalf("open mask file error: %v", err)
return
} }
defer mask.Close()
defer func() {
mask.Close()
origin.Close()
os.Remove("mask.png")
os.Remove("image.png")
}()
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
Image: origin, Image: origin,
@@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) {
defer teardown() defer teardown()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
origin, err := os.Create("image.png") origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil { if err != nil {
t.Error("open origin file error") t.Fatalf("open origin file error: %v", err)
return
} }
defer origin.Close()
defer func() {
origin.Close()
os.Remove("image.png")
}()
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
Image: origin, Image: origin,
@@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) {
defer teardown() defer teardown()
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
origin, err := os.Create("image.png") origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil { if err != nil {
t.Error("open origin file error") t.Fatalf("open origin file error: %v", err)
return
} }
defer origin.Close()
defer func() {
origin.Close()
os.Remove("image.png")
}()
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
Image: origin, Image: origin,

View File

@@ -1,7 +1,6 @@
package openai //nolint:testpackage // testing private field package openai //nolint:testpackage // testing private field
import ( import (
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"bytes" "bytes"
@@ -20,15 +19,11 @@ func (*failingWriter) Write([]byte) (int, error) {
} }
func TestFormBuilderWithFailingWriter(t *testing.T) { func TestFormBuilderWithFailingWriter(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t) file, err := os.CreateTemp(t.TempDir(), "")
defer cleanup()
file, err := os.CreateTemp(dir, "")
if err != nil { if err != nil {
t.Errorf("Error creating tmp file: %v", err) t.Fatalf("Error creating tmp file: %v", err)
} }
defer file.Close() defer file.Close()
defer os.Remove(file.Name())
builder := NewFormBuilder(&failingWriter{}) builder := NewFormBuilder(&failingWriter{})
err = builder.CreateFormFile("file", file) err = builder.CreateFormFile("file", file)
@@ -36,15 +31,11 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
} }
func TestFormBuilderWithClosedFile(t *testing.T) { func TestFormBuilderWithClosedFile(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t) file, err := os.CreateTemp(t.TempDir(), "")
defer cleanup()
file, err := os.CreateTemp(dir, "")
if err != nil { if err != nil {
t.Errorf("Error creating tmp file: %v", err) t.Fatalf("Error creating tmp file: %v", err)
} }
file.Close() file.Close()
defer os.Remove(file.Name())
body := &bytes.Buffer{} body := &bytes.Buffer{}
builder := NewFormBuilder(body) builder := NewFormBuilder(body)

View File

@@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) {
file.Close() file.Close()
} }
// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called.
func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {
t.Helper()
path, err := os.MkdirTemp(os.TempDir(), "")
checks.NoError(t, err)
return path, func() { os.RemoveAll(path) }
}
// TokenRoundTripper is a struct that implements the RoundTripper // TokenRoundTripper is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token // interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each // to the request header. We need this because the API requires that each

View File

@@ -41,6 +41,7 @@ type MessageContent struct {
Type string `json:"type"` Type string `json:"type"`
Text *MessageText `json:"text,omitempty"` Text *MessageText `json:"text,omitempty"`
ImageFile *ImageFile `json:"image_file,omitempty"` ImageFile *ImageFile `json:"image_file,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
} }
type MessageText struct { type MessageText struct {
Value string `json:"value"` Value string `json:"value"`
@@ -51,6 +52,11 @@ type ImageFile struct {
FileID string `json:"file_id"` FileID string `json:"file_id"`
} }
type ImageURL struct {
URL string `json:"url"`
Detail string `json:"detail"`
}
type MessageRequest struct { type MessageRequest struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`

View File

@@ -29,9 +29,9 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
// numTokens Returns the number of GPT-3 encoded tokens in the given text. // numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI: // This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer // https://beta.openai.com/tokenizer/
// //
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) // TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
func numTokens(s string) int { func numTokens(s string) int {
return int(float32(len(s)) / 4) return int(float32(len(s)) / 4)
} }

111
reasoning_validator.go Normal file
View File

@@ -0,0 +1,111 @@
package openai
import (
"errors"
"strings"
)
var (
// Deprecated: use ErrReasoningModelMaxTokensDeprecated instead.
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
)
var (
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
// Deprecated: use ErrReasoningModelLimitations* instead.
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)
var (
//nolint:lll
ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens")
ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)
var unsupportedToolsForO1Models = map[ToolType]struct{}{
ToolTypeFunction: {},
}
var availableMessageRoleForO1Models = map[string]struct{}{
ChatMessageRoleUser: {},
ChatMessageRoleAssistant: {},
}
// ReasoningValidator handles validation for o-series model requests.
type ReasoningValidator struct{}
// NewReasoningValidator creates a new validator for o-series models.
func NewReasoningValidator() *ReasoningValidator {
return &ReasoningValidator{}
}
// Validate performs all validation checks for o-series models.
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
o1Series := strings.HasPrefix(request.Model, "o1")
o3Series := strings.HasPrefix(request.Model, "o3")
if !o1Series && !o3Series {
return nil
}
if err := v.validateReasoningModelParams(request); err != nil {
return err
}
if o1Series {
if err := v.validateO1Specific(request); err != nil {
return err
}
}
return nil
}
// validateReasoningModelParams checks reasoning model parameters.
func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error {
if request.MaxTokens > 0 {
return ErrReasoningModelMaxTokensDeprecated
}
if request.LogProbs {
return ErrReasoningModelLimitationsLogprobs
}
if request.Temperature > 0 && request.Temperature != 1 {
return ErrReasoningModelLimitationsOther
}
if request.TopP > 0 && request.TopP != 1 {
return ErrReasoningModelLimitationsOther
}
if request.N > 0 && request.N != 1 {
return ErrReasoningModelLimitationsOther
}
if request.PresencePenalty > 0 {
return ErrReasoningModelLimitationsOther
}
if request.FrequencyPenalty > 0 {
return ErrReasoningModelLimitationsOther
}
return nil
}
// validateO1Specific checks O1-specific limitations.
func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error {
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
}
}
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
}
}
return nil
}

13
run.go
View File

@@ -83,12 +83,13 @@ const (
) )
type RunRequest struct { type RunRequest struct {
AssistantID string `json:"assistant_id"` AssistantID string `json:"assistant_id"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Instructions string `json:"instructions,omitempty"` Instructions string `json:"instructions,omitempty"`
AdditionalInstructions string `json:"additional_instructions,omitempty"` AdditionalInstructions string `json:"additional_instructions,omitempty"`
Tools []Tool `json:"tools,omitempty"` AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"` Tools []Tool `json:"tools,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
// Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // Sampling temperature between 0 and 2. Higher values like 0.8 are more random.
// lower values are more focused and deterministic. // lower values are more focused and deterministic.

View File

@@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) {
defer teardown() defer teardown()
server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
dir, cleanup := test.CreateTestDirectory(t) path := filepath.Join(t.TempDir(), "fake.mp3")
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path) test.CreateTestFile(t, path)
defer cleanup()
// audio endpoints only accept POST requests // audio endpoints only accept POST requests
if r.Method != "POST" { if r.Method != "POST" {

View File

@@ -32,17 +32,28 @@ type streamReader[T streamable] struct {
} }
func (stream *streamReader[T]) Recv() (response T, err error) { func (stream *streamReader[T]) Recv() (response T, err error) {
if stream.isFinished { rawLine, err := stream.RecvRaw()
err = io.EOF if err != nil {
return return
} }
response, err = stream.processLines() err = stream.unmarshaler.Unmarshal(rawLine, &response)
return if err != nil {
return
}
return response, nil
}
func (stream *streamReader[T]) RecvRaw() ([]byte, error) {
if stream.isFinished {
return nil, io.EOF
}
return stream.processLines()
} }
//nolint:gocognit //nolint:gocognit
func (stream *streamReader[T]) processLines() (T, error) { func (stream *streamReader[T]) processLines() ([]byte, error) {
var ( var (
emptyMessagesCount uint emptyMessagesCount uint
hasErrorPrefix bool hasErrorPrefix bool
@@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) {
if readErr != nil || hasErrorPrefix { if readErr != nil || hasErrorPrefix {
respErr := stream.unmarshalError() respErr := stream.unmarshalError()
if respErr != nil { if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error) return nil, fmt.Errorf("error, %w", respErr.Error)
} }
return *new(T), readErr return nil, readErr
} }
noSpaceLine := bytes.TrimSpace(rawLine) noSpaceLine := bytes.TrimSpace(rawLine)
@@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) {
} }
writeErr := stream.errAccumulator.Write(noSpaceLine) writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil { if writeErr != nil {
return *new(T), writeErr return nil, writeErr
} }
emptyMessagesCount++ emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit { if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(T), ErrTooManyEmptyStreamMessages return nil, ErrTooManyEmptyStreamMessages
} }
continue continue
@@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) {
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if string(noPrefixLine) == "[DONE]" { if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true stream.isFinished = true
return *new(T), io.EOF return nil, io.EOF
} }
var response T return noPrefixLine, nil
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
if unmarshalErr != nil {
return *new(T), unmarshalErr
}
return response, nil
} }
} }

View File

@@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
_, err := stream.Recv() _, err := stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
} }
func TestStreamReaderRecvRaw(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))),
}
rawLine, err := stream.RecvRaw()
if err != nil {
t.Fatalf("Did not return raw line: %v", err)
}
if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) {
t.Fatalf("Did not return raw line: %v", string(rawLine))
}
}