* Implement optional io.Reader in AudioRequest (#303) (#265) * Fix err shadowing * Add test to cover AudioRequest io.Reader usage * Add additional test cases to cover AudioRequest io.Reader usage * Add test to cover opening the file specified in an AudioRequest
This commit is contained in:
45
audio.go
45
audio.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -27,8 +28,14 @@ const (
|
|||||||
// AudioRequest represents a request structure for audio API.
|
// AudioRequest represents a request structure for audio API.
|
||||||
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
|
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
|
||||||
type AudioRequest struct {
|
type AudioRequest struct {
|
||||||
Model string
|
Model string
|
||||||
FilePath string
|
|
||||||
|
// FilePath is either an existing file in your filesystem or a filename representing the contents of Reader.
|
||||||
|
FilePath string
|
||||||
|
|
||||||
|
// Reader is an optional io.Reader when you do not want to use an existing file.
|
||||||
|
Reader io.Reader
|
||||||
|
|
||||||
Prompt string // For translation, it should be in English
|
Prompt string // For translation, it should be in English
|
||||||
Temperature float32
|
Temperature float32
|
||||||
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
|
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
|
||||||
@@ -95,15 +102,9 @@ func (r AudioRequest) HasJSONResponse() bool {
|
|||||||
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
||||||
// audio processing.
|
// audio processing.
|
||||||
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
|
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
|
||||||
f, err := os.Open(request.FilePath)
|
err := createFileField(request, b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("opening audio file: %w", err)
|
return err
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
err = b.CreateFormFile("file", f)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating form file: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = b.WriteField("model", request.Model)
|
err = b.WriteField("model", request.Model)
|
||||||
@@ -146,3 +147,27 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
|
|||||||
// Close the multipart writer
|
// Close the multipart writer
|
||||||
return b.Close()
|
return b.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createFileField creates the "file" form field from either an existing file or by using the reader.
|
||||||
|
func createFileField(request AudioRequest, b utils.FormBuilder) error {
|
||||||
|
if request.Reader != nil {
|
||||||
|
err := b.CreateFormFileReader("file", request.Reader, request.FilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating form using reader: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(request.FilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("opening audio file: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
err = b.CreateFormFile("file", f)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating form file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -11,12 +12,10 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
|
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
|
||||||
@@ -65,6 +64,16 @@ func TestAudio(t *testing.T) {
|
|||||||
_, err = tc.createFn(ctx, req)
|
_, err = tc.createFn(ctx, req)
|
||||||
checks.NoError(t, err, "audio API error")
|
checks.NoError(t, err, "audio API error")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run(tc.name+" (with reader)", func(t *testing.T) {
|
||||||
|
req := AudioRequest{
|
||||||
|
FilePath: "fake.webm",
|
||||||
|
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
|
||||||
|
Model: "whisper-3",
|
||||||
|
}
|
||||||
|
_, err = tc.createFn(ctx, req)
|
||||||
|
checks.NoError(t, err, "audio API error")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,3 +222,54 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
|||||||
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
|
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateFileField(t *testing.T) {
|
||||||
|
t.Run("createFileField failing file", func(t *testing.T) {
|
||||||
|
dir, cleanup := test.CreateTestDirectory(t)
|
||||||
|
defer cleanup()
|
||||||
|
path := filepath.Join(dir, "fake.mp3")
|
||||||
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
|
req := AudioRequest{
|
||||||
|
FilePath: path,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||||
|
mockBuilder := &mockFormBuilder{
|
||||||
|
mockCreateFormFile: func(string, *os.File) error {
|
||||||
|
return mockFailedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := createFileField(req, mockBuilder)
|
||||||
|
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("createFileField failing reader", func(t *testing.T) {
|
||||||
|
req := AudioRequest{
|
||||||
|
FilePath: "test.wav",
|
||||||
|
Reader: bytes.NewBuffer([]byte(`wav test contents`)),
|
||||||
|
}
|
||||||
|
|
||||||
|
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||||
|
mockBuilder := &mockFormBuilder{
|
||||||
|
mockCreateFormFileReader: func(string, io.Reader, string) error {
|
||||||
|
return mockFailedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := createFileField(req, mockBuilder)
|
||||||
|
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("createFileField failing open", func(t *testing.T) {
|
||||||
|
req := AudioRequest{
|
||||||
|
FilePath: "non_existing_file.wav",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockBuilder := &mockFormBuilder{}
|
||||||
|
|
||||||
|
err := createFileField(req, mockBuilder)
|
||||||
|
checks.HasError(t, err, "createFileField using file should return error when open file fails")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -264,15 +264,20 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockFormBuilder struct {
|
type mockFormBuilder struct {
|
||||||
mockCreateFormFile func(string, *os.File) error
|
mockCreateFormFile func(string, *os.File) error
|
||||||
mockWriteField func(string, string) error
|
mockCreateFormFileReader func(string, io.Reader, string) error
|
||||||
mockClose func() error
|
mockWriteField func(string, string) error
|
||||||
|
mockClose func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
||||||
return fb.mockCreateFormFile(fieldname, file)
|
return fb.mockCreateFormFile(fieldname, file)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||||
|
return fb.mockCreateFormFileReader(fieldname, r, filename)
|
||||||
|
}
|
||||||
|
|
||||||
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
|
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
|
||||||
return fb.mockWriteField(fieldname, value)
|
return fb.mockWriteField(fieldname, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FormBuilder interface {
|
type FormBuilder interface {
|
||||||
CreateFormFile(fieldname string, file *os.File) error
|
CreateFormFile(fieldname string, file *os.File) error
|
||||||
|
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
|
||||||
WriteField(fieldname, value string) error
|
WriteField(fieldname, value string) error
|
||||||
Close() error
|
Close() error
|
||||||
FormDataContentType() string
|
FormDataContentType() string
|
||||||
@@ -24,15 +27,28 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
||||||
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
|
return fb.createFormFile(fieldname, file, file.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||||
|
return fb.createFormFile(fieldname, r, path.Base(filename))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
|
||||||
|
if filename == "" {
|
||||||
|
return fmt.Errorf("filename cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = io.Copy(fieldWriter, file)
|
_, err = io.Copy(fieldWriter, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user