diff --git a/chat.go b/chat.go index 8ea7238..ce24fa3 100644 --- a/chat.go +++ b/chat.go @@ -392,7 +392,8 @@ func (c *Client) CreateChatCompletion( return } - if err = validateRequestForO1Models(request); err != nil { + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 58b2651..525b445 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -80,7 +80,8 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - if err = validateRequestForO1Models(request); err != nil { + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { return } diff --git a/chat_stream_test.go b/chat_stream_test.go index 28a9acf..4d992e4 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -792,6 +792,173 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return true } +func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxCompletionTokens: 2000, + Model: openai.O3Mini20250131, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " from", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " O3Mini", + }, + }, + }, + }, + { + ID: "5", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err) + } +} + func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false diff --git a/chat_test.go b/chat_test.go index cea549c..fc6c4a9 100644 --- a/chat_test.go +++ b/chat_test.go @@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { MaxTokens: 5, Model: openai.O1Preview, }, - expectedError: openai.ErrO1MaxTokensDeprecated, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, }, { name: "o1-mini_MaxTokens_deprecated", @@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { MaxTokens: 5, Model: openai.O1Mini, }, - expectedError: openai.ErrO1MaxTokensDeprecated, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, }, } @@ -104,7 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { LogProbs: true, Model: openai.O1Preview, }, - expectedError: openai.ErrO1BetaLimitationsLogprobs, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, }, { name: "message_type_unsupported", @@ -155,7 +155,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, Temperature: float32(2), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_top_unsupported", @@ -173,7 +173,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Temperature: float32(1), TopP: float32(0.1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_n_unsupported", @@ -192,7 +192,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { TopP: float32(1), N: 2, }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_presence_penalty_unsupported", @@ -209,7 +209,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, PresencePenalty: float32(1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_frequency_penalty_unsupported", @@ -226,7 +226,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, FrequencyPenalty: float32(0.1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O3Mini, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, }, } @@ -308,6 +428,23 @@ func TestO1ModelChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestO3ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O3Mini, + MaxCompletionTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() diff --git a/completion.go b/completion.go index 6272468..1985293 100644 --- a/completion.go +++ b/completion.go @@ -2,24 +2,9 @@ package openai import ( "context" - "errors" "net/http" ) -var ( - ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll - ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll - ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll - ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll -) - -var ( - ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll - ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll - ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll - ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll -) - // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific @@ -31,6 +16,8 @@ const ( O1Preview20240912 = "o1-preview-2024-09-12" O1 = "o1" O120241217 = "o1-2024-12-17" + O3Mini = "o3-mini" + O3Mini20250131 = "o3-mini-2025-01-31" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -96,21 +83,14 @@ const ( CodexCodeDavinci001 = "code-davinci-001" ) -// O1SeriesModels List of new Series of OpenAI models. -// Some old api attributes not supported. -var O1SeriesModels = map[string]struct{}{ - O1Mini: {}, - O1Mini20240912: {}, - O1Preview: {}, - O1Preview20240912: {}, -} - var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { O1Mini: true, O1Mini20240912: true, O1Preview: true, O1Preview20240912: true, + O3Mini: true, + O3Mini20250131: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, @@ -183,64 +163,6 @@ func checkPromptType(prompt any) bool { return true // all items in the slice are string, so it is []string } -var unsupportedToolsForO1Models = map[ToolType]struct{}{ - ToolTypeFunction: {}, -} - -var availableMessageRoleForO1Models = map[string]struct{}{ - ChatMessageRoleUser: {}, - ChatMessageRoleAssistant: {}, -} - -// validateRequestForO1Models checks for deprecated fields of OpenAI models. -func validateRequestForO1Models(request ChatCompletionRequest) error { - if _, found := O1SeriesModels[request.Model]; !found { - return nil - } - - if request.MaxTokens > 0 { - return ErrO1MaxTokensDeprecated - } - - // Logprobs: not supported. - if request.LogProbs { - return ErrO1BetaLimitationsLogprobs - } - - // Message types: user and assistant messages only, system messages are not supported. - for _, m := range request.Messages { - if _, found := availableMessageRoleForO1Models[m.Role]; !found { - return ErrO1BetaLimitationsMessageTypes - } - } - - // Tools: tools, function calling, and response format parameters are not supported - for _, t := range request.Tools { - if _, found := unsupportedToolsForO1Models[t.Type]; found { - return ErrO1BetaLimitationsTools - } - } - - // Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0. - if request.Temperature > 0 && request.Temperature != 1 { - return ErrO1BetaLimitationsOther - } - if request.TopP > 0 && request.TopP != 1 { - return ErrO1BetaLimitationsOther - } - if request.N > 0 && request.N != 1 { - return ErrO1BetaLimitationsOther - } - if request.PresencePenalty > 0 { - return ErrO1BetaLimitationsOther - } - if request.FrequencyPenalty > 0 { - return ErrO1BetaLimitationsOther - } - - return nil -} - // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { Model string `json:"model"` diff --git a/reasoning_validator.go b/reasoning_validator.go new file mode 100644 index 0000000..42a9fbd --- /dev/null +++ b/reasoning_validator.go @@ -0,0 +1,111 @@ +package openai + +import ( + "errors" + "strings" +) + +var ( + // Deprecated: use ErrReasoningModelMaxTokensDeprecated instead. + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll + ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll + ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll + ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll +) + +var ( + ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll + ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll + // Deprecated: use ErrReasoningModelLimitations* instead. + ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +var ( + //nolint:lll + ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") + ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +var unsupportedToolsForO1Models = map[ToolType]struct{}{ + ToolTypeFunction: {}, +} + +var availableMessageRoleForO1Models = map[string]struct{}{ + ChatMessageRoleUser: {}, + ChatMessageRoleAssistant: {}, +} + +// ReasoningValidator handles validation for o-series model requests. +type ReasoningValidator struct{} + +// NewReasoningValidator creates a new validator for o-series models. +func NewReasoningValidator() *ReasoningValidator { + return &ReasoningValidator{} +} + +// Validate performs all validation checks for o-series models. +func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { + o1Series := strings.HasPrefix(request.Model, "o1") + o3Series := strings.HasPrefix(request.Model, "o3") + + if !o1Series && !o3Series { + return nil + } + + if err := v.validateReasoningModelParams(request); err != nil { + return err + } + + if o1Series { + if err := v.validateO1Specific(request); err != nil { + return err + } + } + + return nil +} + +// validateReasoningModelParams checks reasoning model parameters. +func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error { + if request.MaxTokens > 0 { + return ErrReasoningModelMaxTokensDeprecated + } + if request.LogProbs { + return ErrReasoningModelLimitationsLogprobs + } + if request.Temperature > 0 && request.Temperature != 1 { + return ErrReasoningModelLimitationsOther + } + if request.TopP > 0 && request.TopP != 1 { + return ErrReasoningModelLimitationsOther + } + if request.N > 0 && request.N != 1 { + return ErrReasoningModelLimitationsOther + } + if request.PresencePenalty > 0 { + return ErrReasoningModelLimitationsOther + } + if request.FrequencyPenalty > 0 { + return ErrReasoningModelLimitationsOther + } + + return nil +} + +// validateO1Specific checks O1-specific limitations. +func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error { + for _, m := range request.Messages { + if _, found := availableMessageRoleForO1Models[m.Role]; !found { + return ErrO1BetaLimitationsMessageTypes + } + } + + for _, t := range request.Tools { + if _, found := unsupportedToolsForO1Models[t.Type]; found { + return ErrO1BetaLimitationsTools + } + } + return nil +}