Move form_builder into internal pkg. (#311)
* Move form_uilder into internal pkg. * Fix import of audio.go * Reorganize. * Fix import. * Fix --------- Co-authored-by: JoyShi <joy.shi@sap.com>
This commit is contained in:
20
audio.go
20
audio.go
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
)
|
||||
|
||||
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
|
||||
@@ -72,7 +74,7 @@ func (c *Client) callAudioAPI(
|
||||
if err != nil {
|
||||
return AudioResponse{}, err
|
||||
}
|
||||
req.Header.Add("Content-Type", builder.formDataContentType())
|
||||
req.Header.Add("Content-Type", builder.FormDataContentType())
|
||||
|
||||
if request.HasJSONResponse() {
|
||||
err = c.sendRequest(req, &response)
|
||||
@@ -92,26 +94,26 @@ func (r AudioRequest) HasJSONResponse() bool {
|
||||
|
||||
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
||||
// audio processing.
|
||||
func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
||||
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
|
||||
f, err := os.Open(request.FilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening audio file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
err = b.createFormFile("file", f)
|
||||
err = b.CreateFormFile("file", f)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form file: %w", err)
|
||||
}
|
||||
|
||||
err = b.writeField("model", request.Model)
|
||||
err = b.WriteField("model", request.Model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing model name: %w", err)
|
||||
}
|
||||
|
||||
// Create a form field for the prompt (if provided)
|
||||
if request.Prompt != "" {
|
||||
err = b.writeField("prompt", request.Prompt)
|
||||
err = b.WriteField("prompt", request.Prompt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing prompt: %w", err)
|
||||
}
|
||||
@@ -119,7 +121,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
||||
|
||||
// Create a form field for the format (if provided)
|
||||
if request.Format != "" {
|
||||
err = b.writeField("response_format", string(request.Format))
|
||||
err = b.WriteField("response_format", string(request.Format))
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing format: %w", err)
|
||||
}
|
||||
@@ -127,7 +129,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
||||
|
||||
// Create a form field for the temperature (if provided)
|
||||
if request.Temperature != 0 {
|
||||
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
|
||||
err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature))
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing temperature: %w", err)
|
||||
}
|
||||
@@ -135,12 +137,12 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
||||
|
||||
// Create a form field for the language (if provided)
|
||||
if request.Language != "" {
|
||||
err = b.writeField("language", request.Language)
|
||||
err = b.WriteField("language", request.Language)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing language: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer
|
||||
return b.close()
|
||||
return b.Close()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
)
|
||||
|
||||
// Client is OpenAI GPT-3 API client.
|
||||
@@ -14,7 +16,7 @@ type Client struct {
|
||||
config ClientConfig
|
||||
|
||||
requestBuilder requestBuilder
|
||||
createFormBuilder func(io.Writer) formBuilder
|
||||
createFormBuilder func(io.Writer) utils.FormBuilder
|
||||
}
|
||||
|
||||
// NewClient creates new OpenAI API client.
|
||||
@@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client {
|
||||
return &Client{
|
||||
config: config,
|
||||
requestBuilder: newRequestBuilder(),
|
||||
createFormBuilder: func(body io.Writer) formBuilder {
|
||||
return newFormBuilder(body)
|
||||
createFormBuilder: func(body io.Writer) utils.FormBuilder {
|
||||
return utils.NewFormBuilder(body)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
8
files.go
8
files.go
@@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
||||
var b bytes.Buffer
|
||||
builder := c.createFormBuilder(&b)
|
||||
|
||||
err = builder.writeField("purpose", request.Purpose)
|
||||
err = builder.WriteField("purpose", request.Purpose)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.createFormFile("file", fileData)
|
||||
err = builder.CreateFormFile("file", fileData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.close()
|
||||
err = builder.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", builder.formDataContentType())
|
||||
req.Header.Set("Content-Type", builder.FormDataContentType())
|
||||
|
||||
err = c.sendRequest(req, &file)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
. "github.com/sashabaranov/go-openai/internal"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
|
||||
@@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
||||
config.BaseURL = ""
|
||||
client := NewClientWithConfig(config)
|
||||
mockBuilder := &mockFormBuilder{}
|
||||
client.createFormBuilder = func(io.Writer) formBuilder {
|
||||
client.createFormBuilder = func(io.Writer) FormBuilder {
|
||||
return mockBuilder
|
||||
}
|
||||
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
)
|
||||
|
||||
type formBuilder interface {
|
||||
createFormFile(fieldname string, file *os.File) error
|
||||
writeField(fieldname, value string) error
|
||||
close() error
|
||||
formDataContentType() string
|
||||
}
|
||||
|
||||
type defaultFormBuilder struct {
|
||||
writer *multipart.Writer
|
||||
}
|
||||
|
||||
func newFormBuilder(body io.Writer) *defaultFormBuilder {
|
||||
return &defaultFormBuilder{
|
||||
writer: multipart.NewWriter(body),
|
||||
}
|
||||
}
|
||||
|
||||
func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error {
|
||||
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(fieldWriter, file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fb *defaultFormBuilder) writeField(fieldname, value string) error {
|
||||
return fb.writer.WriteField(fieldname, value)
|
||||
}
|
||||
|
||||
func (fb *defaultFormBuilder) close() error {
|
||||
return fb.writer.Close()
|
||||
}
|
||||
|
||||
func (fb *defaultFormBuilder) formDataContentType() string {
|
||||
return fb.writer.FormDataContentType()
|
||||
}
|
||||
28
image.go
28
image.go
@@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
||||
builder := c.createFormBuilder(body)
|
||||
|
||||
// image
|
||||
err = builder.createFormFile("image", request.Image)
|
||||
err = builder.CreateFormFile("image", request.Image)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// mask, it is optional
|
||||
if request.Mask != nil {
|
||||
err = builder.createFormFile("mask", request.Mask)
|
||||
err = builder.CreateFormFile("mask", request.Mask)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = builder.writeField("prompt", request.Prompt)
|
||||
err = builder.WriteField("prompt", request.Prompt)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.writeField("n", strconv.Itoa(request.N))
|
||||
err = builder.WriteField("n", strconv.Itoa(request.N))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.writeField("size", request.Size)
|
||||
err = builder.WriteField("size", request.Size)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.writeField("response_format", request.ResponseFormat)
|
||||
err = builder.WriteField("response_format", request.ResponseFormat)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.close()
|
||||
err = builder.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", builder.formDataContentType())
|
||||
req.Header.Set("Content-Type", builder.FormDataContentType())
|
||||
err = c.sendRequest(req, &response)
|
||||
return
|
||||
}
|
||||
@@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
||||
builder := c.createFormBuilder(body)
|
||||
|
||||
// image
|
||||
err = builder.createFormFile("image", request.Image)
|
||||
err = builder.CreateFormFile("image", request.Image)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.writeField("n", strconv.Itoa(request.N))
|
||||
err = builder.WriteField("n", strconv.Itoa(request.N))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.writeField("size", request.Size)
|
||||
err = builder.WriteField("size", request.Size)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.writeField("response_format", request.ResponseFormat)
|
||||
err = builder.WriteField("response_format", request.ResponseFormat)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = builder.close()
|
||||
err = builder.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", builder.formDataContentType())
|
||||
req.Header.Set("Content-Type", builder.FormDataContentType())
|
||||
err = c.sendRequest(req, &response)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
|
||||
@@ -268,19 +269,19 @@ type mockFormBuilder struct {
|
||||
mockClose func() error
|
||||
}
|
||||
|
||||
func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error {
|
||||
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
||||
return fb.mockCreateFormFile(fieldname, file)
|
||||
}
|
||||
|
||||
func (fb *mockFormBuilder) writeField(fieldname, value string) error {
|
||||
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
|
||||
return fb.mockWriteField(fieldname, value)
|
||||
}
|
||||
|
||||
func (fb *mockFormBuilder) close() error {
|
||||
func (fb *mockFormBuilder) Close() error {
|
||||
return fb.mockClose()
|
||||
}
|
||||
|
||||
func (fb *mockFormBuilder) formDataContentType() string {
|
||||
func (fb *mockFormBuilder) FormDataContentType() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) {
|
||||
client := NewClientWithConfig(config)
|
||||
|
||||
mockBuilder := &mockFormBuilder{}
|
||||
client.createFormBuilder = func(io.Writer) formBuilder {
|
||||
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
|
||||
return mockBuilder
|
||||
}
|
||||
ctx := context.Background()
|
||||
@@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
|
||||
client := NewClientWithConfig(config)
|
||||
|
||||
mockBuilder := &mockFormBuilder{}
|
||||
client.createFormBuilder = func(io.Writer) formBuilder {
|
||||
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
|
||||
return mockBuilder
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
49
internal/form_builder.go
Normal file
49
internal/form_builder.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
)
|
||||
|
||||
type FormBuilder interface {
|
||||
CreateFormFile(fieldname string, file *os.File) error
|
||||
WriteField(fieldname, value string) error
|
||||
Close() error
|
||||
FormDataContentType() string
|
||||
}
|
||||
|
||||
type DefaultFormBuilder struct {
|
||||
writer *multipart.Writer
|
||||
}
|
||||
|
||||
func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
|
||||
return &DefaultFormBuilder{
|
||||
writer: multipart.NewWriter(body),
|
||||
}
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
||||
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(fieldWriter, file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
|
||||
return fb.writer.WriteField(fieldname, value)
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) Close() error {
|
||||
return fb.writer.Close()
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) FormDataContentType() string {
|
||||
return fb.writer.FormDataContentType()
|
||||
}
|
||||
@@ -30,8 +30,8 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
|
||||
defer file.Close()
|
||||
defer os.Remove(file.Name())
|
||||
|
||||
builder := newFormBuilder(&failingWriter{})
|
||||
err = builder.createFormFile("file", file)
|
||||
builder := NewFormBuilder(&failingWriter{})
|
||||
err = builder.CreateFormFile("file", file)
|
||||
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
|
||||
}
|
||||
|
||||
@@ -47,8 +47,8 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
|
||||
defer os.Remove(file.Name())
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
builder := newFormBuilder(body)
|
||||
err = builder.createFormFile("file", file)
|
||||
builder := NewFormBuilder(body)
|
||||
err = builder.CreateFormFile("file", file)
|
||||
checks.HasError(t, err, "formbuilder should return error if file is closed")
|
||||
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
|
||||
}
|
||||
Reference in New Issue
Block a user