Move form_builder into internal pkg. (#311)
* Move form_uilder into internal pkg. * Fix import of audio.go * Reorganize. * Fix import. * Fix --------- Co-authored-by: JoyShi <joy.shi@sap.com>
This commit is contained in:
20
audio.go
20
audio.go
@@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
|
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
|
||||||
@@ -72,7 +74,7 @@ func (c *Client) callAudioAPI(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return AudioResponse{}, err
|
return AudioResponse{}, err
|
||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", builder.formDataContentType())
|
req.Header.Add("Content-Type", builder.FormDataContentType())
|
||||||
|
|
||||||
if request.HasJSONResponse() {
|
if request.HasJSONResponse() {
|
||||||
err = c.sendRequest(req, &response)
|
err = c.sendRequest(req, &response)
|
||||||
@@ -92,26 +94,26 @@ func (r AudioRequest) HasJSONResponse() bool {
|
|||||||
|
|
||||||
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
||||||
// audio processing.
|
// audio processing.
|
||||||
func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
|
||||||
f, err := os.Open(request.FilePath)
|
f, err := os.Open(request.FilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("opening audio file: %w", err)
|
return fmt.Errorf("opening audio file: %w", err)
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
err = b.createFormFile("file", f)
|
err = b.CreateFormFile("file", f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating form file: %w", err)
|
return fmt.Errorf("creating form file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = b.writeField("model", request.Model)
|
err = b.WriteField("model", request.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("writing model name: %w", err)
|
return fmt.Errorf("writing model name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a form field for the prompt (if provided)
|
// Create a form field for the prompt (if provided)
|
||||||
if request.Prompt != "" {
|
if request.Prompt != "" {
|
||||||
err = b.writeField("prompt", request.Prompt)
|
err = b.WriteField("prompt", request.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("writing prompt: %w", err)
|
return fmt.Errorf("writing prompt: %w", err)
|
||||||
}
|
}
|
||||||
@@ -119,7 +121,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
|||||||
|
|
||||||
// Create a form field for the format (if provided)
|
// Create a form field for the format (if provided)
|
||||||
if request.Format != "" {
|
if request.Format != "" {
|
||||||
err = b.writeField("response_format", string(request.Format))
|
err = b.WriteField("response_format", string(request.Format))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("writing format: %w", err)
|
return fmt.Errorf("writing format: %w", err)
|
||||||
}
|
}
|
||||||
@@ -127,7 +129,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
|||||||
|
|
||||||
// Create a form field for the temperature (if provided)
|
// Create a form field for the temperature (if provided)
|
||||||
if request.Temperature != 0 {
|
if request.Temperature != 0 {
|
||||||
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
|
err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("writing temperature: %w", err)
|
return fmt.Errorf("writing temperature: %w", err)
|
||||||
}
|
}
|
||||||
@@ -135,12 +137,12 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
|||||||
|
|
||||||
// Create a form field for the language (if provided)
|
// Create a form field for the language (if provided)
|
||||||
if request.Language != "" {
|
if request.Language != "" {
|
||||||
err = b.writeField("language", request.Language)
|
err = b.WriteField("language", request.Language)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("writing language: %w", err)
|
return fmt.Errorf("writing language: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the multipart writer
|
// Close the multipart writer
|
||||||
return b.close()
|
return b.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client is OpenAI GPT-3 API client.
|
// Client is OpenAI GPT-3 API client.
|
||||||
@@ -14,7 +16,7 @@ type Client struct {
|
|||||||
config ClientConfig
|
config ClientConfig
|
||||||
|
|
||||||
requestBuilder requestBuilder
|
requestBuilder requestBuilder
|
||||||
createFormBuilder func(io.Writer) formBuilder
|
createFormBuilder func(io.Writer) utils.FormBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates new OpenAI API client.
|
// NewClient creates new OpenAI API client.
|
||||||
@@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client {
|
|||||||
return &Client{
|
return &Client{
|
||||||
config: config,
|
config: config,
|
||||||
requestBuilder: newRequestBuilder(),
|
requestBuilder: newRequestBuilder(),
|
||||||
createFormBuilder: func(body io.Writer) formBuilder {
|
createFormBuilder: func(body io.Writer) utils.FormBuilder {
|
||||||
return newFormBuilder(body)
|
return utils.NewFormBuilder(body)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
8
files.go
8
files.go
@@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
|||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
builder := c.createFormBuilder(&b)
|
builder := c.createFormBuilder(&b)
|
||||||
|
|
||||||
err = builder.writeField("purpose", request.Purpose)
|
err = builder.WriteField("purpose", request.Purpose)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.createFormFile("file", fileData)
|
err = builder.CreateFormFile("file", fileData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.close()
|
err = builder.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", builder.formDataContentType())
|
req.Header.Set("Content-Type", builder.FormDataContentType())
|
||||||
|
|
||||||
err = c.sendRequest(req, &file)
|
err = c.sendRequest(req, &file)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
. "github.com/sashabaranov/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
@@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
|||||||
config.BaseURL = ""
|
config.BaseURL = ""
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
mockBuilder := &mockFormBuilder{}
|
mockBuilder := &mockFormBuilder{}
|
||||||
client.createFormBuilder = func(io.Writer) formBuilder {
|
client.createFormBuilder = func(io.Writer) FormBuilder {
|
||||||
return mockBuilder
|
return mockBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"mime/multipart"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
type formBuilder interface {
|
|
||||||
createFormFile(fieldname string, file *os.File) error
|
|
||||||
writeField(fieldname, value string) error
|
|
||||||
close() error
|
|
||||||
formDataContentType() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type defaultFormBuilder struct {
|
|
||||||
writer *multipart.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFormBuilder(body io.Writer) *defaultFormBuilder {
|
|
||||||
return &defaultFormBuilder{
|
|
||||||
writer: multipart.NewWriter(body),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error {
|
|
||||||
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = io.Copy(fieldWriter, file)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fb *defaultFormBuilder) writeField(fieldname, value string) error {
|
|
||||||
return fb.writer.WriteField(fieldname, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fb *defaultFormBuilder) close() error {
|
|
||||||
return fb.writer.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fb *defaultFormBuilder) formDataContentType() string {
|
|
||||||
return fb.writer.FormDataContentType()
|
|
||||||
}
|
|
||||||
28
image.go
28
image.go
@@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
|||||||
builder := c.createFormBuilder(body)
|
builder := c.createFormBuilder(body)
|
||||||
|
|
||||||
// image
|
// image
|
||||||
err = builder.createFormFile("image", request.Image)
|
err = builder.CreateFormFile("image", request.Image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// mask, it is optional
|
// mask, it is optional
|
||||||
if request.Mask != nil {
|
if request.Mask != nil {
|
||||||
err = builder.createFormFile("mask", request.Mask)
|
err = builder.CreateFormFile("mask", request.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.writeField("prompt", request.Prompt)
|
err = builder.WriteField("prompt", request.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.writeField("n", strconv.Itoa(request.N))
|
err = builder.WriteField("n", strconv.Itoa(request.N))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.writeField("size", request.Size)
|
err = builder.WriteField("size", request.Size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.writeField("response_format", request.ResponseFormat)
|
err = builder.WriteField("response_format", request.ResponseFormat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.close()
|
err = builder.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", builder.formDataContentType())
|
req.Header.Set("Content-Type", builder.FormDataContentType())
|
||||||
err = c.sendRequest(req, &response)
|
err = c.sendRequest(req, &response)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
|||||||
builder := c.createFormBuilder(body)
|
builder := c.createFormBuilder(body)
|
||||||
|
|
||||||
// image
|
// image
|
||||||
err = builder.createFormFile("image", request.Image)
|
err = builder.CreateFormFile("image", request.Image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.writeField("n", strconv.Itoa(request.N))
|
err = builder.WriteField("n", strconv.Itoa(request.N))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.writeField("size", request.Size)
|
err = builder.WriteField("size", request.Size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.writeField("response_format", request.ResponseFormat)
|
err = builder.WriteField("response_format", request.ResponseFormat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = builder.close()
|
err = builder.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", builder.formDataContentType())
|
req.Header.Set("Content-Type", builder.FormDataContentType())
|
||||||
err = c.sendRequest(req, &response)
|
err = c.sendRequest(req, &response)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
@@ -268,19 +269,19 @@ type mockFormBuilder struct {
|
|||||||
mockClose func() error
|
mockClose func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error {
|
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
||||||
return fb.mockCreateFormFile(fieldname, file)
|
return fb.mockCreateFormFile(fieldname, file)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fb *mockFormBuilder) writeField(fieldname, value string) error {
|
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
|
||||||
return fb.mockWriteField(fieldname, value)
|
return fb.mockWriteField(fieldname, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fb *mockFormBuilder) close() error {
|
func (fb *mockFormBuilder) Close() error {
|
||||||
return fb.mockClose()
|
return fb.mockClose()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fb *mockFormBuilder) formDataContentType() string {
|
func (fb *mockFormBuilder) FormDataContentType() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) {
|
|||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
|
|
||||||
mockBuilder := &mockFormBuilder{}
|
mockBuilder := &mockFormBuilder{}
|
||||||
client.createFormBuilder = func(io.Writer) formBuilder {
|
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
|
||||||
return mockBuilder
|
return mockBuilder
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
|
|||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
|
|
||||||
mockBuilder := &mockFormBuilder{}
|
mockBuilder := &mockFormBuilder{}
|
||||||
client.createFormBuilder = func(io.Writer) formBuilder {
|
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
|
||||||
return mockBuilder
|
return mockBuilder
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
49
internal/form_builder.go
Normal file
49
internal/form_builder.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FormBuilder interface {
|
||||||
|
CreateFormFile(fieldname string, file *os.File) error
|
||||||
|
WriteField(fieldname, value string) error
|
||||||
|
Close() error
|
||||||
|
FormDataContentType() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultFormBuilder struct {
|
||||||
|
writer *multipart.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
|
||||||
|
return &DefaultFormBuilder{
|
||||||
|
writer: multipart.NewWriter(body),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
||||||
|
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = io.Copy(fieldWriter, file)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
|
||||||
|
return fb.writer.WriteField(fieldname, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) Close() error {
|
||||||
|
return fb.writer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fb *DefaultFormBuilder) FormDataContentType() string {
|
||||||
|
return fb.writer.FormDataContentType()
|
||||||
|
}
|
||||||
@@ -30,8 +30,8 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
|
|||||||
defer file.Close()
|
defer file.Close()
|
||||||
defer os.Remove(file.Name())
|
defer os.Remove(file.Name())
|
||||||
|
|
||||||
builder := newFormBuilder(&failingWriter{})
|
builder := NewFormBuilder(&failingWriter{})
|
||||||
err = builder.createFormFile("file", file)
|
err = builder.CreateFormFile("file", file)
|
||||||
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
|
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,8 +47,8 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
|
|||||||
defer os.Remove(file.Name())
|
defer os.Remove(file.Name())
|
||||||
|
|
||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
builder := newFormBuilder(body)
|
builder := NewFormBuilder(body)
|
||||||
err = builder.createFormFile("file", file)
|
err = builder.CreateFormFile("file", file)
|
||||||
checks.HasError(t, err, "formbuilder should return error if file is closed")
|
checks.HasError(t, err, "formbuilder should return error if file is closed")
|
||||||
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
|
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user