Move form_builder into internal pkg. (#311)

* Move form_uilder into internal pkg.

* Fix import of audio.go

* Reorganize.

* Fix import.

* Fix

---------

Co-authored-by: JoyShi <joy.shi@sap.com>
This commit is contained in:
JoyShi
2023-05-17 04:38:09 +08:00
committed by GitHub
parent 83d03fca52
commit 21eef5bc8d
9 changed files with 96 additions and 90 deletions

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"net/http"
"os"
utils "github.com/sashabaranov/go-openai/internal"
)
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
@@ -72,7 +74,7 @@ func (c *Client) callAudioAPI(
if err != nil {
return AudioResponse{}, err
}
req.Header.Add("Content-Type", builder.formDataContentType())
req.Header.Add("Content-Type", builder.FormDataContentType())
if request.HasJSONResponse() {
err = c.sendRequest(req, &response)
@@ -92,26 +94,26 @@ func (r AudioRequest) HasJSONResponse() bool {
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing.
func audioMultipartForm(request AudioRequest, b formBuilder) error {
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()
err = b.createFormFile("file", f)
err = b.CreateFormFile("file", f)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}
err = b.writeField("model", request.Model)
err = b.WriteField("model", request.Model)
if err != nil {
return fmt.Errorf("writing model name: %w", err)
}
// Create a form field for the prompt (if provided)
if request.Prompt != "" {
err = b.writeField("prompt", request.Prompt)
err = b.WriteField("prompt", request.Prompt)
if err != nil {
return fmt.Errorf("writing prompt: %w", err)
}
@@ -119,7 +121,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
// Create a form field for the format (if provided)
if request.Format != "" {
err = b.writeField("response_format", string(request.Format))
err = b.WriteField("response_format", string(request.Format))
if err != nil {
return fmt.Errorf("writing format: %w", err)
}
@@ -127,7 +129,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
// Create a form field for the temperature (if provided)
if request.Temperature != 0 {
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature))
if err != nil {
return fmt.Errorf("writing temperature: %w", err)
}
@@ -135,12 +137,12 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
// Create a form field for the language (if provided)
if request.Language != "" {
err = b.writeField("language", request.Language)
err = b.WriteField("language", request.Language)
if err != nil {
return fmt.Errorf("writing language: %w", err)
}
}
// Close the multipart writer
return b.close()
return b.Close()
}

View File

@@ -7,6 +7,8 @@ import (
"io"
"net/http"
"strings"
utils "github.com/sashabaranov/go-openai/internal"
)
// Client is OpenAI GPT-3 API client.
@@ -14,7 +16,7 @@ type Client struct {
config ClientConfig
requestBuilder requestBuilder
createFormBuilder func(io.Writer) formBuilder
createFormBuilder func(io.Writer) utils.FormBuilder
}
// NewClient creates new OpenAI API client.
@@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client {
return &Client{
config: config,
requestBuilder: newRequestBuilder(),
createFormBuilder: func(body io.Writer) formBuilder {
return newFormBuilder(body)
createFormBuilder: func(body io.Writer) utils.FormBuilder {
return utils.NewFormBuilder(body)
},
}
}

View File

@@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
var b bytes.Buffer
builder := c.createFormBuilder(&b)
err = builder.writeField("purpose", request.Purpose)
err = builder.WriteField("purpose", request.Purpose)
if err != nil {
return
}
@@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return
}
err = builder.createFormFile("file", fileData)
err = builder.CreateFormFile("file", fileData)
if err != nil {
return
}
err = builder.close()
err = builder.Close()
if err != nil {
return
}
@@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return
}
req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &file)

View File

@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field
import (
. "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
@@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config.BaseURL = ""
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) FormBuilder {
return mockBuilder
}

View File

@@ -1,49 +0,0 @@
package openai
import (
"io"
"mime/multipart"
"os"
)
type formBuilder interface {
createFormFile(fieldname string, file *os.File) error
writeField(fieldname, value string) error
close() error
formDataContentType() string
}
type defaultFormBuilder struct {
writer *multipart.Writer
}
func newFormBuilder(body io.Writer) *defaultFormBuilder {
return &defaultFormBuilder{
writer: multipart.NewWriter(body),
}
}
func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error {
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, file)
if err != nil {
return err
}
return nil
}
func (fb *defaultFormBuilder) writeField(fieldname, value string) error {
return fb.writer.WriteField(fieldname, value)
}
func (fb *defaultFormBuilder) close() error {
return fb.writer.Close()
}
func (fb *defaultFormBuilder) formDataContentType() string {
return fb.writer.FormDataContentType()
}

