Compare commits
10 Commits
74ed75f291
...
a62919e8c6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a62919e8c6 | ||
|
|
2054db016c | ||
|
|
45aa99607b | ||
|
|
9823a8bbbd | ||
|
|
7a2915a37d | ||
|
|
2a0ff5ac63 | ||
|
|
56a9acf86f | ||
|
|
af5355f5b1 | ||
|
|
c203ca001f | ||
|
|
21fa42c18d |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -17,3 +17,6 @@
|
|||||||
# Auth token for tests
|
# Auth token for tests
|
||||||
.openai-token
|
.openai-token
|
||||||
.idea
|
.idea
|
||||||
|
|
||||||
|
# Generated by tests
|
||||||
|
test.mp3
|
||||||
@@ -206,6 +206,7 @@ linters:
|
|||||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||||
- unconvert # Remove unnecessary type conversions
|
- unconvert # Remove unnecessary type conversions
|
||||||
- unparam # Reports unused function parameters
|
- unparam # Reports unused function parameters
|
||||||
|
- usetesting # Reports uses of functions with replacement inside the testing package
|
||||||
- wastedassign # wastedassign finds wasted assignment statements.
|
- wastedassign # wastedassign finds wasted assignment statements.
|
||||||
- whitespace # Tool for detection of leading and trailing whitespace
|
- whitespace # Tool for detection of leading and trailing whitespace
|
||||||
## you may want to enable
|
## you may want to enable
|
||||||
|
|||||||
@@ -40,12 +40,9 @@ func TestAudio(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
for _, tc := range testcases {
|
for _, tc := range testcases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
path := filepath.Join(dir, "fake.mp3")
|
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||||
test.CreateTestFile(t, path)
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
req := openai.AudioRequest{
|
req := openai.AudioRequest{
|
||||||
@@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
for _, tc := range testcases {
|
for _, tc := range testcases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
path := filepath.Join(dir, "fake.mp3")
|
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||||
test.CreateTestFile(t, path)
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
req := openai.AudioRequest{
|
req := openai.AudioRequest{
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestAudioWithFailingFormBuilder(t *testing.T) {
|
func TestAudioWithFailingFormBuilder(t *testing.T) {
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||||
defer cleanup()
|
|
||||||
path := filepath.Join(dir, "fake.mp3")
|
|
||||||
test.CreateTestFile(t, path)
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
req := AudioRequest{
|
req := AudioRequest{
|
||||||
@@ -63,9 +61,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateFileField(t *testing.T) {
|
func TestCreateFileField(t *testing.T) {
|
||||||
t.Run("createFileField failing file", func(t *testing.T) {
|
t.Run("createFileField failing file", func(t *testing.T) {
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||||
defer cleanup()
|
|
||||||
path := filepath.Join(dir, "fake.mp3")
|
|
||||||
test.CreateTestFile(t, path)
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
req := AudioRequest{
|
req := AudioRequest{
|
||||||
|
|||||||
13
chat.go
13
chat.go
@@ -93,7 +93,7 @@ type ChatMessagePart struct {
|
|||||||
|
|
||||||
type ChatCompletionMessage struct {
|
type ChatCompletionMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart
|
MultiContent []ChatMessagePart
|
||||||
|
|
||||||
@@ -132,7 +132,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
|||||||
|
|
||||||
msg := struct {
|
msg := struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart `json:"-"`
|
MultiContent []ChatMessagePart `json:"-"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
@@ -146,7 +146,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
|||||||
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||||
msg := struct {
|
msg := struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart
|
MultiContent []ChatMessagePart
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
@@ -179,7 +179,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
|||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
// Index is not nil only in chat completion chunk object
|
// Index is not nil only in chat completion chunk object
|
||||||
Index *int `json:"index,omitempty"`
|
Index *int `json:"index,omitempty"`
|
||||||
ID string `json:"id"`
|
ID string `json:"id,omitempty"`
|
||||||
Type ToolType `json:"type"`
|
Type ToolType `json:"type"`
|
||||||
Function FunctionCall `json:"function"`
|
Function FunctionCall `json:"function"`
|
||||||
}
|
}
|
||||||
@@ -258,6 +258,8 @@ type ChatCompletionRequest struct {
|
|||||||
// Store can be set to true to store the output of this completion request for use in distillations and evals.
|
// Store can be set to true to store the output of this completion request for use in distillations and evals.
|
||||||
// https://platform.openai.com/docs/api-reference/chat/create#chat-create-store
|
// https://platform.openai.com/docs/api-reference/chat/create#chat-create-store
|
||||||
Store bool `json:"store,omitempty"`
|
Store bool `json:"store,omitempty"`
|
||||||
|
// Controls effort on reasoning for reasoning models. It can be set to "low", "medium", or "high".
|
||||||
|
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||||
// Metadata to store with the completion.
|
// Metadata to store with the completion.
|
||||||
Metadata map[string]string `json:"metadata,omitempty"`
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -390,7 +392,8 @@ func (c *Client) CreateChatCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = validateRequestForO1Models(request); err != nil {
|
reasoningValidator := NewReasoningValidator()
|
||||||
|
if err = reasoningValidator.Validate(request); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,8 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Stream = true
|
request.Stream = true
|
||||||
if err = validateRequestForO1Models(request); err != nil {
|
reasoningValidator := NewReasoningValidator()
|
||||||
|
if err = reasoningValidator.Validate(request); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -792,6 +792,173 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
|
dataBytes := []byte{}
|
||||||
|
|
||||||
|
//nolint:lll
|
||||||
|
dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...)
|
||||||
|
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||||
|
|
||||||
|
//nolint:lll
|
||||||
|
dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...)
|
||||||
|
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||||
|
|
||||||
|
//nolint:lll
|
||||||
|
dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...)
|
||||||
|
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||||
|
|
||||||
|
//nolint:lll
|
||||||
|
dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...)
|
||||||
|
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||||
|
|
||||||
|
//nolint:lll
|
||||||
|
dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...)
|
||||||
|
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||||
|
|
||||||
|
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
|
||||||
|
|
||||||
|
_, err := w.Write(dataBytes)
|
||||||
|
checks.NoError(t, err, "Write error")
|
||||||
|
})
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionTokens: 2000,
|
||||||
|
Model: openai.O3Mini20250131,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
expectedResponses := []openai.ChatCompletionStreamResponse{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 1729585728,
|
||||||
|
Model: openai.O3Mini20250131,
|
||||||
|
SystemFingerprint: "fp_mini",
|
||||||
|
Choices: []openai.ChatCompletionStreamChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||||
|
Role: "assistant",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "2",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 1729585728,
|
||||||
|
Model: openai.O3Mini20250131,
|
||||||
|
SystemFingerprint: "fp_mini",
|
||||||
|
Choices: []openai.ChatCompletionStreamChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "3",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 1729585728,
|
||||||
|
Model: openai.O3Mini20250131,
|
||||||
|
SystemFingerprint: "fp_mini",
|
||||||
|
Choices: []openai.ChatCompletionStreamChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||||
|
Content: " from",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "4",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 1729585728,
|
||||||
|
Model: openai.O3Mini20250131,
|
||||||
|
SystemFingerprint: "fp_mini",
|
||||||
|
Choices: []openai.ChatCompletionStreamChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||||
|
Content: " O3Mini",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "5",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 1729585728,
|
||||||
|
Model: openai.O3Mini20250131,
|
||||||
|
SystemFingerprint: "fp_mini",
|
||||||
|
Choices: []openai.ChatCompletionStreamChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: openai.ChatCompletionStreamChoiceDelta{},
|
||||||
|
FinishReason: "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for ix, expectedResponse := range expectedResponses {
|
||||||
|
b, _ := json.Marshal(expectedResponse)
|
||||||
|
t.Logf("%d: %s", ix, string(b))
|
||||||
|
|
||||||
|
receivedResponse, streamErr := stream.Recv()
|
||||||
|
checks.NoError(t, streamErr, "stream.Recv() failed")
|
||||||
|
if !compareChatResponses(expectedResponse, receivedResponse) {
|
||||||
|
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, streamErr := stream.Recv()
|
||||||
|
if !errors.Is(streamErr, io.EOF) {
|
||||||
|
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
|
||||||
|
client, _, _ := setupOpenAITestServer()
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
MaxTokens: 100, // This will trigger the validator to fail
|
||||||
|
Model: openai.O3Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
if stream != nil {
|
||||||
|
t.Error("Expected nil stream when validation fails")
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||||
|
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
|
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
|
||||||
if c1.Index != c2.Index {
|
if c1.Index != c2.Index {
|
||||||
return false
|
return false
|
||||||
|
|||||||
155
chat_test.go
155
chat_test.go
@@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
|
|||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: openai.O1Preview,
|
Model: openai.O1Preview,
|
||||||
},
|
},
|
||||||
expectedError: openai.ErrO1MaxTokensDeprecated,
|
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "o1-mini_MaxTokens_deprecated",
|
name: "o1-mini_MaxTokens_deprecated",
|
||||||
@@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
|
|||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: openai.O1Mini,
|
Model: openai.O1Mini,
|
||||||
},
|
},
|
||||||
expectedError: openai.ErrO1MaxTokensDeprecated,
|
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,7 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
|||||||
LogProbs: true,
|
LogProbs: true,
|
||||||
Model: openai.O1Preview,
|
Model: openai.O1Preview,
|
||||||
},
|
},
|
||||||
expectedError: openai.ErrO1BetaLimitationsLogprobs,
|
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "message_type_unsupported",
|
name: "message_type_unsupported",
|
||||||
@@ -155,7 +155,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Temperature: float32(2),
|
Temperature: float32(2),
|
||||||
},
|
},
|
||||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set_top_unsupported",
|
name: "set_top_unsupported",
|
||||||
@@ -173,7 +173,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
|||||||
Temperature: float32(1),
|
Temperature: float32(1),
|
||||||
TopP: float32(0.1),
|
TopP: float32(0.1),
|
||||||
},
|
},
|
||||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set_n_unsupported",
|
name: "set_n_unsupported",
|
||||||
@@ -192,7 +192,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
|||||||
TopP: float32(1),
|
TopP: float32(1),
|
||||||
N: 2,
|
N: 2,
|
||||||
},
|
},
|
||||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set_presence_penalty_unsupported",
|
name: "set_presence_penalty_unsupported",
|
||||||
@@ -209,7 +209,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
|||||||
},
|
},
|
||||||
PresencePenalty: float32(1),
|
PresencePenalty: float32(1),
|
||||||
},
|
},
|
||||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set_frequency_penalty_unsupported",
|
name: "set_frequency_penalty_unsupported",
|
||||||
@@ -226,7 +226,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
|||||||
},
|
},
|
||||||
FrequencyPenalty: float32(0.1),
|
FrequencyPenalty: float32(0.1),
|
||||||
},
|
},
|
||||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := client.CreateChatCompletion(ctx, tt.in)
|
||||||
|
checks.HasError(t, err)
|
||||||
|
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
||||||
|
checks.ErrorIs(t, err, tt.expectedError, msg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in openai.ChatCompletionRequest
|
||||||
|
expectedError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "log_probs_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionTokens: 1000,
|
||||||
|
LogProbs: true,
|
||||||
|
Model: openai.O3Mini,
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_temperature_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionTokens: 1000,
|
||||||
|
Model: openai.O3Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Temperature: float32(2),
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_top_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionTokens: 1000,
|
||||||
|
Model: openai.O3Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Temperature: float32(1),
|
||||||
|
TopP: float32(0.1),
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_n_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionTokens: 1000,
|
||||||
|
Model: openai.O3Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Temperature: float32(1),
|
||||||
|
TopP: float32(1),
|
||||||
|
N: 2,
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_presence_penalty_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionTokens: 1000,
|
||||||
|
Model: openai.O3Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PresencePenalty: float32(1),
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_frequency_penalty_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionTokens: 1000,
|
||||||
|
Model: openai.O3Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FrequencyPenalty: float32(0.1),
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,6 +428,23 @@ func TestO1ModelChatCompletions(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateChatCompletion error")
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestO3ModelChatCompletions(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
Model: openai.O3Mini,
|
||||||
|
MaxCompletionTokens: 1000,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
|
}
|
||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
func TestChatCompletionsWithHeaders(t *testing.T) {
|
func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
@@ -631,7 +768,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
|
|||||||
t.Fatalf("Unexpected error")
|
t.Fatalf("Unexpected error")
|
||||||
}
|
}
|
||||||
res = strings.ReplaceAll(string(s), " ", "")
|
res = strings.ReplaceAll(string(s), " ", "")
|
||||||
if res != `{"role":"user","content":""}` {
|
if res != `{"role":"user"}` {
|
||||||
t.Fatalf("invalid message: %s", string(s))
|
t.Fatalf("invalid message: %s", string(s))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,24 +2,9 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
|
|
||||||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
|
||||||
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
|
||||||
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
|
|
||||||
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
|
|
||||||
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
|
||||||
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
|
||||||
)
|
|
||||||
|
|
||||||
// GPT3 Defines the models provided by OpenAI to use when generating
|
// GPT3 Defines the models provided by OpenAI to use when generating
|
||||||
// completions from OpenAI.
|
// completions from OpenAI.
|
||||||
// GPT3 Models are designed for text-based tasks. For code-specific
|
// GPT3 Models are designed for text-based tasks. For code-specific
|
||||||
@@ -29,6 +14,10 @@ const (
|
|||||||
O1Mini20240912 = "o1-mini-2024-09-12"
|
O1Mini20240912 = "o1-mini-2024-09-12"
|
||||||
O1Preview = "o1-preview"
|
O1Preview = "o1-preview"
|
||||||
O1Preview20240912 = "o1-preview-2024-09-12"
|
O1Preview20240912 = "o1-preview-2024-09-12"
|
||||||
|
O1 = "o1"
|
||||||
|
O120241217 = "o1-2024-12-17"
|
||||||
|
O3Mini = "o3-mini"
|
||||||
|
O3Mini20250131 = "o3-mini-2025-01-31"
|
||||||
GPT432K0613 = "gpt-4-32k-0613"
|
GPT432K0613 = "gpt-4-32k-0613"
|
||||||
GPT432K0314 = "gpt-4-32k-0314"
|
GPT432K0314 = "gpt-4-32k-0314"
|
||||||
GPT432K = "gpt-4-32k"
|
GPT432K = "gpt-4-32k"
|
||||||
@@ -37,6 +26,7 @@ const (
|
|||||||
GPT4o = "gpt-4o"
|
GPT4o = "gpt-4o"
|
||||||
GPT4o20240513 = "gpt-4o-2024-05-13"
|
GPT4o20240513 = "gpt-4o-2024-05-13"
|
||||||
GPT4o20240806 = "gpt-4o-2024-08-06"
|
GPT4o20240806 = "gpt-4o-2024-08-06"
|
||||||
|
GPT4o20241120 = "gpt-4o-2024-11-20"
|
||||||
GPT4oLatest = "chatgpt-4o-latest"
|
GPT4oLatest = "chatgpt-4o-latest"
|
||||||
GPT4oMini = "gpt-4o-mini"
|
GPT4oMini = "gpt-4o-mini"
|
||||||
GPT4oMini20240718 = "gpt-4o-mini-2024-07-18"
|
GPT4oMini20240718 = "gpt-4o-mini-2024-07-18"
|
||||||
@@ -93,21 +83,14 @@ const (
|
|||||||
CodexCodeDavinci001 = "code-davinci-001"
|
CodexCodeDavinci001 = "code-davinci-001"
|
||||||
)
|
)
|
||||||
|
|
||||||
// O1SeriesModels List of new Series of OpenAI models.
|
|
||||||
// Some old api attributes not supported.
|
|
||||||
var O1SeriesModels = map[string]struct{}{
|
|
||||||
O1Mini: {},
|
|
||||||
O1Mini20240912: {},
|
|
||||||
O1Preview: {},
|
|
||||||
O1Preview20240912: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
var disabledModelsForEndpoints = map[string]map[string]bool{
|
var disabledModelsForEndpoints = map[string]map[string]bool{
|
||||||
"/completions": {
|
"/completions": {
|
||||||
O1Mini: true,
|
O1Mini: true,
|
||||||
O1Mini20240912: true,
|
O1Mini20240912: true,
|
||||||
O1Preview: true,
|
O1Preview: true,
|
||||||
O1Preview20240912: true,
|
O1Preview20240912: true,
|
||||||
|
O3Mini: true,
|
||||||
|
O3Mini20250131: true,
|
||||||
GPT3Dot5Turbo: true,
|
GPT3Dot5Turbo: true,
|
||||||
GPT3Dot5Turbo0301: true,
|
GPT3Dot5Turbo0301: true,
|
||||||
GPT3Dot5Turbo0613: true,
|
GPT3Dot5Turbo0613: true,
|
||||||
@@ -119,6 +102,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
|
|||||||
GPT4o: true,
|
GPT4o: true,
|
||||||
GPT4o20240513: true,
|
GPT4o20240513: true,
|
||||||
GPT4o20240806: true,
|
GPT4o20240806: true,
|
||||||
|
GPT4o20241120: true,
|
||||||
GPT4oLatest: true,
|
GPT4oLatest: true,
|
||||||
GPT4oMini: true,
|
GPT4oMini: true,
|
||||||
GPT4oMini20240718: true,
|
GPT4oMini20240718: true,
|
||||||
@@ -179,64 +163,6 @@ func checkPromptType(prompt any) bool {
|
|||||||
return true // all items in the slice are string, so it is []string
|
return true // all items in the slice are string, so it is []string
|
||||||
}
|
}
|
||||||
|
|
||||||
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
|
||||||
ToolTypeFunction: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
var availableMessageRoleForO1Models = map[string]struct{}{
|
|
||||||
ChatMessageRoleUser: {},
|
|
||||||
ChatMessageRoleAssistant: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateRequestForO1Models checks for deprecated fields of OpenAI models.
|
|
||||||
func validateRequestForO1Models(request ChatCompletionRequest) error {
|
|
||||||
if _, found := O1SeriesModels[request.Model]; !found {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.MaxTokens > 0 {
|
|
||||||
return ErrO1MaxTokensDeprecated
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logprobs: not supported.
|
|
||||||
if request.LogProbs {
|
|
||||||
return ErrO1BetaLimitationsLogprobs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Message types: user and assistant messages only, system messages are not supported.
|
|
||||||
for _, m := range request.Messages {
|
|
||||||
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
|
|
||||||
return ErrO1BetaLimitationsMessageTypes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tools: tools, function calling, and response format parameters are not supported
|
|
||||||
for _, t := range request.Tools {
|
|
||||||
if _, found := unsupportedToolsForO1Models[t.Type]; found {
|
|
||||||
return ErrO1BetaLimitationsTools
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
|
|
||||||
if request.Temperature > 0 && request.Temperature != 1 {
|
|
||||||
return ErrO1BetaLimitationsOther
|
|
||||||
}
|
|
||||||
if request.TopP > 0 && request.TopP != 1 {
|
|
||||||
return ErrO1BetaLimitationsOther
|
|
||||||
}
|
|
||||||
if request.N > 0 && request.N != 1 {
|
|
||||||
return ErrO1BetaLimitationsOther
|
|
||||||
}
|
|
||||||
if request.PresencePenalty > 0 {
|
|
||||||
return ErrO1BetaLimitationsOther
|
|
||||||
}
|
|
||||||
if request.FrequencyPenalty > 0 {
|
|
||||||
return ErrO1BetaLimitationsOther
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CompletionRequest represents a request structure for completion API.
|
// CompletionRequest represents a request structure for completion API.
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) {
|
|||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
||||||
|
|
||||||
origin, err := os.Create("image.png")
|
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("open origin file error")
|
t.Fatalf("open origin file error: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
defer origin.Close()
|
||||||
|
|
||||||
mask, err := os.Create("mask.png")
|
mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("open mask file error")
|
t.Fatalf("open mask file error: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
defer mask.Close()
|
||||||
defer func() {
|
|
||||||
mask.Close()
|
|
||||||
origin.Close()
|
|
||||||
os.Remove("mask.png")
|
|
||||||
os.Remove("image.png")
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
||||||
Image: origin,
|
Image: origin,
|
||||||
@@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) {
|
|||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
||||||
|
|
||||||
origin, err := os.Create("image.png")
|
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("open origin file error")
|
t.Fatalf("open origin file error: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
defer origin.Close()
|
||||||
defer func() {
|
|
||||||
origin.Close()
|
|
||||||
os.Remove("image.png")
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
||||||
Image: origin,
|
Image: origin,
|
||||||
@@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) {
|
|||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
|
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
|
||||||
|
|
||||||
origin, err := os.Create("image.png")
|
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("open origin file error")
|
t.Fatalf("open origin file error: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
defer origin.Close()
|
||||||
defer func() {
|
|
||||||
origin.Close()
|
|
||||||
os.Remove("image.png")
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
|
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
|
||||||
Image: origin,
|
Image: origin,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -20,15 +19,11 @@ func (*failingWriter) Write([]byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFormBuilderWithFailingWriter(t *testing.T) {
|
func TestFormBuilderWithFailingWriter(t *testing.T) {
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
file, err := os.CreateTemp(t.TempDir(), "")
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
file, err := os.CreateTemp(dir, "")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error creating tmp file: %v", err)
|
t.Fatalf("Error creating tmp file: %v", err)
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
defer os.Remove(file.Name())
|
|
||||||
|
|
||||||
builder := NewFormBuilder(&failingWriter{})
|
builder := NewFormBuilder(&failingWriter{})
|
||||||
err = builder.CreateFormFile("file", file)
|
err = builder.CreateFormFile("file", file)
|
||||||
@@ -36,15 +31,11 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFormBuilderWithClosedFile(t *testing.T) {
|
func TestFormBuilderWithClosedFile(t *testing.T) {
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
file, err := os.CreateTemp(t.TempDir(), "")
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
file, err := os.CreateTemp(dir, "")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error creating tmp file: %v", err)
|
t.Fatalf("Error creating tmp file: %v", err)
|
||||||
}
|
}
|
||||||
file.Close()
|
file.Close()
|
||||||
defer os.Remove(file.Name())
|
|
||||||
|
|
||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
builder := NewFormBuilder(body)
|
builder := NewFormBuilder(body)
|
||||||
|
|||||||
@@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) {
|
|||||||
file.Close()
|
file.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called.
|
|
||||||
func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
path, err := os.MkdirTemp(os.TempDir(), "")
|
|
||||||
checks.NoError(t, err)
|
|
||||||
|
|
||||||
return path, func() { os.RemoveAll(path) }
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenRoundTripper is a struct that implements the RoundTripper
|
// TokenRoundTripper is a struct that implements the RoundTripper
|
||||||
// interface, specifically to handle the authentication token by adding a token
|
// interface, specifically to handle the authentication token by adding a token
|
||||||
// to the request header. We need this because the API requires that each
|
// to the request header. We need this because the API requires that each
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ type MessageContent struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text *MessageText `json:"text,omitempty"`
|
Text *MessageText `json:"text,omitempty"`
|
||||||
ImageFile *ImageFile `json:"image_file,omitempty"`
|
ImageFile *ImageFile `json:"image_file,omitempty"`
|
||||||
|
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||||
}
|
}
|
||||||
type MessageText struct {
|
type MessageText struct {
|
||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
@@ -51,6 +52,11 @@ type ImageFile struct {
|
|||||||
FileID string `json:"file_id"`
|
FileID string `json:"file_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ImageURL struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Detail string `json:"detail"`
|
||||||
|
}
|
||||||
|
|
||||||
type MessageRequest struct {
|
type MessageRequest struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
|
|||||||
|
|
||||||
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
||||||
// This function approximates based on the rule of thumb stated by OpenAI:
|
// This function approximates based on the rule of thumb stated by OpenAI:
|
||||||
// https://beta.openai.com/tokenizer
|
// https://beta.openai.com/tokenizer/
|
||||||
//
|
//
|
||||||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
|
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
|
||||||
func numTokens(s string) int {
|
func numTokens(s string) int {
|
||||||
return int(float32(len(s)) / 4)
|
return int(float32(len(s)) / 4)
|
||||||
}
|
}
|
||||||
|
|||||||
111
reasoning_validator.go
Normal file
111
reasoning_validator.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Deprecated: use ErrReasoningModelMaxTokensDeprecated instead.
|
||||||
|
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
|
||||||
|
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
||||||
|
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
||||||
|
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
|
||||||
|
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
|
||||||
|
// Deprecated: use ErrReasoningModelLimitations* instead.
|
||||||
|
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||||
|
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//nolint:lll
|
||||||
|
ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens")
|
||||||
|
ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||||
|
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||||
|
)
|
||||||
|
|
||||||
|
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
||||||
|
ToolTypeFunction: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
var availableMessageRoleForO1Models = map[string]struct{}{
|
||||||
|
ChatMessageRoleUser: {},
|
||||||
|
ChatMessageRoleAssistant: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReasoningValidator handles validation for o-series model requests.
|
||||||
|
type ReasoningValidator struct{}
|
||||||
|
|
||||||
|
// NewReasoningValidator creates a new validator for o-series models.
|
||||||
|
func NewReasoningValidator() *ReasoningValidator {
|
||||||
|
return &ReasoningValidator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate performs all validation checks for o-series models.
|
||||||
|
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
|
||||||
|
o1Series := strings.HasPrefix(request.Model, "o1")
|
||||||
|
o3Series := strings.HasPrefix(request.Model, "o3")
|
||||||
|
|
||||||
|
if !o1Series && !o3Series {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.validateReasoningModelParams(request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if o1Series {
|
||||||
|
if err := v.validateO1Specific(request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateReasoningModelParams checks reasoning model parameters.
|
||||||
|
func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error {
|
||||||
|
if request.MaxTokens > 0 {
|
||||||
|
return ErrReasoningModelMaxTokensDeprecated
|
||||||
|
}
|
||||||
|
if request.LogProbs {
|
||||||
|
return ErrReasoningModelLimitationsLogprobs
|
||||||
|
}
|
||||||
|
if request.Temperature > 0 && request.Temperature != 1 {
|
||||||
|
return ErrReasoningModelLimitationsOther
|
||||||
|
}
|
||||||
|
if request.TopP > 0 && request.TopP != 1 {
|
||||||
|
return ErrReasoningModelLimitationsOther
|
||||||
|
}
|
||||||
|
if request.N > 0 && request.N != 1 {
|
||||||
|
return ErrReasoningModelLimitationsOther
|
||||||
|
}
|
||||||
|
if request.PresencePenalty > 0 {
|
||||||
|
return ErrReasoningModelLimitationsOther
|
||||||
|
}
|
||||||
|
if request.FrequencyPenalty > 0 {
|
||||||
|
return ErrReasoningModelLimitationsOther
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateO1Specific checks O1-specific limitations.
|
||||||
|
func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error {
|
||||||
|
for _, m := range request.Messages {
|
||||||
|
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
|
||||||
|
return ErrO1BetaLimitationsMessageTypes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range request.Tools {
|
||||||
|
if _, found := unsupportedToolsForO1Models[t.Type]; found {
|
||||||
|
return ErrO1BetaLimitationsTools
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
1
run.go
1
run.go
@@ -87,6 +87,7 @@ type RunRequest struct {
|
|||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Instructions string `json:"instructions,omitempty"`
|
Instructions string `json:"instructions,omitempty"`
|
||||||
AdditionalInstructions string `json:"additional_instructions,omitempty"`
|
AdditionalInstructions string `json:"additional_instructions,omitempty"`
|
||||||
|
AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"`
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
Metadata map[string]any `json:"metadata,omitempty"`
|
Metadata map[string]any `json:"metadata,omitempty"`
|
||||||
|
|
||||||
|
|||||||
@@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) {
|
|||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||||
path := filepath.Join(dir, "fake.mp3")
|
|
||||||
test.CreateTestFile(t, path)
|
test.CreateTestFile(t, path)
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
// audio endpoints only accept POST requests
|
// audio endpoints only accept POST requests
|
||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
|
|||||||
@@ -32,17 +32,28 @@ type streamReader[T streamable] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (stream *streamReader[T]) Recv() (response T, err error) {
|
func (stream *streamReader[T]) Recv() (response T, err error) {
|
||||||
if stream.isFinished {
|
rawLine, err := stream.RecvRaw()
|
||||||
err = io.EOF
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err = stream.processLines()
|
err = stream.unmarshaler.Unmarshal(rawLine, &response)
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream *streamReader[T]) RecvRaw() ([]byte, error) {
|
||||||
|
if stream.isFinished {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream.processLines()
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit
|
//nolint:gocognit
|
||||||
func (stream *streamReader[T]) processLines() (T, error) {
|
func (stream *streamReader[T]) processLines() ([]byte, error) {
|
||||||
var (
|
var (
|
||||||
emptyMessagesCount uint
|
emptyMessagesCount uint
|
||||||
hasErrorPrefix bool
|
hasErrorPrefix bool
|
||||||
@@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) {
|
|||||||
if readErr != nil || hasErrorPrefix {
|
if readErr != nil || hasErrorPrefix {
|
||||||
respErr := stream.unmarshalError()
|
respErr := stream.unmarshalError()
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return *new(T), fmt.Errorf("error, %w", respErr.Error)
|
return nil, fmt.Errorf("error, %w", respErr.Error)
|
||||||
}
|
}
|
||||||
return *new(T), readErr
|
return nil, readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||||
@@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) {
|
|||||||
}
|
}
|
||||||
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
||||||
if writeErr != nil {
|
if writeErr != nil {
|
||||||
return *new(T), writeErr
|
return nil, writeErr
|
||||||
}
|
}
|
||||||
emptyMessagesCount++
|
emptyMessagesCount++
|
||||||
if emptyMessagesCount > stream.emptyMessagesLimit {
|
if emptyMessagesCount > stream.emptyMessagesLimit {
|
||||||
return *new(T), ErrTooManyEmptyStreamMessages
|
return nil, ErrTooManyEmptyStreamMessages
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
@@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) {
|
|||||||
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
|
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
|
||||||
if string(noPrefixLine) == "[DONE]" {
|
if string(noPrefixLine) == "[DONE]" {
|
||||||
stream.isFinished = true
|
stream.isFinished = true
|
||||||
return *new(T), io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
var response T
|
return noPrefixLine, nil
|
||||||
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
|
|
||||||
if unmarshalErr != nil {
|
|
||||||
return *new(T), unmarshalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
|
|||||||
_, err := stream.Recv()
|
_, err := stream.Recv()
|
||||||
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
|
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStreamReaderRecvRaw(t *testing.T) {
|
||||||
|
stream := &streamReader[ChatCompletionStreamResponse]{
|
||||||
|
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))),
|
||||||
|
}
|
||||||
|
rawLine, err := stream.RecvRaw()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Did not return raw line: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) {
|
||||||
|
t.Fatalf("Did not return raw line: %v", string(rawLine))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user