From fa694c61c2196e471e7bedc689b546aa52caf527 Mon Sep 17 00:00:00 2001 From: Mariano Darc Date: Mon, 5 Jun 2023 08:07:13 +0200 Subject: [PATCH] Implement optional io.Reader in AudioRequest (#303) (#265) (#331) * Implement optional io.Reader in AudioRequest (#303) (#265) * Fix err shadowing * Add test to cover AudioRequest io.Reader usage * Add additional test cases to cover AudioRequest io.Reader usage * Add test to cover opening the file specified in an AudioRequest --- audio.go | 45 +++++++++++++++++++++------ audio_test.go | 66 ++++++++++++++++++++++++++++++++++++++-- image_test.go | 11 +++++-- internal/form_builder.go | 20 ++++++++++-- 4 files changed, 124 insertions(+), 18 deletions(-) diff --git a/audio.go b/audio.go index bf23653..20e865f 100644 --- a/audio.go +++ b/audio.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "net/http" "os" @@ -27,8 +28,14 @@ const ( // AudioRequest represents a request structure for audio API. // ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { - Model string - FilePath string + Model string + + // FilePath is either an existing file in your filesystem or a filename representing the contents of Reader. + FilePath string + + // Reader is an optional io.Reader when you do not want to use an existing file. + Reader io.Reader + Prompt string // For translation, it should be in English Temperature float32 Language string // For translation, just do not use it. It seems "en" works, not confirmed... @@ -95,15 +102,9 @@ func (r AudioRequest) HasJSONResponse() bool { // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { - f, err := os.Open(request.FilePath) + err := createFileField(request, b) if err != nil { - return fmt.Errorf("opening audio file: %w", err) - } - defer f.Close() - - err = b.CreateFormFile("file", f) - if err != nil { - return fmt.Errorf("creating form file: %w", err) + return err } err = b.WriteField("model", request.Model) @@ -146,3 +147,27 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { // Close the multipart writer return b.Close() } + +// createFileField creates the "file" form field from either an existing file or by using the reader. +func createFileField(request AudioRequest, b utils.FormBuilder) error { + if request.Reader != nil { + err := b.CreateFormFileReader("file", request.Reader, request.FilePath) + if err != nil { + return fmt.Errorf("creating form using reader: %w", err) + } + return nil + } + + f, err := os.Open(request.FilePath) + if err != nil { + return fmt.Errorf("opening audio file: %w", err) + } + defer f.Close() + + err = b.CreateFormFile("file", f) + if err != nil { + return fmt.Errorf("creating form file: %w", err) + } + + return nil +} diff --git a/audio_test.go b/audio_test.go index daf51f2..6452e2e 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" "errors" "fmt" "io" @@ -11,12 +12,10 @@ import ( "os" "path/filepath" "strings" + "testing" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" - - "context" - "testing" ) // TestAudio Tests the transcription and translation endpoints of the API using the mocked server. @@ -65,6 +64,16 @@ func TestAudio(t *testing.T) { _, err = tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") }) + + t.Run(tc.name+" (with reader)", func(t *testing.T) { + req := AudioRequest{ + FilePath: "fake.webm", + Reader: bytes.NewBuffer([]byte(`some webm binary data`)), + Model: "whisper-3", + } + _, err = tc.createFn(ctx, req) + checks.NoError(t, err, "audio API error") + }) } } @@ -213,3 +222,54 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") } } + +func TestCreateFileField(t *testing.T) { + t.Run("createFileField failing file", func(t *testing.T) { + dir, cleanup := test.CreateTestDirectory(t) + defer cleanup() + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + } + + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder := &mockFormBuilder{ + mockCreateFormFile: func(string, *os.File) error { + return mockFailedErr + }, + } + + err := createFileField(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails") + }) + + t.Run("createFileField failing reader", func(t *testing.T) { + req := AudioRequest{ + FilePath: "test.wav", + Reader: bytes.NewBuffer([]byte(`wav test contents`)), + } + + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { + return mockFailedErr + }, + } + + err := createFileField(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails") + }) + + t.Run("createFileField failing open", func(t *testing.T) { + req := AudioRequest{ + FilePath: "non_existing_file.wav", + } + + mockBuilder := &mockFormBuilder{} + + err := createFileField(req, mockBuilder) + checks.HasError(t, err, "createFileField using file should return error when open file fails") + }) +} diff --git a/image_test.go b/image_test.go index 5cf6a26..ca9faed 100644 --- a/image_test.go +++ b/image_test.go @@ -264,15 +264,20 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { } type mockFormBuilder struct { - mockCreateFormFile func(string, *os.File) error - mockWriteField func(string, string) error - mockClose func() error + mockCreateFormFile func(string, *os.File) error + mockCreateFormFileReader func(string, io.Reader, string) error + mockWriteField func(string, string) error + mockClose func() error } func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { return fb.mockCreateFormFile(fieldname, file) } +func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + return fb.mockCreateFormFileReader(fieldname, r, filename) +} + func (fb *mockFormBuilder) WriteField(fieldname, value string) error { return fb.mockWriteField(fieldname, value) } diff --git a/internal/form_builder.go b/internal/form_builder.go index 359dd7e..2224fad 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -1,13 +1,16 @@ package openai import ( + "fmt" "io" "mime/multipart" "os" + "path" ) type FormBuilder interface { CreateFormFile(fieldname string, file *os.File) error + CreateFormFileReader(fieldname string, r io.Reader, filename string) error WriteField(fieldname, value string) error Close() error FormDataContentType() string @@ -24,15 +27,28 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder { } func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { - fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) + return fb.createFormFile(fieldname, file, file.Name()) +} + +func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + return fb.createFormFile(fieldname, r, path.Base(filename)) +} + +func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { + if filename == "" { + return fmt.Errorf("filename cannot be empty") + } + + fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename) if err != nil { return err } - _, err = io.Copy(fieldWriter, file) + _, err = io.Copy(fieldWriter, r) if err != nil { return err } + return nil }