* updated client_test to solve lint error * modified golangci yml to solve linter issues * minor change
574 lines
16 KiB
Go
574 lines
16 KiB
Go
package openai //nolint:testpackage // testing private field
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/sashabaranov/go-openai/internal/test"
|
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
)
|
|
|
|
var errTestRequestBuilderFailed = errors.New("test request builder failed")
|
|
|
|
type failingRequestBuilder struct{}
|
|
|
|
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) {
|
|
return nil, errTestRequestBuilderFailed
|
|
}
|
|
|
|
func TestClient(t *testing.T) {
|
|
const mockToken = "mock token"
|
|
client := NewClient(mockToken)
|
|
if client.config.authToken != mockToken {
|
|
t.Errorf("Client does not contain proper token")
|
|
}
|
|
|
|
const mockOrg = "mock org"
|
|
client = NewOrgClient(mockToken, mockOrg)
|
|
if client.config.authToken != mockToken {
|
|
t.Errorf("Client does not contain proper token")
|
|
}
|
|
if client.config.OrgID != mockOrg {
|
|
t.Errorf("Client does not contain proper orgID")
|
|
}
|
|
}
|
|
|
|
func TestDecodeResponse(t *testing.T) {
|
|
stringInput := ""
|
|
|
|
testCases := []struct {
|
|
name string
|
|
value interface{}
|
|
expected interface{}
|
|
body io.Reader
|
|
hasError bool
|
|
}{
|
|
{
|
|
name: "nil input",
|
|
value: nil,
|
|
body: bytes.NewReader([]byte("")),
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "string input",
|
|
value: &stringInput,
|
|
body: bytes.NewReader([]byte("test")),
|
|
expected: "test",
|
|
},
|
|
{
|
|
name: "map input",
|
|
value: &map[string]interface{}{},
|
|
body: bytes.NewReader([]byte(`{"test": "test"}`)),
|
|
expected: map[string]interface{}{
|
|
"test": "test",
|
|
},
|
|
},
|
|
{
|
|
name: "reader return error",
|
|
value: &stringInput,
|
|
body: &errorReader{err: errors.New("dummy")},
|
|
hasError: true,
|
|
},
|
|
{
|
|
name: "audio text input",
|
|
value: &audioTextResponse{},
|
|
body: bytes.NewReader([]byte("test")),
|
|
expected: audioTextResponse{
|
|
Text: "test",
|
|
},
|
|
},
|
|
}
|
|
|
|
assertEqual := func(t *testing.T, expected, actual interface{}) {
|
|
t.Helper()
|
|
if expected == actual {
|
|
return
|
|
}
|
|
v := reflect.ValueOf(actual).Elem().Interface()
|
|
if !reflect.DeepEqual(v, expected) {
|
|
t.Fatalf("Unexpected value: %v, expected: %v", v, expected)
|
|
}
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := decodeResponse(tc.body, tc.value)
|
|
if tc.hasError {
|
|
checks.HasError(t, err, "Unexpected nil error")
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
assertEqual(t, tc.expected, tc.value)
|
|
})
|
|
}
|
|
}
|
|
|
|
type errorReader struct {
|
|
err error
|
|
}
|
|
|
|
func (e *errorReader) Read(_ []byte) (n int, err error) {
|
|
return 0, e.err
|
|
}
|
|
|
|
func TestHandleErrorResp(t *testing.T) {
|
|
// var errRes *ErrorResponse
|
|
var errRes ErrorResponse
|
|
var reqErr RequestError
|
|
t.Log(errRes, errRes.Error)
|
|
if errRes.Error != nil {
|
|
reqErr.Err = errRes.Error
|
|
}
|
|
t.Log(fmt.Errorf("error, %w", &reqErr))
|
|
t.Log(errRes.Error, "nil pointer check Pass")
|
|
|
|
const mockToken = "mock token"
|
|
client := NewClient(mockToken)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
httpCode int
|
|
httpStatus string
|
|
contentType string
|
|
body io.Reader
|
|
expected string
|
|
}{
|
|
{
|
|
name: "401 Invalid Authentication",
|
|
httpCode: http.StatusUnauthorized,
|
|
contentType: "application/json",
|
|
body: bytes.NewReader([]byte(
|
|
`{
|
|
"error":{
|
|
"message":"You didn't provide an API key. ....",
|
|
"type":"invalid_request_error",
|
|
"param":null,
|
|
"code":null
|
|
}
|
|
}`,
|
|
)),
|
|
expected: "error, status code: 401, status: , message: You didn't provide an API key. ....",
|
|
},
|
|
{
|
|
name: "401 Azure Access Denied",
|
|
httpCode: http.StatusUnauthorized,
|
|
contentType: "application/json",
|
|
body: bytes.NewReader([]byte(
|
|
`{
|
|
"error":{
|
|
"code":"AccessDenied",
|
|
"message":"Access denied due to Virtual Network/Firewall rules."
|
|
}
|
|
}`,
|
|
)),
|
|
expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.",
|
|
},
|
|
{
|
|
name: "503 Model Overloaded",
|
|
httpCode: http.StatusServiceUnavailable,
|
|
contentType: "application/json",
|
|
body: bytes.NewReader([]byte(`
|
|
{
|
|
"error":{
|
|
"message":"That model...",
|
|
"type":"server_error",
|
|
"param":null,
|
|
"code":null
|
|
}
|
|
}`)),
|
|
expected: "error, status code: 503, status: , message: That model...",
|
|
},
|
|
{
|
|
name: "503 no message (Unknown response)",
|
|
httpCode: http.StatusServiceUnavailable,
|
|
contentType: "application/json",
|
|
body: bytes.NewReader([]byte(`
|
|
{
|
|
"error":{}
|
|
}`)),
|
|
expected: `error, status code: 503, status: , message: , body:
|
|
{
|
|
"error":{}
|
|
}`,
|
|
},
|
|
{
|
|
name: "413 Request Entity Too Large",
|
|
httpCode: http.StatusRequestEntityTooLarge,
|
|
contentType: "text/html",
|
|
body: bytes.NewReader([]byte(`
|
|
<html>
|
|
<head><title>413 Request Entity Too Large</title></head>
|
|
<body>
|
|
<center><h1>413 Request Entity Too Large</h1></center>
|
|
<hr><center>nginx</center>
|
|
</body>
|
|
</html>`)),
|
|
expected: `error, status code: 413, status: , message: invalid character '<' looking for beginning of value, body:
|
|
<html>
|
|
<head><title>413 Request Entity Too Large</title></head>
|
|
<body>
|
|
<center><h1>413 Request Entity Too Large</h1></center>
|
|
<hr><center>nginx</center>
|
|
</body>
|
|
</html>`,
|
|
},
|
|
{
|
|
name: "errorReader",
|
|
httpCode: http.StatusRequestEntityTooLarge,
|
|
contentType: "text/html",
|
|
body: &errorReader{err: errors.New("errorReader")},
|
|
expected: "error, reading response body: errorReader",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
testCase := &http.Response{
|
|
Header: map[string][]string{
|
|
"Content-Type": {tc.contentType},
|
|
},
|
|
}
|
|
testCase.StatusCode = tc.httpCode
|
|
testCase.Body = io.NopCloser(tc.body)
|
|
err := client.handleErrorResp(testCase)
|
|
t.Log(err.Error())
|
|
if err.Error() != tc.expected {
|
|
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
|
|
t.Fail()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
|
config := DefaultConfig(test.GetTestToken())
|
|
client := NewClientWithConfig(config)
|
|
client.requestBuilder = &failingRequestBuilder{}
|
|
ctx := context.Background()
|
|
|
|
type TestCase struct {
|
|
Name string
|
|
TestFunc func() (any, error)
|
|
}
|
|
|
|
testCases := []TestCase{
|
|
{"CreateCompletion", func() (any, error) {
|
|
return client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
|
}},
|
|
{"CreateCompletionStream", func() (any, error) {
|
|
return client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
|
}},
|
|
{"CreateChatCompletion", func() (any, error) {
|
|
return client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
|
}},
|
|
{"CreateChatCompletionStream", func() (any, error) {
|
|
return client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
|
}},
|
|
{"CreateFineTune", func() (any, error) {
|
|
return client.CreateFineTune(ctx, FineTuneRequest{})
|
|
}},
|
|
{"ListFineTunes", func() (any, error) {
|
|
return client.ListFineTunes(ctx)
|
|
}},
|
|
{"CancelFineTune", func() (any, error) {
|
|
return client.CancelFineTune(ctx, "")
|
|
}},
|
|
{"GetFineTune", func() (any, error) {
|
|
return client.GetFineTune(ctx, "")
|
|
}},
|
|
{"DeleteFineTune", func() (any, error) {
|
|
return client.DeleteFineTune(ctx, "")
|
|
}},
|
|
{"ListFineTuneEvents", func() (any, error) {
|
|
return client.ListFineTuneEvents(ctx, "")
|
|
}},
|
|
{"CreateFineTuningJob", func() (any, error) {
|
|
return client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
|
|
}},
|
|
{"CancelFineTuningJob", func() (any, error) {
|
|
return client.CancelFineTuningJob(ctx, "")
|
|
}},
|
|
{"RetrieveFineTuningJob", func() (any, error) {
|
|
return client.RetrieveFineTuningJob(ctx, "")
|
|
}},
|
|
{"ListFineTuningJobEvents", func() (any, error) {
|
|
return client.ListFineTuningJobEvents(ctx, "")
|
|
}},
|
|
{"Moderations", func() (any, error) {
|
|
return client.Moderations(ctx, ModerationRequest{})
|
|
}},
|
|
{"Edits", func() (any, error) {
|
|
return client.Edits(ctx, EditsRequest{})
|
|
}},
|
|
{"CreateEmbeddings", func() (any, error) {
|
|
return client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
|
}},
|
|
{"CreateImage", func() (any, error) {
|
|
return client.CreateImage(ctx, ImageRequest{})
|
|
}},
|
|
{"CreateFileBytes", func() (any, error) {
|
|
return client.CreateFileBytes(ctx, FileBytesRequest{})
|
|
}},
|
|
{"DeleteFile", func() (any, error) {
|
|
return nil, client.DeleteFile(ctx, "")
|
|
}},
|
|
{"GetFile", func() (any, error) {
|
|
return client.GetFile(ctx, "")
|
|
}},
|
|
{"GetFileContent", func() (any, error) {
|
|
return client.GetFileContent(ctx, "")
|
|
}},
|
|
{"ListFiles", func() (any, error) {
|
|
return client.ListFiles(ctx)
|
|
}},
|
|
{"ListEngines", func() (any, error) {
|
|
return client.ListEngines(ctx)
|
|
}},
|
|
{"GetEngine", func() (any, error) {
|
|
return client.GetEngine(ctx, "")
|
|
}},
|
|
{"ListModels", func() (any, error) {
|
|
return client.ListModels(ctx)
|
|
}},
|
|
{"GetModel", func() (any, error) {
|
|
return client.GetModel(ctx, "text-davinci-003")
|
|
}},
|
|
{"DeleteFineTuneModel", func() (any, error) {
|
|
return client.DeleteFineTuneModel(ctx, "")
|
|
}},
|
|
{"CreateAssistant", func() (any, error) {
|
|
return client.CreateAssistant(ctx, AssistantRequest{})
|
|
}},
|
|
{"RetrieveAssistant", func() (any, error) {
|
|
return client.RetrieveAssistant(ctx, "")
|
|
}},
|
|
{"ModifyAssistant", func() (any, error) {
|
|
return client.ModifyAssistant(ctx, "", AssistantRequest{})
|
|
}},
|
|
{"DeleteAssistant", func() (any, error) {
|
|
return client.DeleteAssistant(ctx, "")
|
|
}},
|
|
{"ListAssistants", func() (any, error) {
|
|
return client.ListAssistants(ctx, nil, nil, nil, nil)
|
|
}},
|
|
{"CreateAssistantFile", func() (any, error) {
|
|
return client.CreateAssistantFile(ctx, "", AssistantFileRequest{})
|
|
}},
|
|
{"ListAssistantFiles", func() (any, error) {
|
|
return client.ListAssistantFiles(ctx, "", nil, nil, nil, nil)
|
|
}},
|
|
{"RetrieveAssistantFile", func() (any, error) {
|
|
return client.RetrieveAssistantFile(ctx, "", "")
|
|
}},
|
|
{"DeleteAssistantFile", func() (any, error) {
|
|
return nil, client.DeleteAssistantFile(ctx, "", "")
|
|
}},
|
|
{"CreateMessage", func() (any, error) {
|
|
return client.CreateMessage(ctx, "", MessageRequest{})
|
|
}},
|
|
{"ListMessage", func() (any, error) {
|
|
return client.ListMessage(ctx, "", nil, nil, nil, nil, nil)
|
|
}},
|
|
{"RetrieveMessage", func() (any, error) {
|
|
return client.RetrieveMessage(ctx, "", "")
|
|
}},
|
|
{"ModifyMessage", func() (any, error) {
|
|
return client.ModifyMessage(ctx, "", "", nil)
|
|
}},
|
|
{"DeleteMessage", func() (any, error) {
|
|
return client.DeleteMessage(ctx, "", "")
|
|
}},
|
|
{"RetrieveMessageFile", func() (any, error) {
|
|
return client.RetrieveMessageFile(ctx, "", "", "")
|
|
}},
|
|
{"ListMessageFiles", func() (any, error) {
|
|
return client.ListMessageFiles(ctx, "", "")
|
|
}},
|
|
{"CreateThread", func() (any, error) {
|
|
return client.CreateThread(ctx, ThreadRequest{})
|
|
}},
|
|
{"RetrieveThread", func() (any, error) {
|
|
return client.RetrieveThread(ctx, "")
|
|
}},
|
|
{"ModifyThread", func() (any, error) {
|
|
return client.ModifyThread(ctx, "", ModifyThreadRequest{})
|
|
}},
|
|
{"DeleteThread", func() (any, error) {
|
|
return client.DeleteThread(ctx, "")
|
|
}},
|
|
{"CreateRun", func() (any, error) {
|
|
return client.CreateRun(ctx, "", RunRequest{})
|
|
}},
|
|
{"RetrieveRun", func() (any, error) {
|
|
return client.RetrieveRun(ctx, "", "")
|
|
}},
|
|
{"ModifyRun", func() (any, error) {
|
|
return client.ModifyRun(ctx, "", "", RunModifyRequest{})
|
|
}},
|
|
{"ListRuns", func() (any, error) {
|
|
return client.ListRuns(ctx, "", Pagination{})
|
|
}},
|
|
{"SubmitToolOutputs", func() (any, error) {
|
|
return client.SubmitToolOutputs(ctx, "", "", SubmitToolOutputsRequest{})
|
|
}},
|
|
{"CancelRun", func() (any, error) {
|
|
return client.CancelRun(ctx, "", "")
|
|
}},
|
|
{"CreateThreadAndRun", func() (any, error) {
|
|
return client.CreateThreadAndRun(ctx, CreateThreadAndRunRequest{})
|
|
}},
|
|
{"RetrieveRunStep", func() (any, error) {
|
|
return client.RetrieveRunStep(ctx, "", "", "")
|
|
}},
|
|
{"ListRunSteps", func() (any, error) {
|
|
return client.ListRunSteps(ctx, "", "", Pagination{})
|
|
}},
|
|
{"CreateSpeech", func() (any, error) {
|
|
return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy})
|
|
}},
|
|
{"CreateBatch", func() (any, error) {
|
|
return client.CreateBatch(ctx, CreateBatchRequest{})
|
|
}},
|
|
{"CreateBatchWithUploadFile", func() (any, error) {
|
|
return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{})
|
|
}},
|
|
{"RetrieveBatch", func() (any, error) {
|
|
return client.RetrieveBatch(ctx, "")
|
|
}},
|
|
{"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }},
|
|
{"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }},
|
|
}
|
|
|
|
for _, testCase := range testCases {
|
|
_, err := testCase.TestFunc()
|
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) {
|
|
config := DefaultConfig(test.GetTestToken())
|
|
client := NewClientWithConfig(config)
|
|
client.requestBuilder = &failingRequestBuilder{}
|
|
ctx := context.Background()
|
|
_, err := client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
|
|
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
}
|
|
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
|
|
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestClient_suffixWithAPIVersion(t *testing.T) {
|
|
type fields struct {
|
|
apiVersion string
|
|
}
|
|
type args struct {
|
|
suffix string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want string
|
|
wantPanic string
|
|
}{
|
|
{
|
|
"",
|
|
fields{apiVersion: "2023-05"},
|
|
args{suffix: "/assistants"},
|
|
"/assistants?api-version=2023-05",
|
|
"",
|
|
},
|
|
{
|
|
"",
|
|
fields{apiVersion: "2023-05"},
|
|
args{suffix: "/assistants?limit=5"},
|
|
"/assistants?api-version=2023-05&limit=5",
|
|
"",
|
|
},
|
|
{
|
|
"",
|
|
fields{apiVersion: "2023-05"},
|
|
args{suffix: "123:assistants?limit=5"},
|
|
"/assistants?api-version=2023-05&limit=5",
|
|
"failed to parse url suffix",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
c := &Client{
|
|
config: ClientConfig{APIVersion: tt.fields.apiVersion},
|
|
}
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
// Check if the panic message matches the expected panic message
|
|
if rStr, ok := r.(string); ok {
|
|
if rStr != tt.wantPanic {
|
|
t.Errorf("suffixWithAPIVersion() = %v, want %v", rStr, tt.wantPanic)
|
|
}
|
|
} else {
|
|
// If the panic is not a string, log it
|
|
t.Errorf("suffixWithAPIVersion() panicked with non-string value: %v", r)
|
|
}
|
|
}
|
|
}()
|
|
if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want {
|
|
t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClient_baseURLWithAzureDeployment(t *testing.T) {
|
|
type args struct {
|
|
baseURL string
|
|
suffix string
|
|
model string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
wantNewBaseURL string
|
|
}{
|
|
{
|
|
"",
|
|
args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini},
|
|
"https://test.openai.azure.com/openai",
|
|
},
|
|
{
|
|
"",
|
|
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini},
|
|
"https://test.openai.azure.com/openai/deployments/gpt-4o-mini",
|
|
},
|
|
{
|
|
"",
|
|
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""},
|
|
"https://test.openai.azure.com/openai/deployments/UNKNOWN",
|
|
},
|
|
}
|
|
client := NewClient("")
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if gotNewBaseURL := client.baseURLWithAzureDeployment(
|
|
tt.args.baseURL,
|
|
tt.args.suffix,
|
|
tt.args.model,
|
|
); gotNewBaseURL != tt.wantNewBaseURL {
|
|
t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL)
|
|
}
|
|
})
|
|
}
|
|
}
|