diff --git a/api_test.go b/api_test.go index a5a0d12..478a274 100644 --- a/api_test.go +++ b/api_test.go @@ -2,6 +2,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "errors" @@ -20,25 +21,17 @@ func TestAPI(t *testing.T) { c := NewClient(apiToken) ctx := context.Background() _, err = c.ListEngines(ctx) - if err != nil { - t.Fatalf("ListEngines error: %v", err) - } + checks.NoError(t, err, "ListEngines error") _, err = c.GetEngine(ctx, "davinci") - if err != nil { - t.Fatalf("GetEngine error: %v", err) - } + checks.NoError(t, err, "GetEngine error") fileRes, err := c.ListFiles(ctx) - if err != nil { - t.Fatalf("ListFiles error: %v", err) - } + checks.NoError(t, err, "ListFiles error") if len(fileRes.Files) > 0 { _, err = c.GetFile(ctx, fileRes.Files[0].ID) - if err != nil { - t.Fatalf("GetFile error: %v", err) - } + checks.NoError(t, err, "GetFile error") } // else skip embeddingReq := EmbeddingRequest{ @@ -49,9 +42,7 @@ func TestAPI(t *testing.T) { Model: AdaSearchQuery, } _, err = c.CreateEmbeddings(ctx, embeddingReq) - if err != nil { - t.Fatalf("Embedding error: %v", err) - } + checks.NoError(t, err, "Embedding error") _, err = c.CreateChatCompletion( ctx, @@ -66,9 +57,7 @@ func TestAPI(t *testing.T) { }, ) - if err != nil { - t.Errorf("CreateChatCompletion (without name) returned error: %v", err) - } + checks.NoError(t, err, "CreateChatCompletion (without name) returned error") _, err = c.CreateChatCompletion( ctx, @@ -83,10 +72,7 @@ func TestAPI(t *testing.T) { }, }, ) - - if err != nil { - t.Errorf("CreateChatCompletion (with name) returned error: %v", err) - } + checks.NoError(t, err, "CreateChatCompletion (with name) returned error") stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ Prompt: "Ex falso quodlibet", @@ -94,9 +80,7 @@ func TestAPI(t *testing.T) { MaxTokens: 5, Stream: true, }) - if err != nil { - t.Errorf("CreateCompletionStream returned error: %v", err) - } + checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() counter := 0 @@ -126,9 +110,7 @@ func TestAPIError(t *testing.T) { c := NewClient(apiToken + "_invalid") ctx := context.Background() _, err = c.ListEngines(ctx) - if err == nil { - t.Fatal("ListEngines did not fail") - } + checks.NoError(t, err, "ListEngines did not fail") var apiErr *APIError if !errors.As(err, &apiErr) { @@ -154,9 +136,7 @@ func TestRequestError(t *testing.T) { c := NewClientWithConfig(config) ctx := context.Background() _, err = c.ListEngines(ctx) - if err == nil { - t.Fatal("ListEngines request did not fail") - } + checks.HasError(t, err, "ListEngines did not fail") var reqErr *RequestError if !errors.As(err, &reqErr) { diff --git a/audio_test.go b/audio_test.go index 8a9b578..0870848 100644 --- a/audio_test.go +++ b/audio_test.go @@ -13,6 +13,7 @@ import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "testing" @@ -62,9 +63,7 @@ func TestAudio(t *testing.T) { Model: "whisper-3", } _, err = tc.createFn(ctx, req) - if err != nil { - t.Fatalf("audio API error: %v", err) - } + checks.NoError(t, err, "audio API error") }) } } @@ -115,9 +114,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { Language: "zh", } _, err = tc.createFn(ctx, req) - if err != nil { - t.Fatalf("audio API error: %v", err) - } + checks.NoError(t, err, "audio API error") }) } } @@ -125,9 +122,8 @@ func TestAudioWithOptionalArgs(t *testing.T) { // createTestFile creates a fake file with "hello" as the content. func createTestFile(t *testing.T, path string) { file, err := os.Create(path) - if err != nil { - t.Fatalf("failed to create file %v", err) - } + checks.NoError(t, err, "failed to create file") + if _, err = file.WriteString("hello"); err != nil { t.Fatalf("failed to write to file %v", err) } @@ -139,9 +135,7 @@ func createTestDirectory(t *testing.T) (path string, cleanup func()) { t.Helper() path, err := os.MkdirTemp(os.TempDir(), "") - if err != nil { - t.Fatal(err) - } + checks.NoError(t, err) return path, func() { os.RemoveAll(path) } } diff --git a/chat_stream_test.go b/chat_stream_test.go index a21ceee..24046db 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" @@ -55,9 +56,7 @@ func TestCreateChatCompletionStream(t *testing.T) { dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) _, err := w.Write(dataBytes) - if err != nil { - t.Errorf("Write error: %s", err) - } + checks.NoError(t, err, "Write error") })) defer server.Close() @@ -85,9 +84,7 @@ func TestCreateChatCompletionStream(t *testing.T) { } stream, err := client.CreateChatCompletionStream(ctx, request) - if err != nil { - t.Errorf("CreateCompletionStream returned error: %v", err) - } + checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() expectedResponses := []ChatCompletionStreamResponse{ @@ -126,9 +123,7 @@ func TestCreateChatCompletionStream(t *testing.T) { t.Logf("%d: %s", ix, string(b)) receivedResponse, streamErr := stream.Recv() - if streamErr != nil { - t.Errorf("stream.Recv() failed: %v", streamErr) - } + checks.NoError(t, streamErr, "stream.Recv() failed") if !compareChatResponses(expectedResponse, receivedResponse) { t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) } @@ -140,6 +135,8 @@ func TestCreateChatCompletionStream(t *testing.T) { } _, streamErr = stream.Recv() + + checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished") if !errors.Is(streamErr, io.EOF) { t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) } @@ -166,9 +163,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { } _, err := w.Write(dataBytes) - if err != nil { - t.Errorf("Write error: %s", err) - } + checks.NoError(t, err, "Write error") })) defer server.Close() @@ -196,15 +191,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) { } stream, err := client.CreateChatCompletionStream(ctx, request) - if err != nil { - t.Errorf("CreateCompletionStream returned error: %v", err) - } + checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() _, streamErr := stream.Recv() - if streamErr == nil { - t.Errorf("stream.Recv() did not return error") - } + checks.HasError(t, streamErr, "stream.Recv() did not return error") + var apiErr *APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") diff --git a/chat_test.go b/chat_test.go index 2d569a4..ce302a6 100644 --- a/chat_test.go +++ b/chat_test.go @@ -3,10 +3,10 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -33,9 +33,8 @@ func TestChatCompletionsWrongModel(t *testing.T) { }, } _, err := client.CreateChatCompletion(ctx, req) - if !errors.Is(err, ErrChatCompletionInvalidModel) { - t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) - } + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg) } func TestChatCompletionsWithStream(t *testing.T) { @@ -48,9 +47,7 @@ func TestChatCompletionsWithStream(t *testing.T) { Stream: true, } _, err := client.CreateChatCompletion(ctx, req) - if !errors.Is(err, ErrChatCompletionStreamNotSupported) { - t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error") - } + checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error") } // TestCompletions Tests the completions endpoint of the API using the mocked server. @@ -79,9 +76,7 @@ func TestChatCompletions(t *testing.T) { }, } _, err = client.CreateChatCompletion(ctx, req) - if err != nil { - t.Fatalf("CreateChatCompletion error: %v", err) - } + checks.NoError(t, err, "CreateChatCompletion error") } // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. diff --git a/completion_test.go b/completion_test.go index daa02e3..ce95faf 100644 --- a/completion_test.go +++ b/completion_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" @@ -66,9 +67,7 @@ func TestCompletions(t *testing.T) { } req.Prompt = "Lorem ipsum" _, err = client.CreateCompletion(ctx, req) - if err != nil { - t.Fatalf("CreateCompletion error: %v", err) - } + checks.NoError(t, err, "CreateCompletion error") } // handleCompletionEndpoint Handles the completion endpoint by the test server. diff --git a/edits_test.go b/edits_test.go index 6a16f7c..fa6c128 100644 --- a/edits_test.go +++ b/edits_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" @@ -40,9 +41,7 @@ func TestEdits(t *testing.T) { N: 3, } response, err := client.Edits(ctx, editReq) - if err != nil { - t.Fatalf("Edits error: %v", err) - } + checks.NoError(t, err, "Edits error") if len(response.Choices) != editReq.N { t.Fatalf("edits does not properly return the correct number of choices") } diff --git a/embeddings_test.go b/embeddings_test.go index 2aa48c5..0259cea 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -2,6 +2,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" "encoding/json" @@ -38,9 +39,7 @@ func TestEmbedding(t *testing.T) { // marshal embeddingReq to JSON and confirm that the model field equals // the AdaSearchQuery type marshaled, err := json.Marshal(embeddingReq) - if err != nil { - t.Fatalf("Could not marshal embedding request: %v", err) - } + checks.NoError(t, err, "Could not marshal embedding request") if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { t.Fatalf("Expected embedding request to contain model field") } diff --git a/error_accumulator_test.go b/error_accumulator_test.go index 4dabc1e..637bf36 100644 --- a/error_accumulator_test.go +++ b/error_accumulator_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) var ( @@ -81,16 +82,13 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) { ctx := context.Background() stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - if err != nil { - t.Fatal(err) - } + checks.NoError(t, err) + stream.errAccumulator = &defaultErrorAccumulator{ buffer: &failingErrorBuffer{}, unmarshaler: &jsonUnmarshaler{}, } _, err = stream.Recv() - if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) - } + checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) } diff --git a/files_test.go b/files_test.go index 6a78ce1..3e8dfc4 100644 --- a/files_test.go +++ b/files_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" @@ -33,9 +34,7 @@ func TestFileUpload(t *testing.T) { Purpose: "fine-tune", } _, err = client.CreateFile(ctx, req) - if err != nil { - t.Fatalf("CreateFile error: %v", err) - } + checks.NoError(t, err, "CreateFile erro") } // handleCreateFile Handles the images endpoint by the test server. diff --git a/fine_tunes_test.go b/fine_tunes_test.go index 1f6f967..c602549 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" @@ -70,32 +71,20 @@ func TestFineTunes(t *testing.T) { ctx := context.Background() _, err = client.ListFineTunes(ctx) - if err != nil { - t.Fatalf("ListFineTunes error: %v", err) - } + checks.NoError(t, err, "ListFineTunes error") _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if err != nil { - t.Fatalf("CreateFineTune error: %v", err) - } + checks.NoError(t, err, "CreateFineTune error") _, err = client.CancelFineTune(ctx, testFineTuneID) - if err != nil { - t.Fatalf("CancelFineTune error: %v", err) - } + checks.NoError(t, err, "CancelFineTune error") _, err = client.GetFineTune(ctx, testFineTuneID) - if err != nil { - t.Fatalf("GetFineTune error: %v", err) - } + checks.NoError(t, err, "GetFineTune error") _, err = client.DeleteFineTune(ctx, testFineTuneID) - if err != nil { - t.Fatalf("DeleteFineTune error: %v", err) - } + checks.NoError(t, err, "DeleteFineTune error") _, err = client.ListFineTuneEvents(ctx, testFineTuneID) - if err != nil { - t.Fatalf("ListFineTuneEvents error: %v", err) - } + checks.NoError(t, err, "ListFineTuneEvents error") } diff --git a/image_test.go b/image_test.go index b7949c8..9917b78 100644 --- a/image_test.go +++ b/image_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" @@ -31,9 +32,7 @@ func TestImages(t *testing.T) { req := ImageRequest{} req.Prompt = "Lorem ipsum" _, err = client.CreateImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } + checks.NoError(t, err, "CreateImage error") } // handleImageEndpoint Handles the images endpoint by the test server. @@ -127,9 +126,7 @@ func TestImageEdit(t *testing.T) { Size: CreateImageSize1024x1024, } _, err = client.CreateEditImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } + checks.NoError(t, err, "CreateImage error") } func TestImageEditWithoutMask(t *testing.T) { @@ -164,9 +161,7 @@ func TestImageEditWithoutMask(t *testing.T) { Size: CreateImageSize1024x1024, } _, err = client.CreateEditImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } + checks.NoError(t, err, "CreateImage error") } // handleEditImageEndpoint Handles the images endpoint by the test server. @@ -231,9 +226,7 @@ func TestImageVariation(t *testing.T) { Size: CreateImageSize1024x1024, } _, err = client.CreateVariImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } + checks.NoError(t, err, "CreateImage error") } // handleVariateImageEndpoint Handles the images endpoint by the test server. diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go new file mode 100644 index 0000000..7133691 --- /dev/null +++ b/internal/test/checks/checks.go @@ -0,0 +1,48 @@ +package checks + +import ( + "errors" + "testing" +) + +func NoError(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Error(err, message) + } +} + +func HasError(t *testing.T, err error, message ...string) { + t.Helper() + if err == nil { + t.Error(err, message) + } +} + +func ErrorIs(t *testing.T, err, target error, msg ...string) { + t.Helper() + if !errors.Is(err, target) { + t.Fatal(msg) + } +} + +func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) { + t.Helper() + if !errors.Is(err, target) { + t.Fatalf(format, msg) + } +} + +func ErrorIsNot(t *testing.T, err, target error, msg ...string) { + t.Helper() + if errors.Is(err, target) { + t.Fatal(msg) + } +} + +func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) { + t.Helper() + if errors.Is(err, target) { + t.Fatalf(format, msg) + } +} diff --git a/models_test.go b/models_test.go index 972a5fe..dad59be 100644 --- a/models_test.go +++ b/models_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" @@ -27,9 +28,7 @@ func TestListModels(t *testing.T) { ctx := context.Background() _, err = client.ListModels(ctx) - if err != nil { - t.Fatalf("ListModels error: %v", err) - } + checks.NoError(t, err, "ListModels error") } // handleModelsEndpoint Handles the models endpoint by the test server. diff --git a/moderation_test.go b/moderation_test.go index f501245..3535bc8 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "encoding/json" @@ -37,9 +38,7 @@ func TestModerations(t *testing.T) { Input: "I want to kill them.", } _, err = client.Moderations(ctx, moderationReq) - if err != nil { - t.Fatalf("Moderation error: %v", err) - } + checks.NoError(t, err, "Moderation error") } // handleModerationEndpoint Handles the moderation endpoint by the test server. diff --git a/stream_test.go b/stream_test.go index 7d01ebd..a80504d 100644 --- a/stream_test.go +++ b/stream_test.go @@ -3,6 +3,7 @@ package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" "context" "errors" @@ -49,9 +50,7 @@ func TestCreateCompletionStream(t *testing.T) { dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) _, err := w.Write(dataBytes) - if err != nil { - t.Errorf("Write error: %s", err) - } + checks.NoError(t, err, "Write error") })) defer server.Close() @@ -74,9 +73,7 @@ func TestCreateCompletionStream(t *testing.T) { } stream, err := client.CreateCompletionStream(ctx, request) - if err != nil { - t.Errorf("CreateCompletionStream returned error: %v", err) - } + checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() expectedResponses := []CompletionResponse{ @@ -138,9 +135,7 @@ func TestCreateCompletionStreamError(t *testing.T) { } _, err := w.Write(dataBytes) - if err != nil { - t.Errorf("Write error: %s", err) - } + checks.NoError(t, err, "Write error") })) defer server.Close() @@ -163,15 +158,12 @@ func TestCreateCompletionStreamError(t *testing.T) { } stream, err := client.CreateCompletionStream(ctx, request) - if err != nil { - t.Errorf("CreateCompletionStream returned error: %v", err) - } + checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() _, streamErr := stream.Recv() - if streamErr == nil { - t.Errorf("stream.Recv() did not return error") - } + checks.HasError(t, streamErr, "stream.Recv() did not return error") + var apiErr *APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError")