Add support for O3-mini (#930)
* Add support for O3-mini - Add support for the o3 mini set of models, including tests that match the constraints in OpenAI's API docs (https://platform.openai.com/docs/models#o3-mini). * Deprecate and refactor - Deprecate `ErrO1BetaLimitationsLogprobs` and `ErrO1BetaLimitationsOther` - Implement `validationRequestForReasoningModels`, which works on both o1 & o3, and has per-model-type restrictions on functionality (eg, o3 class are allowed function calls and system messages, o1 isn't) * Move reasoning validation to `reasoning_validator.go` - Add a `NewReasoningValidator` which exposes a `Validate()` method for a given request - Also adds a test for chat streams * Final nits
This commit is contained in:
3
chat.go
3
chat.go
@@ -392,7 +392,8 @@ func (c *Client) CreateChatCompletion(
|
||||
return
|
||||
}
|
||||
|
||||
if err = validateRequestForO1Models(request); err != nil {
|
||||
reasoningValidator := NewReasoningValidator()
|
||||
if err = reasoningValidator.Validate(request); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -80,7 +80,8 @@ func (c *Client) CreateChatCompletionStream(
|
||||
}
|
||||
|
||||
request.Stream = true
|
||||
if err = validateRequestForO1Models(request); err != nil {
|
||||
reasoningValidator := NewReasoningValidator()
|
||||
if err = reasoningValidator.Validate(request); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -792,6 +792,173 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
dataBytes := []byte{}
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
checks.NoError(t, err, "Write error")
|
||||
})
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 2000,
|
||||
Model: openai.O3Mini20250131,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||
defer stream.Close()
|
||||
|
||||
expectedResponses := []openai.ChatCompletionStreamResponse{
|
||||
{
|
||||
ID: "1",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Role: "assistant",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "3",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Content: " from",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "4",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Content: " O3Mini",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "5",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for ix, expectedResponse := range expectedResponses {
|
||||
b, _ := json.Marshal(expectedResponse)
|
||||
t.Logf("%d: %s", ix, string(b))
|
||||
|
||||
receivedResponse, streamErr := stream.Recv()
|
||||
checks.NoError(t, streamErr, "stream.Recv() failed")
|
||||
if !compareChatResponses(expectedResponse, receivedResponse) {
|
||||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
|
||||
}
|
||||
}
|
||||
|
||||
_, streamErr := stream.Recv()
|
||||
if !errors.Is(streamErr, io.EOF) {
|
||||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
|
||||
client, _, _ := setupOpenAITestServer()
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxTokens: 100, // This will trigger the validator to fail
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
if stream != nil {
|
||||
t.Error("Expected nil stream when validation fails")
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
|
||||
if c1.Index != c2.Index {
|
||||
return false
|
||||
|
||||
153
chat_test.go
153
chat_test.go
@@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
|
||||
MaxTokens: 5,
|
||||
Model: openai.O1Preview,
|
||||
},
|
||||
expectedError: openai.ErrO1MaxTokensDeprecated,
|
||||
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
|
||||
},
|
||||
{
|
||||
name: "o1-mini_MaxTokens_deprecated",
|
||||
@@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
|
||||
MaxTokens: 5,
|
||||
Model: openai.O1Mini,
|
||||
},
|
||||
expectedError: openai.ErrO1MaxTokensDeprecated,
|
||||
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
LogProbs: true,
|
||||
Model: openai.O1Preview,
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsLogprobs,
|
||||
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
||||
},
|
||||
{
|
||||
name: "message_type_unsupported",
|
||||
@@ -155,7 +155,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
},
|
||||
Temperature: float32(2),
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_top_unsupported",
|
||||
@@ -173,7 +173,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
Temperature: float32(1),
|
||||
TopP: float32(0.1),
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_n_unsupported",
|
||||
@@ -192,7 +192,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
TopP: float32(1),
|
||||
N: 2,
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_presence_penalty_unsupported",
|
||||
@@ -209,7 +209,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
},
|
||||
PresencePenalty: float32(1),
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_frequency_penalty_unsupported",
|
||||
@@ -226,7 +226,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
},
|
||||
FrequencyPenalty: float32(0.1),
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.CreateChatCompletion(ctx, tt.in)
|
||||
checks.HasError(t, err)
|
||||
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
||||
checks.ErrorIs(t, err, tt.expectedError, msg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in openai.ChatCompletionRequest
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "log_probs_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
LogProbs: true,
|
||||
Model: openai.O3Mini,
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
||||
},
|
||||
{
|
||||
name: "set_temperature_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
Temperature: float32(2),
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_top_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
Temperature: float32(1),
|
||||
TopP: float32(0.1),
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_n_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
Temperature: float32(1),
|
||||
TopP: float32(1),
|
||||
N: 2,
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_presence_penalty_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
PresencePenalty: float32(1),
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_frequency_penalty_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
FrequencyPenalty: float32(0.1),
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -308,6 +428,23 @@ func TestO1ModelChatCompletions(t *testing.T) {
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
func TestO3ModelChatCompletions(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||
Model: openai.O3Mini,
|
||||
MaxCompletionTokens: 1000,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
})
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||
func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
|
||||
@@ -2,24 +2,9 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
|
||||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
||||
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
||||
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
|
||||
)
|
||||
|
||||
var (
|
||||
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||
)
|
||||
|
||||
// GPT3 Defines the models provided by OpenAI to use when generating
|
||||
// completions from OpenAI.
|
||||
// GPT3 Models are designed for text-based tasks. For code-specific
|
||||
@@ -31,6 +16,8 @@ const (
|
||||
O1Preview20240912 = "o1-preview-2024-09-12"
|
||||
O1 = "o1"
|
||||
O120241217 = "o1-2024-12-17"
|
||||
O3Mini = "o3-mini"
|
||||
O3Mini20250131 = "o3-mini-2025-01-31"
|
||||
GPT432K0613 = "gpt-4-32k-0613"
|
||||
GPT432K0314 = "gpt-4-32k-0314"
|
||||
GPT432K = "gpt-4-32k"
|
||||
@@ -96,21 +83,14 @@ const (
|
||||
CodexCodeDavinci001 = "code-davinci-001"
|
||||
)
|
||||
|
||||
// O1SeriesModels List of new Series of OpenAI models.
|
||||
// Some old api attributes not supported.
|
||||
var O1SeriesModels = map[string]struct{}{
|
||||
O1Mini: {},
|
||||
O1Mini20240912: {},
|
||||
O1Preview: {},
|
||||
O1Preview20240912: {},
|
||||
}
|
||||
|
||||
var disabledModelsForEndpoints = map[string]map[string]bool{
|
||||
"/completions": {
|
||||
O1Mini: true,
|
||||
O1Mini20240912: true,
|
||||
O1Preview: true,
|
||||
O1Preview20240912: true,
|
||||
O3Mini: true,
|
||||
O3Mini20250131: true,
|
||||
GPT3Dot5Turbo: true,
|
||||
GPT3Dot5Turbo0301: true,
|
||||
GPT3Dot5Turbo0613: true,
|
||||
@@ -183,64 +163,6 @@ func checkPromptType(prompt any) bool {
|
||||
return true // all items in the slice are string, so it is []string
|
||||
}
|
||||
|
||||
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
||||
ToolTypeFunction: {},
|
||||
}
|
||||
|
||||
var availableMessageRoleForO1Models = map[string]struct{}{
|
||||
ChatMessageRoleUser: {},
|
||||
ChatMessageRoleAssistant: {},
|
||||
}
|
||||
|
||||
// validateRequestForO1Models checks for deprecated fields of OpenAI models.
|
||||
func validateRequestForO1Models(request ChatCompletionRequest) error {
|
||||
if _, found := O1SeriesModels[request.Model]; !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
if request.MaxTokens > 0 {
|
||||
return ErrO1MaxTokensDeprecated
|
||||
}
|
||||
|
||||
// Logprobs: not supported.
|
||||
if request.LogProbs {
|
||||
return ErrO1BetaLimitationsLogprobs
|
||||
}
|
||||
|
||||
// Message types: user and assistant messages only, system messages are not supported.
|
||||
for _, m := range request.Messages {
|
||||
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
|
||||
return ErrO1BetaLimitationsMessageTypes
|
||||
}
|
||||
}
|
||||
|
||||
// Tools: tools, function calling, and response format parameters are not supported
|
||||
for _, t := range request.Tools {
|
||||
if _, found := unsupportedToolsForO1Models[t.Type]; found {
|
||||
return ErrO1BetaLimitationsTools
|
||||
}
|
||||
}
|
||||
|
||||
// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
|
||||
if request.Temperature > 0 && request.Temperature != 1 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
if request.TopP > 0 && request.TopP != 1 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
if request.N > 0 && request.N != 1 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
if request.PresencePenalty > 0 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
if request.FrequencyPenalty > 0 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompletionRequest represents a request structure for completion API.
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
||||
111
reasoning_validator.go
Normal file
111
reasoning_validator.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// Deprecated: use ErrReasoningModelMaxTokensDeprecated instead.
|
||||
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
|
||||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
||||
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
||||
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
|
||||
)
|
||||
|
||||
var (
|
||||
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
|
||||
// Deprecated: use ErrReasoningModelLimitations* instead.
|
||||
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||
)
|
||||
|
||||
var (
|
||||
//nolint:lll
|
||||
ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens")
|
||||
ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||
)
|
||||
|
||||
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
||||
ToolTypeFunction: {},
|
||||
}
|
||||
|
||||
var availableMessageRoleForO1Models = map[string]struct{}{
|
||||
ChatMessageRoleUser: {},
|
||||
ChatMessageRoleAssistant: {},
|
||||
}
|
||||
|
||||
// ReasoningValidator handles validation for o-series model requests.
|
||||
type ReasoningValidator struct{}
|
||||
|
||||
// NewReasoningValidator creates a new validator for o-series models.
|
||||
func NewReasoningValidator() *ReasoningValidator {
|
||||
return &ReasoningValidator{}
|
||||
}
|
||||
|
||||
// Validate performs all validation checks for o-series models.
|
||||
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
|
||||
o1Series := strings.HasPrefix(request.Model, "o1")
|
||||
o3Series := strings.HasPrefix(request.Model, "o3")
|
||||
|
||||
if !o1Series && !o3Series {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := v.validateReasoningModelParams(request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o1Series {
|
||||
if err := v.validateO1Specific(request); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateReasoningModelParams checks reasoning model parameters.
|
||||
func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error {
|
||||
if request.MaxTokens > 0 {
|
||||
return ErrReasoningModelMaxTokensDeprecated
|
||||
}
|
||||
if request.LogProbs {
|
||||
return ErrReasoningModelLimitationsLogprobs
|
||||
}
|
||||
if request.Temperature > 0 && request.Temperature != 1 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
if request.TopP > 0 && request.TopP != 1 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
if request.N > 0 && request.N != 1 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
if request.PresencePenalty > 0 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
if request.FrequencyPenalty > 0 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateO1Specific checks O1-specific limitations.
|
||||
func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error {
|
||||
for _, m := range request.Messages {
|
||||
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
|
||||
return ErrO1BetaLimitationsMessageTypes
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range request.Tools {
|
||||
if _, found := unsupportedToolsForO1Models[t.Type]; found {
|
||||
return ErrO1BetaLimitationsTools
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user