update image api *os.File to io.Reader (#994)

* update image api *os.File to io.Reader

* update code style

* add reader test

* supplementary reader test

* update the reader in the form builder test

* add commnet

* update comment

* update code style
This commit is contained in:
Axb12
2025-05-20 21:45:40 +08:00
committed by GitHub
parent 4d2e7ab29d
commit 8c65b35c57
4 changed files with 88 additions and 27 deletions

View File

@@ -3,8 +3,8 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"io"
"net/http" "net/http"
"os"
"strconv" "strconv"
) )
@@ -134,8 +134,8 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
// ImageEditRequest represents the request structure for the image API. // ImageEditRequest represents the request structure for the image API.
type ImageEditRequest struct { type ImageEditRequest struct {
Image *os.File `json:"image,omitempty"` Image io.Reader `json:"image,omitempty"`
Mask *os.File `json:"mask,omitempty"` Mask io.Reader `json:"mask,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
@@ -150,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
body := &bytes.Buffer{} body := &bytes.Buffer{}
builder := c.createFormBuilder(body) builder := c.createFormBuilder(body)
// image // image, filename is not required
err = builder.CreateFormFile("image", request.Image) err = builder.CreateFormFileReader("image", request.Image, "")
if err != nil { if err != nil {
return return
} }
// mask, it is optional // mask, it is optional
if request.Mask != nil { if request.Mask != nil {
err = builder.CreateFormFile("mask", request.Mask) // mask, filename is not required
err = builder.CreateFormFileReader("mask", request.Mask, "")
if err != nil { if err != nil {
return return
} }
@@ -206,7 +207,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
// ImageVariRequest represents the request structure for the image API. // ImageVariRequest represents the request structure for the image API.
type ImageVariRequest struct { type ImageVariRequest struct {
Image *os.File `json:"image,omitempty"` Image io.Reader `json:"image,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
@@ -220,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
body := &bytes.Buffer{} body := &bytes.Buffer{}
builder := c.createFormBuilder(body) builder := c.createFormBuilder(body)
// image // image, filename is not required
err = builder.CreateFormFile("image", request.Image) err = builder.CreateFormFileReader("image", request.Image, "")
if err != nil { if err != nil {
return return
} }

View File

@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
} }
mockFailedErr := fmt.Errorf("mock form builder fail") mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFile = func(string, *os.File) error { mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return mockFailedErr return mockFailedErr
} }
_, err := client.CreateEditImage(ctx, req) _, err := client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
if name == "mask" { if name == "mask" {
return mockFailedErr return mockFailedErr
} }
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
req := ImageVariRequest{} req := ImageVariRequest{}
mockFailedErr := fmt.Errorf("mock form builder fail") mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFile = func(string, *os.File) error { mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return mockFailedErr return mockFailedErr
} }
_, err := client.CreateVariImage(ctx, req) _, err := client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(string, *os.File) error { mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return nil return nil
} }

View File

@@ -4,8 +4,10 @@ import (
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"net/textproto"
"os" "os"
"path" "path/filepath"
"strings"
) )
type FormBuilder interface { type FormBuilder interface {
@@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
return fb.createFormFile(fieldname, file, file.Name()) return fb.createFormFile(fieldname, file, file.Name())
} }
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
// CreateFormFileReader creates a form field with a file reader.
// The filename in parameters can be an empty string.
// The filename in Content-Disposition is required, But it can be an empty string.
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.createFormFile(fieldname, r, path.Base(filename)) h := make(textproto.MIMEHeader)
h.Set(
"Content-Disposition",
fmt.Sprintf(
`form-data; name="%s"; filename="%s"`,
escapeQuotes(fieldname),
escapeQuotes(filepath.Base(filename)),
),
)
fieldWriter, err := fb.writer.CreatePart(h)
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}
return nil
} }
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {

View File

@@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
checks.HasError(t, err, "formbuilder should return error if file is closed") checks.HasError(t, err, "formbuilder should return error if file is closed")
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
} }
type failingReader struct {
}
var errMockFailingReaderError = errors.New("mock reader failed")
func (*failingReader) Read([]byte) (int, error) {
return 0, errMockFailingReaderError
}
func TestFormBuilderWithReader(t *testing.T) {
file, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatalf("Error creating tmp file: %v", err)
}
defer file.Close()
builder := NewFormBuilder(&failingWriter{})
err = builder.CreateFormFileReader("file", file, file.Name())
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
builder = NewFormBuilder(&bytes.Buffer{})
reader := &failingReader{}
err = builder.CreateFormFileReader("file", reader, "")
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
successReader := &bytes.Buffer{}
err = builder.CreateFormFileReader("file", successReader, "")
checks.NoError(t, err, "formbuilder should not return error")
}