View File

@@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
builder := c.createFormBuilder(body)
// image
err = builder.createFormFile("image", request.Image)
err = builder.CreateFormFile("image", request.Image)
if err != nil {
return
}
// mask, it is optional
if request.Mask != nil {
err = builder.createFormFile("mask", request.Mask)
err = builder.CreateFormFile("mask", request.Mask)
if err != nil {
return
}
}
err = builder.writeField("prompt", request.Prompt)
err = builder.WriteField("prompt", request.Prompt)
if err != nil {
return
}
err = builder.writeField("n", strconv.Itoa(request.N))
err = builder.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
}
err = builder.writeField("size", request.Size)
err = builder.WriteField("size", request.Size)
if err != nil {
return
}
err = builder.writeField("response_format", request.ResponseFormat)
err = builder.WriteField("response_format", request.ResponseFormat)
if err != nil {
return
}
err = builder.close()
err = builder.Close()
if err != nil {
return
}
@@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return
}
req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &response)
return
}
@@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
builder := c.createFormBuilder(body)
// image
err = builder.createFormFile("image", request.Image)
err = builder.CreateFormFile("image", request.Image)
if err != nil {
return
}
err = builder.writeField("n", strconv.Itoa(request.N))
err = builder.WriteField("n", strconv.Itoa(request.N))
if err != nil {
return
}
err = builder.writeField("size", request.Size)
err = builder.WriteField("size", request.Size)
if err != nil {
return
}
err = builder.writeField("response_format", request.ResponseFormat)
err = builder.WriteField("response_format", request.ResponseFormat)
if err != nil {
return
}
err = builder.close()
err = builder.Close()
if err != nil {
return
}
@@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
return
}
req.Header.Set("Content-Type", builder.formDataContentType())
req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &response)
return
}

View File

@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field
import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
@@ -268,19 +269,19 @@ type mockFormBuilder struct {
mockClose func() error
}
func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error {
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
return fb.mockCreateFormFile(fieldname, file)
}
func (fb *mockFormBuilder) writeField(fieldname, value string) error {
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
return fb.mockWriteField(fieldname, value)
}
func (fb *mockFormBuilder) close() error {
func (fb *mockFormBuilder) Close() error {
return fb.mockClose()
}
func (fb *mockFormBuilder) formDataContentType() string {
func (fb *mockFormBuilder) FormDataContentType() string {
return ""
}
@@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) {
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}
ctx := context.Background()
@@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder {
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}
ctx := context.Background()

49
internal/form_builder.go Normal file
View File

@@ -0,0 +1,49 @@
package openai
import (
"io"
"mime/multipart"
"os"
)
type FormBuilder interface {
CreateFormFile(fieldname string, file *os.File) error
WriteField(fieldname, value string) error
Close() error
FormDataContentType() string
}
type DefaultFormBuilder struct {
writer *multipart.Writer
}
func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
return &DefaultFormBuilder{
writer: multipart.NewWriter(body),
}
}
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, file)
if err != nil {
return err
}
return nil
}
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
return fb.writer.WriteField(fieldname, value)
}
func (fb *DefaultFormBuilder) Close() error {
return fb.writer.Close()
}
func (fb *DefaultFormBuilder) FormDataContentType() string {
return fb.writer.FormDataContentType()
}

View File

@@ -30,8 +30,8 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
defer file.Close()
defer os.Remove(file.Name())
builder := newFormBuilder(&failingWriter{})
err = builder.createFormFile("file", file)
builder := NewFormBuilder(&failingWriter{})
err = builder.CreateFormFile("file", file)
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
}
@@ -47,8 +47,8 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
defer os.Remove(file.Name())
body := &bytes.Buffer{}
builder := newFormBuilder(body)
err = builder.createFormFile("file", file)
builder := NewFormBuilder(body)
err = builder.CreateFormFile("file", file)
checks.HasError(t, err, "formbuilder should return error if file is closed")
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
}