package openai_test import ( "context" "encoding/json" "fmt" "io" "net/http" "strconv" "strings" "testing" "time" "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) const ( xCustomHeader = "X-CUSTOM-HEADER" xCustomHeaderValue = "test" ) var rateLimitHeaders = map[string]any{ "x-ratelimit-limit-requests": 60, "x-ratelimit-limit-tokens": 150000, "x-ratelimit-remaining-requests": 59, "x-ratelimit-remaining-tokens": 149984, "x-ratelimit-reset-requests": "1s", "x-ratelimit-reset-tokens": "6m0s", } func TestChatCompletionsWrongModel(t *testing.T) { config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" client := openai.NewClientWithConfig(config) ctx := context.Background() req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletion(ctx, req) msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } func TestChatRequestOmitEmpty(t *testing.T) { data, err := json.Marshal(openai.ChatCompletionRequest{ // We set model b/c it's required, so omitempty doesn't make sense Model: "gpt-4", }) checks.NoError(t, err) // messages is also required so isn't omitted const expected = `{"model":"gpt-4","messages":null}` if string(data) != expected { t.Errorf("expected JSON with all empty fields to be %v but was %v", expected, string(data)) } } func TestChatCompletionsWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" client := openai.NewClientWithConfig(config) ctx := context.Background() req := openai.ChatCompletionRequest{ Stream: true, } _, err := client.CreateChatCompletion(ctx, req) checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error") } // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, Model: openai.GPT3Dot5Turbo, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, }) checks.NoError(t, err, "CreateChatCompletion error") } // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, Model: openai.GPT3Dot5Turbo, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, }) checks.NoError(t, err, "CreateChatCompletion error") a := resp.Header().Get(xCustomHeader) _ = a if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) } } // TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, Model: openai.GPT3Dot5Turbo, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, }) checks.NoError(t, err, "CreateChatCompletion error") headers := resp.GetRateLimitHeaders() resetRequests := headers.ResetRequests.String() if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] { t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"]) } resetRequestsTime := headers.ResetRequests.Time() if resetRequestsTime.Before(time.Now()) { t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) } bs1, _ := json.Marshal(headers) bs2, _ := json.Marshal(rateLimitHeaders) if string(bs1) != string(bs2) { t.Errorf("expected rate limit header %s to be %s", bs2, bs1) } } // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) t.Run("bytes", func(t *testing.T) { //nolint:lll msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, Model: openai.GPT3Dot5Turbo0613, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) t.Run("struct", func(t *testing.T) { type testMessage struct { Count int `json:"count"` Words []string `json:"words"` } msg := testMessage{ Count: 2, Words: []string{"hello", "world"}, } _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, Model: openai.GPT3Dot5Turbo0613, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) t.Run("JSONSchemaDefinition", func(t *testing.T) { _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, Model: openai.GPT3Dot5Turbo0613, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ "count": { Type: jsonschema.Number, Description: "total number of words in sentence", }, "words": { Type: jsonschema.Array, Description: "list of words in sentence", Items: &jsonschema.Definition{ Type: jsonschema.String, }, }, "enumTest": { Type: jsonschema.String, Enum: []string{"hello", "world"}, }, }, }, }}, }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { // this is a compatibility check _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, Model: openai.GPT3Dot5Turbo0613, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Functions: []openai.FunctionDefine{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ "count": { Type: jsonschema.Number, Description: "total number of words in sentence", }, "words": { Type: jsonschema.Array, Description: "list of words in sentence", Items: &jsonschema.Definition{ Type: jsonschema.String, }, }, "enumTest": { Type: jsonschema.String, Enum: []string{"hello", "world"}, }, }, }, }}, }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) } func TestAzureChatCompletions(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, Model: openai.GPT3Dot5Turbo, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, }) checks.NoError(t, err, "CreateAzureChatCompletion error") } // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error var resBytes []byte // completions only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } var completionReq openai.ChatCompletionRequest if completionReq, err = getChatCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } res := openai.ChatCompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), // would be nice to validate Model during testing, but // this may not be possible with how much upkeep // would be required / wouldn't make much sense Model: completionReq.Model, } // create completions n := completionReq.N if n == 0 { n = 1 } for i := 0; i < n; i++ { // if there are functions, include them if len(completionReq.Functions) > 0 { var fcb []byte b := completionReq.Functions[0].Parameters fcb, err = json.Marshal(b) if err != nil { http.Error(w, "could not marshal function parameters", http.StatusInternalServerError) return } res.Choices = append(res.Choices, openai.ChatCompletionChoice{ Message: openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleFunction, // this is valid json so it should be fine FunctionCall: &openai.FunctionCall{ Name: completionReq.Functions[0].Name, Arguments: string(fcb), }, }, Index: i, }) continue } // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) res.Choices = append(res.Choices, openai.ChatCompletionChoice{ Message: openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleAssistant, Content: completionStr, }, Index: i, }) } inputTokens := numTokens(completionReq.Messages[0].Content) * n completionTokens := completionReq.MaxTokens * n res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) w.Header().Set(xCustomHeader, xCustomHeaderValue) for k, v := range rateLimitHeaders { switch val := v.(type) { case int: w.Header().Set(k, strconv.Itoa(val)) default: w.Header().Set(k, fmt.Sprintf("%s", v)) } } fmt.Fprintln(w, string(resBytes)) } // getChatCompletionBody Returns the body of the request to create a completion. func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { completion := openai.ChatCompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { return openai.ChatCompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { return openai.ChatCompletionRequest{}, err } return completion, nil } func TestFinishReason(t *testing.T) { c := &openai.ChatCompletionChoice{ FinishReason: openai.FinishReasonNull, } resBytes, _ := json.Marshal(c) if !strings.Contains(string(resBytes), `"finish_reason":null`) { t.Error("null should not be quoted") } c.FinishReason = "" resBytes, _ = json.Marshal(c) if !strings.Contains(string(resBytes), `"finish_reason":null`) { t.Error("null should not be quoted") } otherReasons := []openai.FinishReason{ openai.FinishReasonStop, openai.FinishReasonLength, openai.FinishReasonFunctionCall, openai.FinishReasonContentFilter, } for _, r := range otherReasons { c.FinishReason = r resBytes, _ = json.Marshal(c) if !strings.Contains(string(resBytes), fmt.Sprintf(`"finish_reason":"%s"`, r)) { t.Errorf("%s should be quoted", r) } } }