From 7a2915a37dae714f40a4b5575fbf98430fe1d6aa Mon Sep 17 00:00:00 2001 From: Oleksandr Redko Date: Fri, 31 Jan 2025 20:55:41 +0200 Subject: [PATCH] Simplify tests with T.TempDir (#929) --- .golangci.yml | 1 + audio_api_test.go | 10 ++------- audio_test.go | 8 ++----- image_api_test.go | 42 +++++++++++------------------------ internal/form_builder_test.go | 17 ++++---------- internal/test/helpers.go | 10 --------- openai_test.go | 2 +- speech_test.go | 4 +--- 8 files changed, 24 insertions(+), 70 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 724cb73..9d22d9b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -206,6 +206,7 @@ linters: - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters + - usetesting # Reports uses of functions with replacement inside the testing package - wastedassign # wastedassign finds wasted assignment statements. - whitespace # Tool for detection of leading and trailing whitespace ## you may want to enable diff --git a/audio_api_test.go b/audio_api_test.go index c245984..6c6a356 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -40,12 +40,9 @@ func TestAudio(t *testing.T) { ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := openai.AudioRequest{ @@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) { ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := openai.AudioRequest{ diff --git a/audio_test.go b/audio_test.go index 235931f..9f32d54 100644 --- a/audio_test.go +++ b/audio_test.go @@ -13,9 +13,7 @@ import ( ) func TestAudioWithFailingFormBuilder(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := AudioRequest{ @@ -63,9 +61,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { func TestCreateFileField(t *testing.T) { t.Run("createFileField failing file", func(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := AudioRequest{ diff --git a/image_api_test.go b/image_api_test.go index 48416b1..f6057b7 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "path/filepath" "testing" "time" @@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } + defer origin.Close() - mask, err := os.Create("mask.png") + mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png")) if err != nil { - t.Error("open mask file error") - return + t.Fatalf("open mask file error: %v", err) } - - defer func() { - mask.Close() - origin.Close() - os.Remove("mask.png") - os.Remove("image.png") - }() + defer mask.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, @@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } - - defer func() { - origin.Close() - os.Remove("image.png") - }() + defer origin.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, @@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } - - defer func() { - origin.Close() - os.Remove("image.png") - }() + defer origin.Close() _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ Image: origin, diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index d3faf99..8df989e 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,7 +1,6 @@ package openai //nolint:testpackage // testing private field import ( - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -20,15 +19,11 @@ func (*failingWriter) Write([]byte) (int, error) { } func TestFormBuilderWithFailingWriter(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") + file, err := os.CreateTemp(t.TempDir(), "") if err != nil { - t.Errorf("Error creating tmp file: %v", err) + t.Fatalf("Error creating tmp file: %v", err) } defer file.Close() - defer os.Remove(file.Name()) builder := NewFormBuilder(&failingWriter{}) err = builder.CreateFormFile("file", file) @@ -36,15 +31,11 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { } func TestFormBuilderWithClosedFile(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") + file, err := os.CreateTemp(t.TempDir(), "") if err != nil { - t.Errorf("Error creating tmp file: %v", err) + t.Fatalf("Error creating tmp file: %v", err) } file.Close() - defer os.Remove(file.Name()) body := &bytes.Buffer{} builder := NewFormBuilder(body) diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 0e63ae8..dc5fa66 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) { file.Close() } -// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called. -func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { - t.Helper() - - path, err := os.MkdirTemp(os.TempDir(), "") - checks.NoError(t, err) - - return path, func() { os.RemoveAll(path) } -} - // TokenRoundTripper is a struct that implements the RoundTripper // interface, specifically to handle the authentication token by adding a token // to the request header. We need this because the API requires that each diff --git a/openai_test.go b/openai_test.go index 729d888..48a00b9 100644 --- a/openai_test.go +++ b/openai_test.go @@ -31,7 +31,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer // -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { return int(float32(len(s)) / 4) } diff --git a/speech_test.go b/speech_test.go index f1e405c..67a3fea 100644 --- a/speech_test.go +++ b/speech_test.go @@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) { defer teardown() server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { - dir, cleanup := test.CreateTestDirectory(t) - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) - defer cleanup() // audio endpoints only accept POST requests if r.Method != "POST" {