@@ -9,7 +9,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
@@ -26,7 +26,7 @@ func TestAudio(t *testing.T) {
|
|||||||
|
|
||||||
testcases := []struct {
|
testcases := []struct {
|
||||||
name string
|
name string
|
||||||
createFn func(context.Context, AudioRequest) (AudioResponse, error)
|
createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"transcribe",
|
"transcribe",
|
||||||
@@ -48,7 +48,7 @@ func TestAudio(t *testing.T) {
|
|||||||
path := filepath.Join(dir, "fake.mp3")
|
path := filepath.Join(dir, "fake.mp3")
|
||||||
test.CreateTestFile(t, path)
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
req := AudioRequest{
|
req := openai.AudioRequest{
|
||||||
FilePath: path,
|
FilePath: path,
|
||||||
Model: "whisper-3",
|
Model: "whisper-3",
|
||||||
}
|
}
|
||||||
@@ -57,7 +57,7 @@ func TestAudio(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run(tc.name+" (with reader)", func(t *testing.T) {
|
t.Run(tc.name+" (with reader)", func(t *testing.T) {
|
||||||
req := AudioRequest{
|
req := openai.AudioRequest{
|
||||||
FilePath: "fake.webm",
|
FilePath: "fake.webm",
|
||||||
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
|
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
|
||||||
Model: "whisper-3",
|
Model: "whisper-3",
|
||||||
@@ -76,7 +76,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
|
|||||||
|
|
||||||
testcases := []struct {
|
testcases := []struct {
|
||||||
name string
|
name string
|
||||||
createFn func(context.Context, AudioRequest) (AudioResponse, error)
|
createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"transcribe",
|
"transcribe",
|
||||||
@@ -98,13 +98,13 @@ func TestAudioWithOptionalArgs(t *testing.T) {
|
|||||||
path := filepath.Join(dir, "fake.mp3")
|
path := filepath.Join(dir, "fake.mp3")
|
||||||
test.CreateTestFile(t, path)
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
req := AudioRequest{
|
req := openai.AudioRequest{
|
||||||
FilePath: path,
|
FilePath: path,
|
||||||
Model: "whisper-3",
|
Model: "whisper-3",
|
||||||
Prompt: "用简体中文",
|
Prompt: "用简体中文",
|
||||||
Temperature: 0.5,
|
Temperature: 0.5,
|
||||||
Language: "zh",
|
Language: "zh",
|
||||||
Format: AudioResponseFormatSRT,
|
Format: openai.AudioResponseFormatSRT,
|
||||||
}
|
}
|
||||||
_, err := tc.createFn(ctx, req)
|
_, err := tc.createFn(ctx, req)
|
||||||
checks.NoError(t, err, "audio API error")
|
checks.NoError(t, err, "audio API error")
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var failForField string
|
var failForField string
|
||||||
mockBuilder.mockWriteField = func(fieldname, value string) error {
|
mockBuilder.mockWriteField = func(fieldname, _ string) error {
|
||||||
if fieldname == failForField {
|
if fieldname == failForField {
|
||||||
return mockFailedErr
|
return mockFailedErr
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,28 +10,28 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
||||||
config := DefaultConfig("whatever")
|
config := openai.DefaultConfig("whatever")
|
||||||
config.BaseURL = "http://localhost/v1"
|
config.BaseURL = "http://localhost/v1"
|
||||||
client := NewClientWithConfig(config)
|
client := openai.NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
req := ChatCompletionRequest{
|
req := openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: "ada",
|
Model: "ada",
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, err := client.CreateChatCompletionStream(ctx, req)
|
_, err := client.CreateChatCompletionStream(ctx, req)
|
||||||
if !errors.Is(err, ErrChatCompletionInvalidModel) {
|
if !errors.Is(err, openai.ErrChatCompletionInvalidModel) {
|
||||||
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
|
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -39,7 +39,7 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
|||||||
func TestCreateChatCompletionStream(t *testing.T) {
|
func TestCreateChatCompletionStream(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -61,12 +61,12 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -75,15 +75,15 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
expectedResponses := []ChatCompletionStreamResponse{
|
expectedResponses := []openai.ChatCompletionStreamResponse{
|
||||||
{
|
{
|
||||||
ID: "1",
|
ID: "1",
|
||||||
Object: "completion",
|
Object: "completion",
|
||||||
Created: 1598069254,
|
Created: 1598069254,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Choices: []ChatCompletionStreamChoice{
|
Choices: []openai.ChatCompletionStreamChoice{
|
||||||
{
|
{
|
||||||
Delta: ChatCompletionStreamChoiceDelta{
|
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||||
Content: "response1",
|
Content: "response1",
|
||||||
},
|
},
|
||||||
FinishReason: "max_tokens",
|
FinishReason: "max_tokens",
|
||||||
@@ -94,10 +94,10 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
ID: "2",
|
ID: "2",
|
||||||
Object: "completion",
|
Object: "completion",
|
||||||
Created: 1598069255,
|
Created: 1598069255,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Choices: []ChatCompletionStreamChoice{
|
Choices: []openai.ChatCompletionStreamChoice{
|
||||||
{
|
{
|
||||||
Delta: ChatCompletionStreamChoiceDelta{
|
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||||
Content: "response2",
|
Content: "response2",
|
||||||
},
|
},
|
||||||
FinishReason: "max_tokens",
|
FinishReason: "max_tokens",
|
||||||
@@ -133,7 +133,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
func TestCreateChatCompletionStreamError(t *testing.T) {
|
func TestCreateChatCompletionStreamError(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -156,12 +156,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -173,7 +173,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
_, streamErr := stream.Recv()
|
_, streamErr := stream.Recv()
|
||||||
checks.HasError(t, streamErr, "stream.Recv() did not return error")
|
checks.HasError(t, streamErr, "stream.Recv() did not return error")
|
||||||
|
|
||||||
var apiErr *APIError
|
var apiErr *openai.APIError
|
||||||
if !errors.As(streamErr, &apiErr) {
|
if !errors.As(streamErr, &apiErr) {
|
||||||
t.Errorf("stream.Recv() did not return APIError")
|
t.Errorf("stream.Recv() did not return APIError")
|
||||||
}
|
}
|
||||||
@@ -183,7 +183,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
|
func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||||
|
|
||||||
@@ -196,12 +196,12 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -219,7 +219,7 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
|
|||||||
func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
|
func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
for k, v := range rateLimitHeaders {
|
for k, v := range rateLimitHeaders {
|
||||||
switch val := v.(type) {
|
switch val := v.(type) {
|
||||||
@@ -239,12 +239,12 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -264,7 +264,7 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
|
|||||||
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -276,12 +276,12 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -293,7 +293,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
|||||||
_, streamErr := stream.Recv()
|
_, streamErr := stream.Recv()
|
||||||
checks.HasError(t, streamErr, "stream.Recv() did not return error")
|
checks.HasError(t, streamErr, "stream.Recv() did not return error")
|
||||||
|
|
||||||
var apiErr *APIError
|
var apiErr *openai.APIError
|
||||||
if !errors.As(streamErr, &apiErr) {
|
if !errors.As(streamErr, &apiErr) {
|
||||||
t.Errorf("stream.Recv() did not return APIError")
|
t.Errorf("stream.Recv() did not return APIError")
|
||||||
}
|
}
|
||||||
@@ -303,7 +303,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
|||||||
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(429)
|
w.WriteHeader(429)
|
||||||
|
|
||||||
@@ -317,18 +317,18 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
_, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
_, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Stream: true,
|
Stream: true,
|
||||||
})
|
})
|
||||||
var apiErr *APIError
|
var apiErr *openai.APIError
|
||||||
if !errors.As(err, &apiErr) {
|
if !errors.As(err, &apiErr) {
|
||||||
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
|
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
|
||||||
}
|
}
|
||||||
@@ -345,7 +345,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
client, server, teardown := setupAzureTestServer()
|
client, server, teardown := setupAzureTestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions",
|
server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -355,13 +355,13 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
apiErr := &APIError{}
|
apiErr := &openai.APIError{}
|
||||||
_, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
_, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -387,7 +387,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper funcs.
|
// Helper funcs.
|
||||||
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
|
func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
|
||||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -402,7 +402,7 @@ func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool {
|
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
|
||||||
if c1.Index != c2.Index {
|
if c1.Index != c2.Index {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
140
chat_test.go
140
chat_test.go
@@ -11,7 +11,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
@@ -21,8 +21,7 @@ const (
|
|||||||
xCustomHeaderValue = "test"
|
xCustomHeaderValue = "test"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var rateLimitHeaders = map[string]any{
|
||||||
rateLimitHeaders = map[string]any{
|
|
||||||
"x-ratelimit-limit-requests": 60,
|
"x-ratelimit-limit-requests": 60,
|
||||||
"x-ratelimit-limit-tokens": 150000,
|
"x-ratelimit-limit-tokens": 150000,
|
||||||
"x-ratelimit-remaining-requests": 59,
|
"x-ratelimit-remaining-requests": 59,
|
||||||
@@ -30,40 +29,39 @@ var (
|
|||||||
"x-ratelimit-reset-requests": "1s",
|
"x-ratelimit-reset-requests": "1s",
|
||||||
"x-ratelimit-reset-tokens": "6m0s",
|
"x-ratelimit-reset-tokens": "6m0s",
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
func TestChatCompletionsWrongModel(t *testing.T) {
|
func TestChatCompletionsWrongModel(t *testing.T) {
|
||||||
config := DefaultConfig("whatever")
|
config := openai.DefaultConfig("whatever")
|
||||||
config.BaseURL = "http://localhost/v1"
|
config.BaseURL = "http://localhost/v1"
|
||||||
client := NewClientWithConfig(config)
|
client := openai.NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
req := ChatCompletionRequest{
|
req := openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: "ada",
|
Model: "ada",
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, err := client.CreateChatCompletion(ctx, req)
|
_, err := client.CreateChatCompletion(ctx, req)
|
||||||
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
||||||
checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg)
|
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChatCompletionsWithStream(t *testing.T) {
|
func TestChatCompletionsWithStream(t *testing.T) {
|
||||||
config := DefaultConfig("whatever")
|
config := openai.DefaultConfig("whatever")
|
||||||
config.BaseURL = "http://localhost/v1"
|
config.BaseURL = "http://localhost/v1"
|
||||||
client := NewClientWithConfig(config)
|
client := openai.NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
req := ChatCompletionRequest{
|
req := openai.ChatCompletionRequest{
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
}
|
||||||
_, err := client.CreateChatCompletion(ctx, req)
|
_, err := client.CreateChatCompletion(ctx, req)
|
||||||
checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error")
|
checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
@@ -71,12 +69,12 @@ func TestChatCompletions(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||||
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -89,12 +87,12 @@ func TestChatCompletionsWithHeaders(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||||
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -113,12 +111,12 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||||
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -150,16 +148,16 @@ func TestChatCompletionsFunctions(t *testing.T) {
|
|||||||
t.Run("bytes", func(t *testing.T) {
|
t.Run("bytes", func(t *testing.T) {
|
||||||
//nolint:lll
|
//nolint:lll
|
||||||
msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`)
|
msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`)
|
||||||
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo0613,
|
Model: openai.GPT3Dot5Turbo0613,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Functions: []FunctionDefinition{{
|
Functions: []openai.FunctionDefinition{{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Parameters: &msg,
|
Parameters: &msg,
|
||||||
}},
|
}},
|
||||||
@@ -175,16 +173,16 @@ func TestChatCompletionsFunctions(t *testing.T) {
|
|||||||
Count: 2,
|
Count: 2,
|
||||||
Words: []string{"hello", "world"},
|
Words: []string{"hello", "world"},
|
||||||
}
|
}
|
||||||
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo0613,
|
Model: openai.GPT3Dot5Turbo0613,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Functions: []FunctionDefinition{{
|
Functions: []openai.FunctionDefinition{{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Parameters: &msg,
|
Parameters: &msg,
|
||||||
}},
|
}},
|
||||||
@@ -192,16 +190,16 @@ func TestChatCompletionsFunctions(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
||||||
})
|
})
|
||||||
t.Run("JSONSchemaDefinition", func(t *testing.T) {
|
t.Run("JSONSchemaDefinition", func(t *testing.T) {
|
||||||
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo0613,
|
Model: openai.GPT3Dot5Turbo0613,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Functions: []FunctionDefinition{{
|
Functions: []openai.FunctionDefinition{{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Parameters: &jsonschema.Definition{
|
Parameters: &jsonschema.Definition{
|
||||||
Type: jsonschema.Object,
|
Type: jsonschema.Object,
|
||||||
@@ -229,16 +227,16 @@ func TestChatCompletionsFunctions(t *testing.T) {
|
|||||||
})
|
})
|
||||||
t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) {
|
t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) {
|
||||||
// this is a compatibility check
|
// this is a compatibility check
|
||||||
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo0613,
|
Model: openai.GPT3Dot5Turbo0613,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Functions: []FunctionDefine{{
|
Functions: []openai.FunctionDefine{{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Parameters: &jsonschema.Definition{
|
Parameters: &jsonschema.Definition{
|
||||||
Type: jsonschema.Object,
|
Type: jsonschema.Object,
|
||||||
@@ -271,12 +269,12 @@ func TestAzureChatCompletions(t *testing.T) {
|
|||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)
|
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)
|
||||||
|
|
||||||
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -293,12 +291,12 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
var completionReq ChatCompletionRequest
|
var completionReq openai.ChatCompletionRequest
|
||||||
if completionReq, err = getChatCompletionBody(r); err != nil {
|
if completionReq, err = getChatCompletionBody(r); err != nil {
|
||||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res := ChatCompletionResponse{
|
res := openai.ChatCompletionResponse{
|
||||||
ID: strconv.Itoa(int(time.Now().Unix())),
|
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||||
Object: "test-object",
|
Object: "test-object",
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
@@ -323,11 +321,11 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Choices = append(res.Choices, ChatCompletionChoice{
|
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
||||||
Message: ChatCompletionMessage{
|
Message: openai.ChatCompletionMessage{
|
||||||
Role: ChatMessageRoleFunction,
|
Role: openai.ChatMessageRoleFunction,
|
||||||
// this is valid json so it should be fine
|
// this is valid json so it should be fine
|
||||||
FunctionCall: &FunctionCall{
|
FunctionCall: &openai.FunctionCall{
|
||||||
Name: completionReq.Functions[0].Name,
|
Name: completionReq.Functions[0].Name,
|
||||||
Arguments: string(fcb),
|
Arguments: string(fcb),
|
||||||
},
|
},
|
||||||
@@ -339,9 +337,9 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
// generate a random string of length completionReq.Length
|
// generate a random string of length completionReq.Length
|
||||||
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
||||||
|
|
||||||
res.Choices = append(res.Choices, ChatCompletionChoice{
|
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
||||||
Message: ChatCompletionMessage{
|
Message: openai.ChatCompletionMessage{
|
||||||
Role: ChatMessageRoleAssistant,
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
Content: completionStr,
|
Content: completionStr,
|
||||||
},
|
},
|
||||||
Index: i,
|
Index: i,
|
||||||
@@ -349,7 +347,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
||||||
completionTokens := completionReq.MaxTokens * n
|
completionTokens := completionReq.MaxTokens * n
|
||||||
res.Usage = Usage{
|
res.Usage = openai.Usage{
|
||||||
PromptTokens: inputTokens,
|
PromptTokens: inputTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: inputTokens + completionTokens,
|
TotalTokens: inputTokens + completionTokens,
|
||||||
@@ -368,23 +366,23 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getChatCompletionBody Returns the body of the request to create a completion.
|
// getChatCompletionBody Returns the body of the request to create a completion.
|
||||||
func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) {
|
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
||||||
completion := ChatCompletionRequest{}
|
completion := openai.ChatCompletionRequest{}
|
||||||
// read the request body
|
// read the request body
|
||||||
reqBody, err := io.ReadAll(r.Body)
|
reqBody, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ChatCompletionRequest{}, err
|
return openai.ChatCompletionRequest{}, err
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(reqBody, &completion)
|
err = json.Unmarshal(reqBody, &completion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ChatCompletionRequest{}, err
|
return openai.ChatCompletionRequest{}, err
|
||||||
}
|
}
|
||||||
return completion, nil
|
return completion, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFinishReason(t *testing.T) {
|
func TestFinishReason(t *testing.T) {
|
||||||
c := &ChatCompletionChoice{
|
c := &openai.ChatCompletionChoice{
|
||||||
FinishReason: FinishReasonNull,
|
FinishReason: openai.FinishReasonNull,
|
||||||
}
|
}
|
||||||
resBytes, _ := json.Marshal(c)
|
resBytes, _ := json.Marshal(c)
|
||||||
if !strings.Contains(string(resBytes), `"finish_reason":null`) {
|
if !strings.Contains(string(resBytes), `"finish_reason":null`) {
|
||||||
@@ -398,11 +396,11 @@ func TestFinishReason(t *testing.T) {
|
|||||||
t.Error("null should not be quoted")
|
t.Error("null should not be quoted")
|
||||||
}
|
}
|
||||||
|
|
||||||
otherReasons := []FinishReason{
|
otherReasons := []openai.FinishReason{
|
||||||
FinishReasonStop,
|
openai.FinishReasonStop,
|
||||||
FinishReasonLength,
|
openai.FinishReasonLength,
|
||||||
FinishReasonFunctionCall,
|
openai.FinishReasonFunctionCall,
|
||||||
FinishReasonContentFilter,
|
openai.FinishReasonContentFilter,
|
||||||
}
|
}
|
||||||
for _, r := range otherReasons {
|
for _, r := range otherReasons {
|
||||||
c.FinishReason = r
|
c.FinishReason = r
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -14,33 +11,36 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCompletionsWrongModel(t *testing.T) {
|
func TestCompletionsWrongModel(t *testing.T) {
|
||||||
config := DefaultConfig("whatever")
|
config := openai.DefaultConfig("whatever")
|
||||||
config.BaseURL = "http://localhost/v1"
|
config.BaseURL = "http://localhost/v1"
|
||||||
client := NewClientWithConfig(config)
|
client := openai.NewClientWithConfig(config)
|
||||||
|
|
||||||
_, err := client.CreateCompletion(
|
_, err := client.CreateCompletion(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
CompletionRequest{
|
openai.CompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if !errors.Is(err, ErrCompletionUnsupportedModel) {
|
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
|
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompletionWithStream(t *testing.T) {
|
func TestCompletionWithStream(t *testing.T) {
|
||||||
config := DefaultConfig("whatever")
|
config := openai.DefaultConfig("whatever")
|
||||||
client := NewClientWithConfig(config)
|
client := openai.NewClientWithConfig(config)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
req := CompletionRequest{Stream: true}
|
req := openai.CompletionRequest{Stream: true}
|
||||||
_, err := client.CreateCompletion(ctx, req)
|
_, err := client.CreateCompletion(ctx, req)
|
||||||
if !errors.Is(err, ErrCompletionStreamNotSupported) {
|
if !errors.Is(err, openai.ErrCompletionStreamNotSupported) {
|
||||||
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
|
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -50,7 +50,7 @@ func TestCompletions(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
|
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
|
||||||
req := CompletionRequest{
|
req := openai.CompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: "ada",
|
Model: "ada",
|
||||||
Prompt: "Lorem ipsum",
|
Prompt: "Lorem ipsum",
|
||||||
@@ -68,12 +68,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
var completionReq CompletionRequest
|
var completionReq openai.CompletionRequest
|
||||||
if completionReq, err = getCompletionBody(r); err != nil {
|
if completionReq, err = getCompletionBody(r); err != nil {
|
||||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res := CompletionResponse{
|
res := openai.CompletionResponse{
|
||||||
ID: strconv.Itoa(int(time.Now().Unix())),
|
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||||
Object: "test-object",
|
Object: "test-object",
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
@@ -93,14 +93,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
if completionReq.Echo {
|
if completionReq.Echo {
|
||||||
completionStr = completionReq.Prompt.(string) + completionStr
|
completionStr = completionReq.Prompt.(string) + completionStr
|
||||||
}
|
}
|
||||||
res.Choices = append(res.Choices, CompletionChoice{
|
res.Choices = append(res.Choices, openai.CompletionChoice{
|
||||||
Text: completionStr,
|
Text: completionStr,
|
||||||
Index: i,
|
Index: i,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
inputTokens := numTokens(completionReq.Prompt.(string)) * n
|
inputTokens := numTokens(completionReq.Prompt.(string)) * n
|
||||||
completionTokens := completionReq.MaxTokens * n
|
completionTokens := completionReq.MaxTokens * n
|
||||||
res.Usage = Usage{
|
res.Usage = openai.Usage{
|
||||||
PromptTokens: inputTokens,
|
PromptTokens: inputTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: inputTokens + completionTokens,
|
TotalTokens: inputTokens + completionTokens,
|
||||||
@@ -110,16 +110,16 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getCompletionBody Returns the body of the request to create a completion.
|
// getCompletionBody Returns the body of the request to create a completion.
|
||||||
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
|
func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
|
||||||
completion := CompletionRequest{}
|
completion := openai.CompletionRequest{}
|
||||||
// read the request body
|
// read the request body
|
||||||
reqBody, err := io.ReadAll(r.Body)
|
reqBody, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return CompletionRequest{}, err
|
return openai.CompletionRequest{}, err
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(reqBody, &completion)
|
err = json.Unmarshal(reqBody, &completion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return CompletionRequest{}, err
|
return openai.CompletionRequest{}, err
|
||||||
}
|
}
|
||||||
return completion, nil
|
return completion, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetAzureDeploymentByModel(t *testing.T) {
|
func TestGetAzureDeploymentByModel(t *testing.T) {
|
||||||
@@ -49,7 +49,7 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
|
|||||||
|
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run(c.Model, func(t *testing.T) {
|
t.Run(c.Model, func(t *testing.T) {
|
||||||
conf := DefaultAzureConfig("", "https://test.openai.azure.com/")
|
conf := openai.DefaultAzureConfig("", "https://test.openai.azure.com/")
|
||||||
if c.AzureModelMapperFunc != nil {
|
if c.AzureModelMapperFunc != nil {
|
||||||
conf.AzureModelMapperFunc = c.AzureModelMapperFunc
|
conf.AzureModelMapperFunc = c.AzureModelMapperFunc
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -11,6 +8,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestEdits Tests the edits endpoint of the API using the mocked server.
|
// TestEdits Tests the edits endpoint of the API using the mocked server.
|
||||||
@@ -20,7 +20,7 @@ func TestEdits(t *testing.T) {
|
|||||||
server.RegisterHandler("/v1/edits", handleEditEndpoint)
|
server.RegisterHandler("/v1/edits", handleEditEndpoint)
|
||||||
// create an edit request
|
// create an edit request
|
||||||
model := "ada"
|
model := "ada"
|
||||||
editReq := EditsRequest{
|
editReq := openai.EditsRequest{
|
||||||
Model: &model,
|
Model: &model,
|
||||||
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " +
|
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " +
|
||||||
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" +
|
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" +
|
||||||
@@ -45,14 +45,14 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
var editReq EditsRequest
|
var editReq openai.EditsRequest
|
||||||
editReq, err = getEditBody(r)
|
editReq, err = getEditBody(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// create a response
|
// create a response
|
||||||
res := EditsResponse{
|
res := openai.EditsResponse{
|
||||||
Object: "test-object",
|
Object: "test-object",
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
}
|
}
|
||||||
@@ -62,12 +62,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
completionTokens := int(float32(len(editString))/4) * editReq.N
|
completionTokens := int(float32(len(editString))/4) * editReq.N
|
||||||
for i := 0; i < editReq.N; i++ {
|
for i := 0; i < editReq.N; i++ {
|
||||||
// instruction will be hidden and only seen by OpenAI
|
// instruction will be hidden and only seen by OpenAI
|
||||||
res.Choices = append(res.Choices, EditsChoice{
|
res.Choices = append(res.Choices, openai.EditsChoice{
|
||||||
Text: editReq.Input + editString,
|
Text: editReq.Input + editString,
|
||||||
Index: i,
|
Index: i,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
res.Usage = Usage{
|
res.Usage = openai.Usage{
|
||||||
PromptTokens: inputTokens,
|
PromptTokens: inputTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: inputTokens + completionTokens,
|
TotalTokens: inputTokens + completionTokens,
|
||||||
@@ -77,16 +77,16 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getEditBody Returns the body of the request to create an edit.
|
// getEditBody Returns the body of the request to create an edit.
|
||||||
func getEditBody(r *http.Request) (EditsRequest, error) {
|
func getEditBody(r *http.Request) (openai.EditsRequest, error) {
|
||||||
edit := EditsRequest{}
|
edit := openai.EditsRequest{}
|
||||||
// read the request body
|
// read the request body
|
||||||
reqBody, err := io.ReadAll(r.Body)
|
reqBody, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return EditsRequest{}, err
|
return openai.EditsRequest{}, err
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(reqBody, &edit)
|
err = json.Unmarshal(reqBody, &edit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return EditsRequest{}, err
|
return openai.EditsRequest{}, err
|
||||||
}
|
}
|
||||||
return edit, nil
|
return edit, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,32 +11,32 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEmbedding(t *testing.T) {
|
func TestEmbedding(t *testing.T) {
|
||||||
embeddedModels := []EmbeddingModel{
|
embeddedModels := []openai.EmbeddingModel{
|
||||||
AdaSimilarity,
|
openai.AdaSimilarity,
|
||||||
BabbageSimilarity,
|
openai.BabbageSimilarity,
|
||||||
CurieSimilarity,
|
openai.CurieSimilarity,
|
||||||
DavinciSimilarity,
|
openai.DavinciSimilarity,
|
||||||
AdaSearchDocument,
|
openai.AdaSearchDocument,
|
||||||
AdaSearchQuery,
|
openai.AdaSearchQuery,
|
||||||
BabbageSearchDocument,
|
openai.BabbageSearchDocument,
|
||||||
BabbageSearchQuery,
|
openai.BabbageSearchQuery,
|
||||||
CurieSearchDocument,
|
openai.CurieSearchDocument,
|
||||||
CurieSearchQuery,
|
openai.CurieSearchQuery,
|
||||||
DavinciSearchDocument,
|
openai.DavinciSearchDocument,
|
||||||
DavinciSearchQuery,
|
openai.DavinciSearchQuery,
|
||||||
AdaCodeSearchCode,
|
openai.AdaCodeSearchCode,
|
||||||
AdaCodeSearchText,
|
openai.AdaCodeSearchText,
|
||||||
BabbageCodeSearchCode,
|
openai.BabbageCodeSearchCode,
|
||||||
BabbageCodeSearchText,
|
openai.BabbageCodeSearchText,
|
||||||
}
|
}
|
||||||
for _, model := range embeddedModels {
|
for _, model := range embeddedModels {
|
||||||
// test embedding request with strings (simple embedding request)
|
// test embedding request with strings (simple embedding request)
|
||||||
embeddingReq := EmbeddingRequest{
|
embeddingReq := openai.EmbeddingRequest{
|
||||||
Input: []string{
|
Input: []string{
|
||||||
"The food was delicious and the waiter",
|
"The food was delicious and the waiter",
|
||||||
"Other examples of embedding request",
|
"Other examples of embedding request",
|
||||||
@@ -52,7 +52,7 @@ func TestEmbedding(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// test embedding request with strings
|
// test embedding request with strings
|
||||||
embeddingReqStrings := EmbeddingRequestStrings{
|
embeddingReqStrings := openai.EmbeddingRequestStrings{
|
||||||
Input: []string{
|
Input: []string{
|
||||||
"The food was delicious and the waiter",
|
"The food was delicious and the waiter",
|
||||||
"Other examples of embedding request",
|
"Other examples of embedding request",
|
||||||
@@ -66,7 +66,7 @@ func TestEmbedding(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// test embedding request with tokens
|
// test embedding request with tokens
|
||||||
embeddingReqTokens := EmbeddingRequestTokens{
|
embeddingReqTokens := openai.EmbeddingRequestTokens{
|
||||||
Input: [][]int{
|
Input: [][]int{
|
||||||
{464, 2057, 373, 12625, 290, 262, 46612},
|
{464, 2057, 373, 12625, 290, 262, 46612},
|
||||||
{6395, 6096, 286, 11525, 12083, 2581},
|
{6395, 6096, 286, 11525, 12083, 2581},
|
||||||
@@ -82,17 +82,17 @@ func TestEmbedding(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbeddingModel(t *testing.T) {
|
func TestEmbeddingModel(t *testing.T) {
|
||||||
var em EmbeddingModel
|
var em openai.EmbeddingModel
|
||||||
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
|
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
|
||||||
checks.NoError(t, err, "Could not marshal embedding model")
|
checks.NoError(t, err, "Could not marshal embedding model")
|
||||||
|
|
||||||
if em != AdaSimilarity {
|
if em != openai.AdaSimilarity {
|
||||||
t.Errorf("Model is not equal to AdaSimilarity")
|
t.Errorf("Model is not equal to AdaSimilarity")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = em.UnmarshalText([]byte("some-non-existent-model"))
|
err = em.UnmarshalText([]byte("some-non-existent-model"))
|
||||||
checks.NoError(t, err, "Could not marshal embedding model")
|
checks.NoError(t, err, "Could not marshal embedding model")
|
||||||
if em != Unknown {
|
if em != openai.Unknown {
|
||||||
t.Errorf("Model is not equal to Unknown")
|
t.Errorf("Model is not equal to Unknown")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,12 +101,12 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
sampleEmbeddings := []Embedding{
|
sampleEmbeddings := []openai.Embedding{
|
||||||
{Embedding: []float32{1.23, 4.56, 7.89}},
|
{Embedding: []float32{1.23, 4.56, 7.89}},
|
||||||
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
||||||
}
|
}
|
||||||
|
|
||||||
sampleBase64Embeddings := []Base64Embedding{
|
sampleBase64Embeddings := []openai.Base64Embedding{
|
||||||
{Embedding: "pHCdP4XrkUDhevxA"},
|
{Embedding: "pHCdP4XrkUDhevxA"},
|
||||||
{Embedding: "/1jku0G/rLvA/EI8"},
|
{Embedding: "/1jku0G/rLvA/EI8"},
|
||||||
}
|
}
|
||||||
@@ -115,7 +115,7 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
|||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
var req struct {
|
var req struct {
|
||||||
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"`
|
EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"`
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
}
|
}
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
@@ -125,16 +125,16 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
|||||||
case req.User == "invalid":
|
case req.User == "invalid":
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
case req.EncodingFormat == EmbeddingEncodingFormatBase64:
|
case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64:
|
||||||
resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings})
|
resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings})
|
||||||
default:
|
default:
|
||||||
resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings})
|
resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings})
|
||||||
}
|
}
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
// test create embeddings with strings (simple embedding request)
|
// test create embeddings with strings (simple embedding request)
|
||||||
res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
|
res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{})
|
||||||
checks.NoError(t, err, "CreateEmbeddings error")
|
checks.NoError(t, err, "CreateEmbeddings error")
|
||||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||||
@@ -143,8 +143,8 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
|||||||
// test create embeddings with strings (simple embedding request)
|
// test create embeddings with strings (simple embedding request)
|
||||||
res, err = client.CreateEmbeddings(
|
res, err = client.CreateEmbeddings(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
EmbeddingRequest{
|
openai.EmbeddingRequest{
|
||||||
EncodingFormat: EmbeddingEncodingFormatBase64,
|
EncodingFormat: openai.EmbeddingEncodingFormatBase64,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
checks.NoError(t, err, "CreateEmbeddings error")
|
checks.NoError(t, err, "CreateEmbeddings error")
|
||||||
@@ -153,23 +153,23 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// test create embeddings with strings
|
// test create embeddings with strings
|
||||||
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
|
res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{})
|
||||||
checks.NoError(t, err, "CreateEmbeddings strings error")
|
checks.NoError(t, err, "CreateEmbeddings strings error")
|
||||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// test create embeddings with tokens
|
// test create embeddings with tokens
|
||||||
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
|
res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{})
|
||||||
checks.NoError(t, err, "CreateEmbeddings tokens error")
|
checks.NoError(t, err, "CreateEmbeddings tokens error")
|
||||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// test failed sendRequest
|
// test failed sendRequest
|
||||||
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{
|
_, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
|
||||||
User: "invalid",
|
User: "invalid",
|
||||||
EncodingFormat: EmbeddingEncodingFormatBase64,
|
EncodingFormat: openai.EmbeddingEncodingFormatBase64,
|
||||||
})
|
})
|
||||||
checks.HasError(t, err, "CreateEmbeddings error")
|
checks.HasError(t, err, "CreateEmbeddings error")
|
||||||
}
|
}
|
||||||
@@ -177,26 +177,26 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
|||||||
func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Object string
|
Object string
|
||||||
Data []Base64Embedding
|
Data []openai.Base64Embedding
|
||||||
Model EmbeddingModel
|
Model openai.EmbeddingModel
|
||||||
Usage Usage
|
Usage openai.Usage
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
fields fields
|
fields fields
|
||||||
want EmbeddingResponse
|
want openai.EmbeddingResponse
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "test embedding response base64 to embedding response",
|
name: "test embedding response base64 to embedding response",
|
||||||
fields: fields{
|
fields: fields{
|
||||||
Data: []Base64Embedding{
|
Data: []openai.Base64Embedding{
|
||||||
{Embedding: "pHCdP4XrkUDhevxA"},
|
{Embedding: "pHCdP4XrkUDhevxA"},
|
||||||
{Embedding: "/1jku0G/rLvA/EI8"},
|
{Embedding: "/1jku0G/rLvA/EI8"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: EmbeddingResponse{
|
want: openai.EmbeddingResponse{
|
||||||
Data: []Embedding{
|
Data: []openai.Embedding{
|
||||||
{Embedding: []float32{1.23, 4.56, 7.89}},
|
{Embedding: []float32{1.23, 4.56, 7.89}},
|
||||||
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
||||||
},
|
},
|
||||||
@@ -206,19 +206,19 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Invalid embedding",
|
name: "Invalid embedding",
|
||||||
fields: fields{
|
fields: fields{
|
||||||
Data: []Base64Embedding{
|
Data: []openai.Base64Embedding{
|
||||||
{
|
{
|
||||||
Embedding: "----",
|
Embedding: "----",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: EmbeddingResponse{},
|
want: openai.EmbeddingResponse{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
r := &EmbeddingResponseBase64{
|
r := &openai.EmbeddingResponseBase64{
|
||||||
Object: tt.fields.Object,
|
Object: tt.fields.Object,
|
||||||
Data: tt.fields.Data,
|
Data: tt.fields.Data,
|
||||||
Model: tt.fields.Model,
|
Model: tt.fields.Model,
|
||||||
@@ -237,8 +237,8 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDotProduct(t *testing.T) {
|
func TestDotProduct(t *testing.T) {
|
||||||
v1 := &Embedding{Embedding: []float32{1, 2, 3}}
|
v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}}
|
||||||
v2 := &Embedding{Embedding: []float32{2, 4, 6}}
|
v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}}
|
||||||
expected := float32(28.0)
|
expected := float32(28.0)
|
||||||
|
|
||||||
result, err := v1.DotProduct(v2)
|
result, err := v1.DotProduct(v2)
|
||||||
@@ -250,8 +250,8 @@ func TestDotProduct(t *testing.T) {
|
|||||||
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
|
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
|
v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}}
|
||||||
v2 = &Embedding{Embedding: []float32{0, 1, 0}}
|
v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}}
|
||||||
expected = float32(0.0)
|
expected = float32(0.0)
|
||||||
|
|
||||||
result, err = v1.DotProduct(v2)
|
result, err = v1.DotProduct(v2)
|
||||||
@@ -264,10 +264,10 @@ func TestDotProduct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test for VectorLengthMismatchError
|
// Test for VectorLengthMismatchError
|
||||||
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
|
v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}}
|
||||||
v2 = &Embedding{Embedding: []float32{0, 1}}
|
v2 = &openai.Embedding{Embedding: []float32{0, 1}}
|
||||||
_, err = v1.DotProduct(v2)
|
_, err = v1.DotProduct(v2)
|
||||||
if !errors.Is(err, ErrVectorLengthMismatch) {
|
if !errors.Is(err, openai.ErrVectorLengthMismatch) {
|
||||||
t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
|
t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,8 +15,8 @@ import (
|
|||||||
func TestGetEngine(t *testing.T) {
|
func TestGetEngine(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(Engine{})
|
resBytes, _ := json.Marshal(openai.Engine{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
})
|
})
|
||||||
_, err := client.GetEngine(context.Background(), "text-davinci-003")
|
_, err := client.GetEngine(context.Background(), "text-davinci-003")
|
||||||
@@ -27,8 +27,8 @@ func TestGetEngine(t *testing.T) {
|
|||||||
func TestListEngines(t *testing.T) {
|
func TestListEngines(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(EnginesList{})
|
resBytes, _ := json.Marshal(openai.EnginesList{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
})
|
})
|
||||||
_, err := client.ListEngines(context.Background())
|
_, err := client.ListEngines(context.Background())
|
||||||
@@ -38,7 +38,7 @@ func TestListEngines(t *testing.T) {
|
|||||||
func TestListEnginesReturnError(t *testing.T) {
|
func TestListEnginesReturnError(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.WriteHeader(http.StatusTeapot)
|
w.WriteHeader(http.StatusTeapot)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
||||||
@@ -14,7 +14,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
response string
|
response string
|
||||||
hasError bool
|
hasError bool
|
||||||
checkFunc func(t *testing.T, apiErr APIError)
|
checkFunc func(t *testing.T, apiErr openai.APIError)
|
||||||
}
|
}
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
// testcase for message field
|
// testcase for message field
|
||||||
@@ -22,7 +22,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the message is string",
|
name: "parse succeeds when the message is string",
|
||||||
response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`,
|
response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorMessage(t, apiErr, "foo")
|
assertAPIErrorMessage(t, apiErr, "foo")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -30,7 +30,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the message is array with single item",
|
name: "parse succeeds when the message is array with single item",
|
||||||
response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`,
|
response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorMessage(t, apiErr, "foo")
|
assertAPIErrorMessage(t, apiErr, "foo")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -38,7 +38,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the message is array with multiple items",
|
name: "parse succeeds when the message is array with multiple items",
|
||||||
response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`,
|
response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorMessage(t, apiErr, "foo, bar, baz")
|
assertAPIErrorMessage(t, apiErr, "foo, bar, baz")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -46,7 +46,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the message is empty array",
|
name: "parse succeeds when the message is empty array",
|
||||||
response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`,
|
response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorMessage(t, apiErr, "")
|
assertAPIErrorMessage(t, apiErr, "")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -54,7 +54,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the message is null",
|
name: "parse succeeds when the message is null",
|
||||||
response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`,
|
response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorMessage(t, apiErr, "")
|
assertAPIErrorMessage(t, apiErr, "")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -89,23 +89,23 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorInnerError(t, apiErr, &InnerError{
|
assertAPIErrorInnerError(t, apiErr, &openai.InnerError{
|
||||||
Code: "ResponsibleAIPolicyViolation",
|
Code: "ResponsibleAIPolicyViolation",
|
||||||
ContentFilterResults: ContentFilterResults{
|
ContentFilterResults: openai.ContentFilterResults{
|
||||||
Hate: Hate{
|
Hate: openai.Hate{
|
||||||
Filtered: false,
|
Filtered: false,
|
||||||
Severity: "safe",
|
Severity: "safe",
|
||||||
},
|
},
|
||||||
SelfHarm: SelfHarm{
|
SelfHarm: openai.SelfHarm{
|
||||||
Filtered: false,
|
Filtered: false,
|
||||||
Severity: "safe",
|
Severity: "safe",
|
||||||
},
|
},
|
||||||
Sexual: Sexual{
|
Sexual: openai.Sexual{
|
||||||
Filtered: true,
|
Filtered: true,
|
||||||
Severity: "medium",
|
Severity: "medium",
|
||||||
},
|
},
|
||||||
Violence: Violence{
|
Violence: openai.Violence{
|
||||||
Filtered: false,
|
Filtered: false,
|
||||||
Severity: "safe",
|
Severity: "safe",
|
||||||
},
|
},
|
||||||
@@ -117,16 +117,16 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the innerError is empty (Azure Openai)",
|
name: "parse succeeds when the innerError is empty (Azure Openai)",
|
||||||
response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`,
|
response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorInnerError(t, apiErr, &InnerError{})
|
assertAPIErrorInnerError(t, apiErr, &openai.InnerError{})
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)",
|
name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)",
|
||||||
response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`,
|
response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`,
|
||||||
hasError: true,
|
hasError: true,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorInnerError(t, apiErr, &InnerError{})
|
assertAPIErrorInnerError(t, apiErr, &openai.InnerError{})
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -159,7 +159,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the code is int",
|
name: "parse succeeds when the code is int",
|
||||||
response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
|
response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorCode(t, apiErr, 418)
|
assertAPIErrorCode(t, apiErr, 418)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -167,7 +167,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the code is string",
|
name: "parse succeeds when the code is string",
|
||||||
response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
|
response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorCode(t, apiErr, "teapot")
|
assertAPIErrorCode(t, apiErr, "teapot")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -175,7 +175,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse succeeds when the code is not exists",
|
name: "parse succeeds when the code is not exists",
|
||||||
response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
|
response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
|
||||||
hasError: false,
|
hasError: false,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorCode(t, apiErr, nil)
|
assertAPIErrorCode(t, apiErr, nil)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -196,7 +196,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
name: "parse failed when the response is invalid json",
|
name: "parse failed when the response is invalid json",
|
||||||
response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
|
response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
|
||||||
hasError: true,
|
hasError: true,
|
||||||
checkFunc: func(t *testing.T, apiErr APIError) {
|
checkFunc: func(t *testing.T, apiErr openai.APIError) {
|
||||||
assertAPIErrorCode(t, apiErr, nil)
|
assertAPIErrorCode(t, apiErr, nil)
|
||||||
assertAPIErrorMessage(t, apiErr, "")
|
assertAPIErrorMessage(t, apiErr, "")
|
||||||
assertAPIErrorParam(t, apiErr, nil)
|
assertAPIErrorParam(t, apiErr, nil)
|
||||||
@@ -206,7 +206,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
var apiErr APIError
|
var apiErr openai.APIError
|
||||||
err := apiErr.UnmarshalJSON([]byte(tc.response))
|
err := apiErr.UnmarshalJSON([]byte(tc.response))
|
||||||
if (err != nil) != tc.hasError {
|
if (err != nil) != tc.hasError {
|
||||||
t.Errorf("Unexpected error: %v", err)
|
t.Errorf("Unexpected error: %v", err)
|
||||||
@@ -218,19 +218,19 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) {
|
func assertAPIErrorMessage(t *testing.T, apiErr openai.APIError, expected string) {
|
||||||
if apiErr.Message != expected {
|
if apiErr.Message != expected {
|
||||||
t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected)
|
t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) {
|
func assertAPIErrorInnerError(t *testing.T, apiErr openai.APIError, expected interface{}) {
|
||||||
if !reflect.DeepEqual(apiErr.InnerError, expected) {
|
if !reflect.DeepEqual(apiErr.InnerError, expected) {
|
||||||
t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected)
|
t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) {
|
func assertAPIErrorCode(t *testing.T, apiErr openai.APIError, expected interface{}) {
|
||||||
switch v := apiErr.Code.(type) {
|
switch v := apiErr.Code.(type) {
|
||||||
case int:
|
case int:
|
||||||
if v != expected {
|
if v != expected {
|
||||||
@@ -246,25 +246,25 @@ func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) {
|
func assertAPIErrorParam(t *testing.T, apiErr openai.APIError, expected *string) {
|
||||||
if apiErr.Param != expected {
|
if apiErr.Param != expected {
|
||||||
t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected)
|
t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) {
|
func assertAPIErrorType(t *testing.T, apiErr openai.APIError, typ string) {
|
||||||
if apiErr.Type != typ {
|
if apiErr.Type != typ {
|
||||||
t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ)
|
t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestError(t *testing.T) {
|
func TestRequestError(t *testing.T) {
|
||||||
var err error = &RequestError{
|
var err error = &openai.RequestError{
|
||||||
HTTPStatusCode: http.StatusTeapot,
|
HTTPStatusCode: http.StatusTeapot,
|
||||||
Err: errors.New("i am a teapot"),
|
Err: errors.New("i am a teapot"),
|
||||||
}
|
}
|
||||||
|
|
||||||
var reqErr *RequestError
|
var reqErr *openai.RequestError
|
||||||
if !errors.As(err, &reqErr) {
|
if !errors.As(err, &reqErr) {
|
||||||
t.Fatalf("Error is not a RequestError: %+v", err)
|
t.Fatalf("Error is not a RequestError: %+v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ func Example() {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("ChatCompletion error: %v\n", err)
|
fmt.Printf("ChatCompletion error: %v\n", err)
|
||||||
return
|
return
|
||||||
@@ -319,7 +318,6 @@ func ExampleDefaultAzureConfig() {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("ChatCompletion error: %v\n", err)
|
fmt.Printf("ChatCompletion error: %v\n", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ func TestFileUpload(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/files", handleCreateFile)
|
server.RegisterHandler("/v1/files", handleCreateFile)
|
||||||
req := FileRequest{
|
req := openai.FileRequest{
|
||||||
FileName: "test.go",
|
FileName: "test.go",
|
||||||
FilePath: "client.go",
|
FilePath: "client.go",
|
||||||
Purpose: "fine-tune",
|
Purpose: "fine-tune",
|
||||||
@@ -57,7 +57,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
var fileReq = File{
|
fileReq := openai.File{
|
||||||
Bytes: int(header.Size),
|
Bytes: int(header.Size),
|
||||||
ID: strconv.Itoa(int(time.Now().Unix())),
|
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||||
FileName: header.Filename,
|
FileName: header.Filename,
|
||||||
@@ -82,7 +82,7 @@ func TestListFile(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) {
|
||||||
resBytes, _ := json.Marshal(FilesList{})
|
resBytes, _ := json.Marshal(openai.FilesList{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
})
|
})
|
||||||
_, err := client.ListFiles(context.Background())
|
_, err := client.ListFiles(context.Background())
|
||||||
@@ -93,7 +93,7 @@ func TestGetFile(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
|
||||||
resBytes, _ := json.Marshal(File{})
|
resBytes, _ := json.Marshal(openai.File{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
})
|
})
|
||||||
_, err := client.GetFile(context.Background(), "deadbeef")
|
_, err := client.GetFile(context.Background(), "deadbeef")
|
||||||
@@ -148,7 +148,7 @@ func TestGetFileContentReturnError(t *testing.T) {
|
|||||||
t.Fatal("Did not return error")
|
t.Fatal("Did not return error")
|
||||||
}
|
}
|
||||||
|
|
||||||
apiErr := &APIError{}
|
apiErr := &openai.APIError{}
|
||||||
if !errors.As(err, &apiErr) {
|
if !errors.As(err, &apiErr) {
|
||||||
t.Fatalf("Did not return APIError: %+v\n", apiErr)
|
t.Fatalf("Did not return APIError: %+v\n", apiErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r
|
|||||||
// This API will be officially deprecated on January 4th, 2024.
|
// This API will be officially deprecated on January 4th, 2024.
|
||||||
// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go.
|
// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go.
|
||||||
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||||
|
//nolint:goconst // Decreases readability
|
||||||
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"))
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testFineTuneID = "fine-tune-id"
|
const testFineTuneID = "fine-tune-id"
|
||||||
@@ -22,9 +22,9 @@ func TestFineTunes(t *testing.T) {
|
|||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
var resBytes []byte
|
var resBytes []byte
|
||||||
if r.Method == http.MethodGet {
|
if r.Method == http.MethodGet {
|
||||||
resBytes, _ = json.Marshal(FineTuneList{})
|
resBytes, _ = json.Marshal(openai.FineTuneList{})
|
||||||
} else {
|
} else {
|
||||||
resBytes, _ = json.Marshal(FineTune{})
|
resBytes, _ = json.Marshal(openai.FineTune{})
|
||||||
}
|
}
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
@@ -32,8 +32,8 @@ func TestFineTunes(t *testing.T) {
|
|||||||
|
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/v1/fine-tunes/"+testFineTuneID+"/cancel",
|
"/v1/fine-tunes/"+testFineTuneID+"/cancel",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(FineTune{})
|
resBytes, _ := json.Marshal(openai.FineTune{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -43,9 +43,9 @@ func TestFineTunes(t *testing.T) {
|
|||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
var resBytes []byte
|
var resBytes []byte
|
||||||
if r.Method == http.MethodDelete {
|
if r.Method == http.MethodDelete {
|
||||||
resBytes, _ = json.Marshal(FineTuneDeleteResponse{})
|
resBytes, _ = json.Marshal(openai.FineTuneDeleteResponse{})
|
||||||
} else {
|
} else {
|
||||||
resBytes, _ = json.Marshal(FineTune{})
|
resBytes, _ = json.Marshal(openai.FineTune{})
|
||||||
}
|
}
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
@@ -53,8 +53,8 @@ func TestFineTunes(t *testing.T) {
|
|||||||
|
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/v1/fine-tunes/"+testFineTuneID+"/events",
|
"/v1/fine-tunes/"+testFineTuneID+"/events",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(FineTuneEventList{})
|
resBytes, _ := json.Marshal(openai.FineTuneEventList{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -64,7 +64,7 @@ func TestFineTunes(t *testing.T) {
|
|||||||
_, err := client.ListFineTunes(ctx)
|
_, err := client.ListFineTunes(ctx)
|
||||||
checks.NoError(t, err, "ListFineTunes error")
|
checks.NoError(t, err, "ListFineTunes error")
|
||||||
|
|
||||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
_, err = client.CreateFineTune(ctx, openai.FineTuneRequest{})
|
||||||
checks.NoError(t, err, "CreateFineTune error")
|
checks.NoError(t, err, "CreateFineTune error")
|
||||||
|
|
||||||
_, err = client.CancelFineTune(ctx, testFineTuneID)
|
_, err = client.CancelFineTune(ctx, testFineTuneID)
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testFineTuninigJobID = "fine-tuning-job-id"
|
const testFineTuninigJobID = "fine-tuning-job-id"
|
||||||
@@ -20,8 +19,8 @@ func TestFineTuningJob(t *testing.T) {
|
|||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/v1/fine_tuning/jobs",
|
"/v1/fine_tuning/jobs",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(FineTuningJob{
|
resBytes, _ := json.Marshal(openai.FineTuningJob{
|
||||||
Object: "fine_tuning.job",
|
Object: "fine_tuning.job",
|
||||||
ID: testFineTuninigJobID,
|
ID: testFineTuninigJobID,
|
||||||
Model: "davinci-002",
|
Model: "davinci-002",
|
||||||
@@ -33,7 +32,7 @@ func TestFineTuningJob(t *testing.T) {
|
|||||||
Status: "succeeded",
|
Status: "succeeded",
|
||||||
ValidationFile: "",
|
ValidationFile: "",
|
||||||
TrainingFile: "file-abc123",
|
TrainingFile: "file-abc123",
|
||||||
Hyperparameters: Hyperparameters{
|
Hyperparameters: openai.Hyperparameters{
|
||||||
Epochs: "auto",
|
Epochs: "auto",
|
||||||
},
|
},
|
||||||
TrainedTokens: 5768,
|
TrainedTokens: 5768,
|
||||||
@@ -44,32 +43,32 @@ func TestFineTuningJob(t *testing.T) {
|
|||||||
|
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel",
|
"/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(FineTuningJob{})
|
resBytes, _ := json.Marshal(openai.FineTuningJob{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/v1/fine_tuning/jobs/"+testFineTuninigJobID,
|
"/v1/fine_tuning/jobs/"+testFineTuninigJobID,
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
var resBytes []byte
|
var resBytes []byte
|
||||||
resBytes, _ = json.Marshal(FineTuningJob{})
|
resBytes, _ = json.Marshal(openai.FineTuningJob{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events",
|
"/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(FineTuningJobEventList{})
|
resBytes, _ := json.Marshal(openai.FineTuningJobEventList{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
_, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
|
_, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{})
|
||||||
checks.NoError(t, err, "CreateFineTuningJob error")
|
checks.NoError(t, err, "CreateFineTuningJob error")
|
||||||
|
|
||||||
_, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID)
|
_, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID)
|
||||||
@@ -84,22 +83,22 @@ func TestFineTuningJob(t *testing.T) {
|
|||||||
_, err = client.ListFineTuningJobEvents(
|
_, err = client.ListFineTuningJobEvents(
|
||||||
ctx,
|
ctx,
|
||||||
testFineTuninigJobID,
|
testFineTuninigJobID,
|
||||||
ListFineTuningJobEventsWithAfter("last-event-id"),
|
openai.ListFineTuningJobEventsWithAfter("last-event-id"),
|
||||||
)
|
)
|
||||||
checks.NoError(t, err, "ListFineTuningJobEvents error")
|
checks.NoError(t, err, "ListFineTuningJobEvents error")
|
||||||
|
|
||||||
_, err = client.ListFineTuningJobEvents(
|
_, err = client.ListFineTuningJobEvents(
|
||||||
ctx,
|
ctx,
|
||||||
testFineTuninigJobID,
|
testFineTuninigJobID,
|
||||||
ListFineTuningJobEventsWithLimit(10),
|
openai.ListFineTuningJobEventsWithLimit(10),
|
||||||
)
|
)
|
||||||
checks.NoError(t, err, "ListFineTuningJobEvents error")
|
checks.NoError(t, err, "ListFineTuningJobEvents error")
|
||||||
|
|
||||||
_, err = client.ListFineTuningJobEvents(
|
_, err = client.ListFineTuningJobEvents(
|
||||||
ctx,
|
ctx,
|
||||||
testFineTuninigJobID,
|
testFineTuninigJobID,
|
||||||
ListFineTuningJobEventsWithAfter("last-event-id"),
|
openai.ListFineTuningJobEventsWithAfter("last-event-id"),
|
||||||
ListFineTuningJobEventsWithLimit(10),
|
openai.ListFineTuningJobEventsWithLimit(10),
|
||||||
)
|
)
|
||||||
checks.NoError(t, err, "ListFineTuningJobEvents error")
|
checks.NoError(t, err, "ListFineTuningJobEvents error")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -12,13 +9,16 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestImages(t *testing.T) {
|
func TestImages(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
|
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
|
||||||
_, err := client.CreateImage(context.Background(), ImageRequest{
|
_, err := client.CreateImage(context.Background(), openai.ImageRequest{
|
||||||
Prompt: "Lorem ipsum",
|
Prompt: "Lorem ipsum",
|
||||||
})
|
})
|
||||||
checks.NoError(t, err, "CreateImage error")
|
checks.NoError(t, err, "CreateImage error")
|
||||||
@@ -33,20 +33,20 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
var imageReq ImageRequest
|
var imageReq openai.ImageRequest
|
||||||
if imageReq, err = getImageBody(r); err != nil {
|
if imageReq, err = getImageBody(r); err != nil {
|
||||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res := ImageResponse{
|
res := openai.ImageResponse{
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
}
|
}
|
||||||
for i := 0; i < imageReq.N; i++ {
|
for i := 0; i < imageReq.N; i++ {
|
||||||
imageData := ImageResponseDataInner{}
|
imageData := openai.ImageResponseDataInner{}
|
||||||
switch imageReq.ResponseFormat {
|
switch imageReq.ResponseFormat {
|
||||||
case CreateImageResponseFormatURL, "":
|
case openai.CreateImageResponseFormatURL, "":
|
||||||
imageData.URL = "https://example.com/image.png"
|
imageData.URL = "https://example.com/image.png"
|
||||||
case CreateImageResponseFormatB64JSON:
|
case openai.CreateImageResponseFormatB64JSON:
|
||||||
// This decodes to "{}" in base64.
|
// This decodes to "{}" in base64.
|
||||||
imageData.B64JSON = "e30K"
|
imageData.B64JSON = "e30K"
|
||||||
default:
|
default:
|
||||||
@@ -60,16 +60,16 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getImageBody Returns the body of the request to create a image.
|
// getImageBody Returns the body of the request to create a image.
|
||||||
func getImageBody(r *http.Request) (ImageRequest, error) {
|
func getImageBody(r *http.Request) (openai.ImageRequest, error) {
|
||||||
image := ImageRequest{}
|
image := openai.ImageRequest{}
|
||||||
// read the request body
|
// read the request body
|
||||||
reqBody, err := io.ReadAll(r.Body)
|
reqBody, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ImageRequest{}, err
|
return openai.ImageRequest{}, err
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(reqBody, &image)
|
err = json.Unmarshal(reqBody, &image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ImageRequest{}, err
|
return openai.ImageRequest{}, err
|
||||||
}
|
}
|
||||||
return image, nil
|
return image, nil
|
||||||
}
|
}
|
||||||
@@ -98,13 +98,13 @@ func TestImageEdit(t *testing.T) {
|
|||||||
os.Remove("image.png")
|
os.Remove("image.png")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{
|
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
||||||
Image: origin,
|
Image: origin,
|
||||||
Mask: mask,
|
Mask: mask,
|
||||||
Prompt: "There is a turtle in the pool",
|
Prompt: "There is a turtle in the pool",
|
||||||
N: 3,
|
N: 3,
|
||||||
Size: CreateImageSize1024x1024,
|
Size: openai.CreateImageSize1024x1024,
|
||||||
ResponseFormat: CreateImageResponseFormatURL,
|
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||||
})
|
})
|
||||||
checks.NoError(t, err, "CreateImage error")
|
checks.NoError(t, err, "CreateImage error")
|
||||||
}
|
}
|
||||||
@@ -125,12 +125,12 @@ func TestImageEditWithoutMask(t *testing.T) {
|
|||||||
os.Remove("image.png")
|
os.Remove("image.png")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{
|
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
||||||
Image: origin,
|
Image: origin,
|
||||||
Prompt: "There is a turtle in the pool",
|
Prompt: "There is a turtle in the pool",
|
||||||
N: 3,
|
N: 3,
|
||||||
Size: CreateImageSize1024x1024,
|
Size: openai.CreateImageSize1024x1024,
|
||||||
ResponseFormat: CreateImageResponseFormatURL,
|
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||||
})
|
})
|
||||||
checks.NoError(t, err, "CreateImage error")
|
checks.NoError(t, err, "CreateImage error")
|
||||||
}
|
}
|
||||||
@@ -144,9 +144,9 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
|
|
||||||
responses := ImageResponse{
|
responses := openai.ImageResponse{
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
Data: []ImageResponseDataInner{
|
Data: []openai.ImageResponseDataInner{
|
||||||
{
|
{
|
||||||
URL: "test-url1",
|
URL: "test-url1",
|
||||||
B64JSON: "",
|
B64JSON: "",
|
||||||
@@ -182,11 +182,11 @@ func TestImageVariation(t *testing.T) {
|
|||||||
os.Remove("image.png")
|
os.Remove("image.png")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err = client.CreateVariImage(context.Background(), ImageVariRequest{
|
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
|
||||||
Image: origin,
|
Image: origin,
|
||||||
N: 3,
|
N: 3,
|
||||||
Size: CreateImageSize1024x1024,
|
Size: openai.CreateImageSize1024x1024,
|
||||||
ResponseFormat: CreateImageResponseFormatURL,
|
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||||
})
|
})
|
||||||
checks.NoError(t, err, "CreateImage error")
|
checks.NoError(t, err, "CreateImage error")
|
||||||
}
|
}
|
||||||
@@ -200,9 +200,9 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
|
|
||||||
responses := ImageResponse{
|
responses := openai.ImageResponse{
|
||||||
Created: time.Now().Unix(),
|
Created: time.Now().Unix(),
|
||||||
Data: []ImageResponseDataInner{
|
Data: []openai.ImageResponseDataInner{
|
||||||
{
|
{
|
||||||
URL: "test-url1",
|
URL: "test-url1",
|
||||||
B64JSON: "",
|
B64JSON: "",
|
||||||
|
|||||||
@@ -5,28 +5,28 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDefinition_MarshalJSON(t *testing.T) {
|
func TestDefinition_MarshalJSON(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
def Definition
|
def jsonschema.Definition
|
||||||
want string
|
want string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Test with empty Definition",
|
name: "Test with empty Definition",
|
||||||
def: Definition{},
|
def: jsonschema.Definition{},
|
||||||
want: `{"properties":{}}`,
|
want: `{"properties":{}}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with Definition properties set",
|
name: "Test with Definition properties set",
|
||||||
def: Definition{
|
def: jsonschema.Definition{
|
||||||
Type: String,
|
Type: jsonschema.String,
|
||||||
Description: "A string type",
|
Description: "A string type",
|
||||||
Properties: map[string]Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"name": {
|
"name": {
|
||||||
Type: String,
|
Type: jsonschema.String,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -43,17 +43,17 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with nested Definition properties",
|
name: "Test with nested Definition properties",
|
||||||
def: Definition{
|
def: jsonschema.Definition{
|
||||||
Type: Object,
|
Type: jsonschema.Object,
|
||||||
Properties: map[string]Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"user": {
|
"user": {
|
||||||
Type: Object,
|
Type: jsonschema.Object,
|
||||||
Properties: map[string]Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"name": {
|
"name": {
|
||||||
Type: String,
|
Type: jsonschema.String,
|
||||||
},
|
},
|
||||||
"age": {
|
"age": {
|
||||||
Type: Integer,
|
Type: jsonschema.Integer,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -80,26 +80,26 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with complex nested Definition",
|
name: "Test with complex nested Definition",
|
||||||
def: Definition{
|
def: jsonschema.Definition{
|
||||||
Type: Object,
|
Type: jsonschema.Object,
|
||||||
Properties: map[string]Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"user": {
|
"user": {
|
||||||
Type: Object,
|
Type: jsonschema.Object,
|
||||||
Properties: map[string]Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"name": {
|
"name": {
|
||||||
Type: String,
|
Type: jsonschema.String,
|
||||||
},
|
},
|
||||||
"age": {
|
"age": {
|
||||||
Type: Integer,
|
Type: jsonschema.Integer,
|
||||||
},
|
},
|
||||||
"address": {
|
"address": {
|
||||||
Type: Object,
|
Type: jsonschema.Object,
|
||||||
Properties: map[string]Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"city": {
|
"city": {
|
||||||
Type: String,
|
Type: jsonschema.String,
|
||||||
},
|
},
|
||||||
"country": {
|
"country": {
|
||||||
Type: String,
|
Type: jsonschema.String,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -141,14 +141,14 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with Array type Definition",
|
name: "Test with Array type Definition",
|
||||||
def: Definition{
|
def: jsonschema.Definition{
|
||||||
Type: Array,
|
Type: jsonschema.Array,
|
||||||
Items: &Definition{
|
Items: &jsonschema.Definition{
|
||||||
Type: String,
|
Type: jsonschema.String,
|
||||||
},
|
},
|
||||||
Properties: map[string]Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"name": {
|
"name": {
|
||||||
Type: String,
|
Type: jsonschema.String,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testFineTuneModelID = "fine-tune-model-id"
|
const testFineTuneModelID = "fine-tune-model-id"
|
||||||
@@ -35,7 +34,7 @@ func TestAzureListModels(t *testing.T) {
|
|||||||
|
|
||||||
// handleListModelsEndpoint Handles the list models endpoint by the test server.
|
// handleListModelsEndpoint Handles the list models endpoint by the test server.
|
||||||
func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
|
func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(ModelsList{})
|
resBytes, _ := json.Marshal(openai.ModelsList{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +57,7 @@ func TestAzureGetModel(t *testing.T) {
|
|||||||
|
|
||||||
// handleGetModelsEndpoint Handles the get model endpoint by the test server.
|
// handleGetModelsEndpoint Handles the get model endpoint by the test server.
|
||||||
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(Model{})
|
resBytes, _ := json.Marshal(openai.Model{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,6 +89,6 @@ func TestDeleteFineTuneModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{})
|
resBytes, _ := json.Marshal(openai.FineTuneModelDeleteResponse{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,6 +10,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestModeration Tests the moderations endpoint of the API using the mocked server.
|
// TestModeration Tests the moderations endpoint of the API using the mocked server.
|
||||||
@@ -20,8 +20,8 @@ func TestModerations(t *testing.T) {
|
|||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
|
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
|
||||||
_, err := client.Moderations(context.Background(), ModerationRequest{
|
_, err := client.Moderations(context.Background(), openai.ModerationRequest{
|
||||||
Model: ModerationTextStable,
|
Model: openai.ModerationTextStable,
|
||||||
Input: "I want to kill them.",
|
Input: "I want to kill them.",
|
||||||
})
|
})
|
||||||
checks.NoError(t, err, "Moderation error")
|
checks.NoError(t, err, "Moderation error")
|
||||||
@@ -34,16 +34,16 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) {
|
|||||||
expect error
|
expect error
|
||||||
}
|
}
|
||||||
modelOptions = append(modelOptions,
|
modelOptions = append(modelOptions,
|
||||||
getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel),
|
getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel),
|
||||||
getModerationModelTestOption(ModerationTextStable, nil),
|
getModerationModelTestOption(openai.ModerationTextStable, nil),
|
||||||
getModerationModelTestOption(ModerationTextLatest, nil),
|
getModerationModelTestOption(openai.ModerationTextLatest, nil),
|
||||||
getModerationModelTestOption("", nil),
|
getModerationModelTestOption("", nil),
|
||||||
)
|
)
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
|
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
|
||||||
for _, modelTest := range modelOptions {
|
for _, modelTest := range modelOptions {
|
||||||
_, err := client.Moderations(context.Background(), ModerationRequest{
|
_, err := client.Moderations(context.Background(), openai.ModerationRequest{
|
||||||
Model: modelTest.model,
|
Model: modelTest.model,
|
||||||
Input: "I want to kill them.",
|
Input: "I want to kill them.",
|
||||||
})
|
})
|
||||||
@@ -71,32 +71,32 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
var moderationReq ModerationRequest
|
var moderationReq openai.ModerationRequest
|
||||||
if moderationReq, err = getModerationBody(r); err != nil {
|
if moderationReq, err = getModerationBody(r); err != nil {
|
||||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resCat := ResultCategories{}
|
resCat := openai.ResultCategories{}
|
||||||
resCatScore := ResultCategoryScores{}
|
resCatScore := openai.ResultCategoryScores{}
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(moderationReq.Input, "kill"):
|
case strings.Contains(moderationReq.Input, "kill"):
|
||||||
resCat = ResultCategories{Violence: true}
|
resCat = openai.ResultCategories{Violence: true}
|
||||||
resCatScore = ResultCategoryScores{Violence: 1}
|
resCatScore = openai.ResultCategoryScores{Violence: 1}
|
||||||
case strings.Contains(moderationReq.Input, "hate"):
|
case strings.Contains(moderationReq.Input, "hate"):
|
||||||
resCat = ResultCategories{Hate: true}
|
resCat = openai.ResultCategories{Hate: true}
|
||||||
resCatScore = ResultCategoryScores{Hate: 1}
|
resCatScore = openai.ResultCategoryScores{Hate: 1}
|
||||||
case strings.Contains(moderationReq.Input, "suicide"):
|
case strings.Contains(moderationReq.Input, "suicide"):
|
||||||
resCat = ResultCategories{SelfHarm: true}
|
resCat = openai.ResultCategories{SelfHarm: true}
|
||||||
resCatScore = ResultCategoryScores{SelfHarm: 1}
|
resCatScore = openai.ResultCategoryScores{SelfHarm: 1}
|
||||||
case strings.Contains(moderationReq.Input, "porn"):
|
case strings.Contains(moderationReq.Input, "porn"):
|
||||||
resCat = ResultCategories{Sexual: true}
|
resCat = openai.ResultCategories{Sexual: true}
|
||||||
resCatScore = ResultCategoryScores{Sexual: 1}
|
resCatScore = openai.ResultCategoryScores{Sexual: 1}
|
||||||
}
|
}
|
||||||
|
|
||||||
result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}
|
result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}
|
||||||
|
|
||||||
res := ModerationResponse{
|
res := openai.ModerationResponse{
|
||||||
ID: strconv.Itoa(int(time.Now().Unix())),
|
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||||
Model: moderationReq.Model,
|
Model: moderationReq.Model,
|
||||||
}
|
}
|
||||||
@@ -107,16 +107,16 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getModerationBody Returns the body of the request to do a moderation.
|
// getModerationBody Returns the body of the request to do a moderation.
|
||||||
func getModerationBody(r *http.Request) (ModerationRequest, error) {
|
func getModerationBody(r *http.Request) (openai.ModerationRequest, error) {
|
||||||
moderation := ModerationRequest{}
|
moderation := openai.ModerationRequest{}
|
||||||
// read the request body
|
// read the request body
|
||||||
reqBody, err := io.ReadAll(r.Body)
|
reqBody, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ModerationRequest{}, err
|
return openai.ModerationRequest{}, err
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(reqBody, &moderation)
|
err = json.Unmarshal(reqBody, &moderation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ModerationRequest{}, err
|
return openai.ModerationRequest{}, err
|
||||||
}
|
}
|
||||||
return moderation, nil
|
return moderation, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +1,29 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) {
|
func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
|
||||||
server = test.NewTestServer()
|
server = test.NewTestServer()
|
||||||
ts := server.OpenAITestServer()
|
ts := server.OpenAITestServer()
|
||||||
ts.Start()
|
ts.Start()
|
||||||
teardown = ts.Close
|
teardown = ts.Close
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := openai.DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = ts.URL + "/v1"
|
config.BaseURL = ts.URL + "/v1"
|
||||||
client = NewClientWithConfig(config)
|
client = openai.NewClientWithConfig(config)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) {
|
func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
|
||||||
server = test.NewTestServer()
|
server = test.NewTestServer()
|
||||||
ts := server.OpenAITestServer()
|
ts := server.OpenAITestServer()
|
||||||
ts.Start()
|
ts.Start()
|
||||||
teardown = ts.Close
|
teardown = ts.Close
|
||||||
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
|
config := openai.DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
|
||||||
config.BaseURL = ts.URL
|
config.BaseURL = ts.URL
|
||||||
client = NewClientWithConfig(config)
|
client = openai.NewClientWithConfig(config)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,23 +10,23 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCompletionsStreamWrongModel(t *testing.T) {
|
func TestCompletionsStreamWrongModel(t *testing.T) {
|
||||||
config := DefaultConfig("whatever")
|
config := openai.DefaultConfig("whatever")
|
||||||
config.BaseURL = "http://localhost/v1"
|
config.BaseURL = "http://localhost/v1"
|
||||||
client := NewClientWithConfig(config)
|
client := openai.NewClientWithConfig(config)
|
||||||
|
|
||||||
_, err := client.CreateCompletionStream(
|
_, err := client.CreateCompletionStream(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
CompletionRequest{
|
openai.CompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: openai.GPT3Dot5Turbo,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if !errors.Is(err, ErrCompletionUnsupportedModel) {
|
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
|
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -56,7 +56,7 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
@@ -65,20 +65,20 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
expectedResponses := []CompletionResponse{
|
expectedResponses := []openai.CompletionResponse{
|
||||||
{
|
{
|
||||||
ID: "1",
|
ID: "1",
|
||||||
Object: "completion",
|
Object: "completion",
|
||||||
Created: 1598069254,
|
Created: 1598069254,
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}},
|
Choices: []openai.CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "2",
|
ID: "2",
|
||||||
Object: "completion",
|
Object: "completion",
|
||||||
Created: 1598069255,
|
Created: 1598069255,
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
|
Choices: []openai.CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,9 +129,9 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3TextDavinci003,
|
Model: openai.GPT3TextDavinci003,
|
||||||
Prompt: "Hello!",
|
Prompt: "Hello!",
|
||||||
Stream: true,
|
Stream: true,
|
||||||
})
|
})
|
||||||
@@ -141,7 +141,7 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
|||||||
_, streamErr := stream.Recv()
|
_, streamErr := stream.Recv()
|
||||||
checks.HasError(t, streamErr, "stream.Recv() did not return error")
|
checks.HasError(t, streamErr, "stream.Recv() did not return error")
|
||||||
|
|
||||||
var apiErr *APIError
|
var apiErr *openai.APIError
|
||||||
if !errors.As(streamErr, &apiErr) {
|
if !errors.As(streamErr, &apiErr) {
|
||||||
t.Errorf("stream.Recv() did not return APIError")
|
t.Errorf("stream.Recv() did not return APIError")
|
||||||
}
|
}
|
||||||
@@ -166,10 +166,10 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
var apiErr *APIError
|
var apiErr *openai.APIError
|
||||||
_, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
_, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Ada,
|
Model: openai.GPT3Ada,
|
||||||
Prompt: "Hello!",
|
Prompt: "Hello!",
|
||||||
Stream: true,
|
Stream: true,
|
||||||
})
|
})
|
||||||
@@ -209,7 +209,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
@@ -220,7 +220,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
|
|||||||
|
|
||||||
_, _ = stream.Recv()
|
_, _ = stream.Recv()
|
||||||
_, streamErr := stream.Recv()
|
_, streamErr := stream.Recv()
|
||||||
if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) {
|
if !errors.Is(streamErr, openai.ErrTooManyEmptyStreamMessages) {
|
||||||
t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages")
|
t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -244,7 +244,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
@@ -285,7 +285,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
|
|||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
@@ -312,7 +312,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
|
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
_, err := client.CreateCompletionStream(ctx, CompletionRequest{
|
_, err := client.CreateCompletionStream(ctx, openai.CompletionRequest{
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
@@ -327,7 +327,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper funcs.
|
// Helper funcs.
|
||||||
func compareResponses(r1, r2 CompletionResponse) bool {
|
func compareResponses(r1, r2 openai.CompletionResponse) bool {
|
||||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -342,7 +342,7 @@ func compareResponses(r1, r2 CompletionResponse) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func compareResponseChoices(c1, c2 CompletionChoice) bool {
|
func compareResponseChoices(c1, c2 openai.CompletionChoice) bool {
|
||||||
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason {
|
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user