* move request_builder into internal pkg (#304) * add some test for internal.RequestBuilder * add a test for openai.GetEngine
This commit is contained in:
committed by
GitHub
parent
62eb4beed2
commit
61ba5f3369
2
chat.go
2
chat.go
@@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
config ClientConfig
|
config ClientConfig
|
||||||
|
|
||||||
requestBuilder requestBuilder
|
requestBuilder utils.RequestBuilder
|
||||||
createFormBuilder func(io.Writer) utils.FormBuilder
|
createFormBuilder func(io.Writer) utils.FormBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ func NewClient(authToken string) *Client {
|
|||||||
func NewClientWithConfig(config ClientConfig) *Client {
|
func NewClientWithConfig(config ClientConfig) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
config: config,
|
config: config,
|
||||||
requestBuilder: newRequestBuilder(),
|
requestBuilder: utils.NewRequestBuilder(),
|
||||||
createFormBuilder: func(body io.Writer) utils.FormBuilder {
|
createFormBuilder: func(body io.Writer) utils.FormBuilder {
|
||||||
return utils.NewFormBuilder(body)
|
return utils.NewFormBuilder(body)
|
||||||
},
|
},
|
||||||
@@ -135,7 +135,7 @@ func (c *Client) newStreamRequest(
|
|||||||
urlSuffix string,
|
urlSuffix string,
|
||||||
body any,
|
body any,
|
||||||
model string) (*http.Request, error) {
|
model string) (*http.Request, error) {
|
||||||
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body)
|
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
149
client_test.go
149
client_test.go
@@ -2,13 +2,24 @@ package openai //nolint:testpackage // testing private field
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errTestRequestBuilderFailed = errors.New("test request builder failed")
|
||||||
|
|
||||||
|
type failingRequestBuilder struct{}
|
||||||
|
|
||||||
|
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
|
||||||
|
return nil, errTestRequestBuilderFailed
|
||||||
|
}
|
||||||
|
|
||||||
func TestClient(t *testing.T) {
|
func TestClient(t *testing.T) {
|
||||||
const mockToken = "mock token"
|
const mockToken = "mock token"
|
||||||
client := NewClient(mockToken)
|
client := NewClient(mockToken)
|
||||||
@@ -145,3 +156,141 @@ func TestHandleErrorResp(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
ts := test.NewTestServer().OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
client.requestBuilder = &failingRequestBuilder{}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListFineTunes(ctx)
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CancelFineTune(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.GetFineTune(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.DeleteFineTune(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListFineTuneEvents(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Moderations(ctx, ModerationRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Edits(ctx, EditsRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateImage(ctx, ImageRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = client.DeleteFile(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.GetFile(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListFiles(ctx)
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListEngines(ctx)
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.GetEngine(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListModels(ctx)
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
ts := test.NewTestServer().OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
client.requestBuilder = &failingRequestBuilder{}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
|
||||||
|
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
|
||||||
|
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func (c *Client) CreateCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
2
edits.go
2
edits.go
@@ -32,7 +32,7 @@ type EditsResponse struct {
|
|||||||
|
|
||||||
// Perform an API call to the Edits endpoint.
|
// Perform an API call to the Edits endpoint.
|
||||||
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
|
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ type EmbeddingRequest struct {
|
|||||||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
||||||
// https://beta.openai.com/docs/api-reference/embeddings/create
|
// https://beta.openai.com/docs/api-reference/embeddings/create
|
||||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
|
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ type EnginesList struct {
|
|||||||
// ListEngines Lists the currently available engines, and provides basic
|
// ListEngines Lists the currently available engines, and provides basic
|
||||||
// information about each option such as the owner and availability.
|
// information about each option such as the owner and availability.
|
||||||
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
|
|||||||
engineID string,
|
engineID string,
|
||||||
) (engine Engine, err error) {
|
) (engine Engine, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
34
engines_test.go
Normal file
34
engines_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package openai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
|
||||||
|
func TestGetEngine(t *testing.T) {
|
||||||
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resBytes, _ := json.Marshal(Engine{})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
})
|
||||||
|
// create the test server
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := client.GetEngine(ctx, "text-davinci-003")
|
||||||
|
checks.NoError(t, err, "GetEngine error")
|
||||||
|
}
|
||||||
6
files.go
6
files.go
@@ -70,7 +70,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
|||||||
|
|
||||||
// DeleteFile deletes an existing file.
|
// DeleteFile deletes an existing file.
|
||||||
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -82,7 +82,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
|||||||
// ListFiles Lists the currently available files,
|
// ListFiles Lists the currently available files,
|
||||||
// and provides basic information about each file such as the file name and purpose.
|
// and provides basic information about each file such as the file name and purpose.
|
||||||
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -95,7 +95,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
|||||||
// such as the file name and purpose.
|
// such as the file name and purpose.
|
||||||
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
|
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/files/%s", fileID)
|
urlSuffix := fmt.Sprintf("/files/%s", fileID)
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct {
|
|||||||
|
|
||||||
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
||||||
urlSuffix := "/fine-tunes"
|
urlSuffix := "/fine-tunes"
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r
|
|||||||
|
|
||||||
// CancelFineTune cancel a fine-tune job.
|
// CancelFineTune cancel a fine-tune job.
|
||||||
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
|
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err
|
|||||||
|
|
||||||
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
|
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
|
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
|
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
2
image.go
2
image.go
@@ -44,7 +44,7 @@ type ImageResponseDataInner struct {
|
|||||||
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||||
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
||||||
urlSuffix := "/images/generations"
|
urlSuffix := "/images/generations"
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,25 +4,23 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type requestBuilder interface {
|
type RequestBuilder interface {
|
||||||
build(ctx context.Context, method, url string, request any) (*http.Request, error)
|
Build(ctx context.Context, method, url string, request any) (*http.Request, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpRequestBuilder struct {
|
type HTTPRequestBuilder struct {
|
||||||
marshaller utils.Marshaller
|
marshaller Marshaller
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequestBuilder() *httpRequestBuilder {
|
func NewRequestBuilder() *HTTPRequestBuilder {
|
||||||
return &httpRequestBuilder{
|
return &HTTPRequestBuilder{
|
||||||
marshaller: &utils.JSONMarshaller{},
|
marshaller: &JSONMarshaller{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) {
|
func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return http.NewRequestWithContext(ctx, method, url, nil)
|
return http.NewRequestWithContext(ctx, method, url, nil)
|
||||||
}
|
}
|
||||||
61
internal/request_builder_test.go
Normal file
61
internal/request_builder_test.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errTestMarshallerFailed = errors.New("test marshaller failed")
|
||||||
|
|
||||||
|
type failingMarshaller struct{}
|
||||||
|
|
||||||
|
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
|
||||||
|
return []byte{}, errTestMarshallerFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
|
||||||
|
builder := HTTPRequestBuilder{
|
||||||
|
marshaller: &failingMarshaller{},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := builder.Build(context.Background(), "", "", struct{}{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBuilderReturnsRequest(t *testing.T) {
|
||||||
|
b := NewRequestBuilder()
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
method = http.MethodPost
|
||||||
|
url = "/foo"
|
||||||
|
request = map[string]string{"foo": "bar"}
|
||||||
|
reqBytes, _ = b.marshaller.Marshal(request)
|
||||||
|
want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes))
|
||||||
|
)
|
||||||
|
got, _ := b.Build(ctx, method, url, request)
|
||||||
|
if !reflect.DeepEqual(got.Body, want.Body) ||
|
||||||
|
!reflect.DeepEqual(got.URL, want.URL) ||
|
||||||
|
!reflect.DeepEqual(got.Method, want.Method) {
|
||||||
|
t.Errorf("Build() got = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
method = http.MethodGet
|
||||||
|
url = "/foo"
|
||||||
|
want, _ = http.NewRequestWithContext(ctx, method, url, nil)
|
||||||
|
)
|
||||||
|
b := NewRequestBuilder()
|
||||||
|
got, _ := b.Build(ctx, method, url, nil)
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("Build() got = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -40,7 +40,7 @@ type ModelsList struct {
|
|||||||
// ListModels Lists the currently available models,
|
// ListModels Lists the currently available models,
|
||||||
// and provides basic information about each model such as the model id and parent.
|
// and provides basic information about each model such as the model id and parent.
|
||||||
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
|
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil)
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ type ModerationResponse struct {
|
|||||||
// Moderations — perform a moderation api call over a string.
|
// Moderations — perform a moderation api call over a string.
|
||||||
// Input can be an array or slice but a string will reduce the complexity.
|
// Input can be an array or slice but a string will reduce the complexity.
|
||||||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
|
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,177 +0,0 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errTestMarshallerFailed = errors.New("test marshaller failed")
|
|
||||||
errTestRequestBuilderFailed = errors.New("test request builder failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
failingRequestBuilder struct{}
|
|
||||||
failingMarshaller struct{}
|
|
||||||
)
|
|
||||||
|
|
||||||
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
|
|
||||||
return []byte{}, errTestMarshallerFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
|
|
||||||
return nil, errTestRequestBuilderFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
|
|
||||||
builder := httpRequestBuilder{
|
|
||||||
marshaller: &failingMarshaller{},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := builder.build(context.Background(), "", "", struct{}{})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
|
||||||
var err error
|
|
||||||
ts := test.NewTestServer().OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
client.requestBuilder = &failingRequestBuilder{}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListFineTunes(ctx)
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CancelFineTune(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.GetFineTune(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.DeleteFineTune(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListFineTuneEvents(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.Moderations(ctx, ModerationRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.Edits(ctx, EditsRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateImage(ctx, ImageRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = client.DeleteFile(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.GetFile(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListFiles(ctx)
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListEngines(ctx)
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.GetEngine(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListModels(ctx)
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) {
|
|
||||||
var err error
|
|
||||||
ts := test.NewTestServer().OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
client.requestBuilder = &failingRequestBuilder{}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
|
|
||||||
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
|
|
||||||
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user