update image api *os.File to io.Reader (#994)
* update image api *os.File to io.Reader * update code style * add reader test * supplementary reader test * update the reader in the form builder test * add commnet * update comment * update code style
This commit is contained in:
43
image.go
43
image.go
@@ -3,8 +3,8 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -134,15 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
|
||||
|
||||
// ImageEditRequest represents the request structure for the image API.
|
||||
type ImageEditRequest struct {
|
||||
Image *os.File `json:"image,omitempty"`
|
||||
Mask *os.File `json:"mask,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Image io.Reader `json:"image,omitempty"`
|
||||
Mask io.Reader `json:"mask,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||
@@ -150,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
||||
body := &bytes.Buffer{}
|
||||
builder := c.createFormBuilder(body)
|
||||
|
||||
// image
|
||||
err = builder.CreateFormFile("image", request.Image)
|
||||
// image, filename is not required
|
||||
err = builder.CreateFormFileReader("image", request.Image, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// mask, it is optional
|
||||
if request.Mask != nil {
|
||||
err = builder.CreateFormFile("mask", request.Mask)
|
||||
// mask, filename is not required
|
||||
err = builder.CreateFormFileReader("mask", request.Mask, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -206,12 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
||||
|
||||
// ImageVariRequest represents the request structure for the image API.
|
||||
type ImageVariRequest struct {
|
||||
Image *os.File `json:"image,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Image io.Reader `json:"image,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
|
||||
@@ -220,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
||||
body := &bytes.Buffer{}
|
||||
builder := c.createFormBuilder(body)
|
||||
|
||||
// image
|
||||
err = builder.CreateFormFile("image", request.Image)
|
||||
// image, filename is not required
|
||||
err = builder.CreateFormFileReader("image", request.Image, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
|
||||
}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return mockFailedErr
|
||||
}
|
||||
_, err := client.CreateEditImage(ctx, req)
|
||||
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
|
||||
|
||||
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
|
||||
if name == "mask" {
|
||||
return mockFailedErr
|
||||
}
|
||||
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
|
||||
req := ImageVariRequest{}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return mockFailedErr
|
||||
}
|
||||
_, err := client.CreateVariImage(ctx, req)
|
||||
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
|
||||
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type FormBuilder interface {
|
||||
@@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
|
||||
return fb.createFormFile(fieldname, file, file.Name())
|
||||
}
|
||||
|
||||
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
|
||||
|
||||
func escapeQuotes(s string) string {
|
||||
return quoteEscaper.Replace(s)
|
||||
}
|
||||
|
||||
// CreateFormFileReader creates a form field with a file reader.
|
||||
// The filename in parameters can be an empty string.
|
||||
// The filename in Content-Disposition is required, But it can be an empty string.
|
||||
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||
return fb.createFormFile(fieldname, r, path.Base(filename))
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(
|
||||
`form-data; name="%s"; filename="%s"`,
|
||||
escapeQuotes(fieldname),
|
||||
escapeQuotes(filepath.Base(filename)),
|
||||
),
|
||||
)
|
||||
|
||||
fieldWriter, err := fb.writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(fieldWriter, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
|
||||
|
||||
@@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
|
||||
checks.HasError(t, err, "formbuilder should return error if file is closed")
|
||||
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
|
||||
}
|
||||
|
||||
type failingReader struct {
|
||||
}
|
||||
|
||||
var errMockFailingReaderError = errors.New("mock reader failed")
|
||||
|
||||
func (*failingReader) Read([]byte) (int, error) {
|
||||
return 0, errMockFailingReaderError
|
||||
}
|
||||
|
||||
func TestFormBuilderWithReader(t *testing.T) {
|
||||
file, err := os.CreateTemp(t.TempDir(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating tmp file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
builder := NewFormBuilder(&failingWriter{})
|
||||
err = builder.CreateFormFileReader("file", file, file.Name())
|
||||
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
|
||||
|
||||
builder = NewFormBuilder(&bytes.Buffer{})
|
||||
reader := &failingReader{}
|
||||
err = builder.CreateFormFileReader("file", reader, "")
|
||||
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
|
||||
|
||||
successReader := &bytes.Buffer{}
|
||||
err = builder.CreateFormFileReader("file", successReader, "")
|
||||
checks.NoError(t, err, "formbuilder should not return error")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user