Add files api tests (#238)
* drop support for downloading files * use form builder to submit files * update doc * add form builder tests
This commit is contained in:
62
files.go
62
files.go
@@ -4,10 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"mime/multipart"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,77 +30,38 @@ type FilesList struct {
|
|||||||
Files []File `json:"data"`
|
Files []File `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// isUrl is a helper function that determines whether the given FilePath
|
|
||||||
// is a remote URL or a local file path.
|
|
||||||
func isURL(path string) bool {
|
|
||||||
_, err := url.ParseRequestURI(path)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := url.Parse(path)
|
|
||||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateFile uploads a jsonl file to GPT3
|
// CreateFile uploads a jsonl file to GPT3
|
||||||
// FilePath can be either a local file path or a URL.
|
// FilePath must be a local file path.
|
||||||
func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) {
|
func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
w := multipart.NewWriter(&b)
|
builder := c.createFormBuilder(&b)
|
||||||
|
|
||||||
var fw io.Writer
|
err = builder.writeField("purpose", request.Purpose)
|
||||||
|
|
||||||
err = w.WriteField("purpose", request.Purpose)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fw, err = w.CreateFormFile("file", request.FileName)
|
fileData, err := os.Open(request.FilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var fileData io.ReadCloser
|
err = builder.createFormFile("file", fileData)
|
||||||
if isURL(request.FilePath) {
|
|
||||||
var remoteFile *http.Response
|
|
||||||
remoteFile, err = http.Get(request.FilePath)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer remoteFile.Body.Close()
|
|
||||||
|
|
||||||
// Check server response
|
|
||||||
if remoteFile.StatusCode != http.StatusOK {
|
|
||||||
err = fmt.Errorf("error, status code: %d, message: failed to fetch file", remoteFile.StatusCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fileData = remoteFile.Body
|
|
||||||
} else {
|
|
||||||
fileData, err = os.Open(request.FilePath)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = io.Copy(fw, fileData)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Close()
|
err = builder.close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", builder.formDataContentType())
|
||||||
|
|
||||||
err = c.sendRequest(req, &file)
|
err = c.sendRequest(req, &file)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
package openai_test
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -34,7 +35,7 @@ func TestFileUpload(t *testing.T) {
|
|||||||
Purpose: "fine-tune",
|
Purpose: "fine-tune",
|
||||||
}
|
}
|
||||||
_, err = client.CreateFile(ctx, req)
|
_, err = client.CreateFile(ctx, req)
|
||||||
checks.NoError(t, err, "CreateFile erro")
|
checks.NoError(t, err, "CreateFile error")
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCreateFile Handles the images endpoint by the test server.
|
// handleCreateFile Handles the images endpoint by the test server.
|
||||||
@@ -78,3 +79,50 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) {
|
|||||||
resBytes, _ = json.Marshal(fileReq)
|
resBytes, _ = json.Marshal(fileReq)
|
||||||
fmt.Fprint(w, string(resBytes))
|
fmt.Fprint(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
||||||
|
config := DefaultConfig("")
|
||||||
|
config.BaseURL = ""
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
mockBuilder := &mockFormBuilder{}
|
||||||
|
client.createFormBuilder = func(io.Writer) formBuilder {
|
||||||
|
return mockBuilder
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := FileRequest{
|
||||||
|
FileName: "test.go",
|
||||||
|
FilePath: "client.go",
|
||||||
|
Purpose: "fine-tune",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockError := fmt.Errorf("mockWriteField error")
|
||||||
|
mockBuilder.mockWriteField = func(string, string) error {
|
||||||
|
return mockError
|
||||||
|
}
|
||||||
|
_, err := client.CreateFile(ctx, req)
|
||||||
|
checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails")
|
||||||
|
|
||||||
|
mockError = fmt.Errorf("mockCreateFormFile error")
|
||||||
|
mockBuilder.mockWriteField = func(string, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||||
|
return mockError
|
||||||
|
}
|
||||||
|
_, err = client.CreateFile(ctx, req)
|
||||||
|
checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails")
|
||||||
|
|
||||||
|
mockError = fmt.Errorf("mockClose error")
|
||||||
|
mockBuilder.mockWriteField = func(string, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mockBuilder.mockClose = func() error {
|
||||||
|
return mockError
|
||||||
|
}
|
||||||
|
_, err = client.CreateFile(ctx, req)
|
||||||
|
checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user