From c34bc77f1ae0a8a85169ba74e3f3b16374c77313 Mon Sep 17 00:00:00 2001 From: sashabaranov <677093+sashabaranov@users.noreply.github.com> Date: Wed, 15 Mar 2023 13:16:33 +0400 Subject: [PATCH] Add testable request builder (#162) * Add testable request builder * improve tests --- api.go | 20 ++---- chat.go | 9 +-- completion.go | 9 +-- edits.go | 9 +-- embeddings.go | 10 +-- engines.go | 4 +- files.go | 6 +- fine_tunes.go | 19 ++---- image.go | 8 +-- marshaller_test.go | 71 -------------------- moderation.go | 9 +-- request_builder.go | 40 +++++++++++ request_builder_test.go | 143 ++++++++++++++++++++++++++++++++++++++++ 13 files changed, 205 insertions(+), 152 deletions(-) delete mode 100644 marshaller_test.go create mode 100644 request_builder.go create mode 100644 request_builder_test.go diff --git a/api.go b/api.go index 3e14c8c..00d6d35 100644 --- a/api.go +++ b/api.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "context" "encoding/json" "fmt" @@ -12,7 +11,7 @@ import ( type Client struct { config ClientConfig - marshaller marshaller + requestBuilder requestBuilder } // NewClient creates new OpenAI API client. @@ -24,8 +23,8 @@ func NewClient(authToken string) *Client { // NewClientWithConfig creates new OpenAI API client for specified config. func NewClientWithConfig(config ClientConfig) *Client { return &Client{ - config: config, - marshaller: &jsonMarshaller{}, + config: config, + requestBuilder: newRequestBuilder(), } } @@ -91,17 +90,8 @@ func (c *Client) newStreamRequest( ctx context.Context, method string, urlSuffix string, - body interface{}) (*http.Request, error) { - var reqBody []byte - if body != nil { - var err error - reqBody, err = c.marshaller.marshal(body) - if err != nil { - return nil, err - } - } - - req, err := http.NewRequestWithContext(ctx, method, c.fullURL(urlSuffix), bytes.NewBuffer(reqBody)) + body any) (*http.Request, error) { + req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body) if err != nil { return nil, err } diff --git a/chat.go b/chat.go index 035ab56..14be6f4 100644 --- a/chat.go +++ b/chat.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "context" "errors" "net/http" @@ -72,14 +71,8 @@ func (c *Client) CreateChatCompletion( return } - var reqBytes []byte - reqBytes, err = c.marshaller.marshal(request) - if err != nil { - return - } - urlSuffix := "/chat/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } diff --git a/completion.go b/completion.go index 79a6de7..66b4866 100644 --- a/completion.go +++ b/completion.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "context" "errors" "net/http" @@ -105,14 +104,8 @@ func (c *Client) CreateCompletion( return } - var reqBytes []byte - reqBytes, err = c.marshaller.marshal(request) - if err != nil { - return - } - urlSuffix := "/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } diff --git a/edits.go b/edits.go index 169665e..858a8e5 100644 --- a/edits.go +++ b/edits.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "context" "net/http" ) @@ -32,13 +31,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - var reqBytes []byte - reqBytes, err = c.marshaller.marshal(request) - if err != nil { - return - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/edits"), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index 5a6f078..2deaccc 100644 --- a/embeddings.go +++ b/embeddings.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "context" "net/http" ) @@ -133,14 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - var reqBytes []byte - reqBytes, err = c.marshaller.marshal(request) - if err != nil { - return - } - - urlSuffix := "/embeddings" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) if err != nil { return } diff --git a/engines.go b/engines.go index cc40248..bb6a66c 100644 --- a/engines.go +++ b/engines.go @@ -22,7 +22,7 @@ type EnginesList struct { // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/engines"), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil) if err != nil { return } @@ -38,7 +38,7 @@ func (c *Client) GetEngine( engineID string, ) (engine Engine, err error) { urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } diff --git a/files.go b/files.go index 9e731db..ec441c3 100644 --- a/files.go +++ b/files.go @@ -112,7 +112,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) + req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) if err != nil { return } @@ -124,7 +124,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/files"), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil) if err != nil { return } @@ -137,7 +137,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } diff --git a/fine_tunes.go b/fine_tunes.go index a296b5b..a121867 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "context" "fmt" "net/http" @@ -68,14 +67,8 @@ type FineTuneDeleteResponse struct { } func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { - var reqBytes []byte - reqBytes, err = c.marshaller.marshal(request) - if err != nil { - return - } - urlSuffix := "/fine-tunes" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } @@ -86,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // CancelFineTune cancel a fine-tune job. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) if err != nil { return } @@ -96,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) if err != nil { return } @@ -107,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) if err != nil { return } @@ -117,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F } func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) if err != nil { return } @@ -127,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) if err != nil { return } diff --git a/image.go b/image.go index 35fa4e6..c0dfa64 100644 --- a/image.go +++ b/image.go @@ -45,14 +45,8 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { - var reqBytes []byte - reqBytes, err = c.marshaller.marshal(request) - if err != nil { - return - } - urlSuffix := "/images/generations" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return } diff --git a/marshaller_test.go b/marshaller_test.go deleted file mode 100644 index 096fe52..0000000 --- a/marshaller_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package openai //nolint:testpackage // testing private field - -import ( - "github.com/sashabaranov/go-openai/internal/test" - - "context" - "errors" - "testing" -) - -type failingMarshaller struct{} - -var errTestMarshallerFailed = errors.New("test marshaller failed") - -func (jm *failingMarshaller) marshal(value any) ([]byte, error) { - return []byte{}, errTestMarshallerFailed -} - -func TestClientReturnMarshallerErrors(t *testing.T) { - var err error - ts := test.NewTestServer().OpenAITestServer() - ts.Start() - defer ts.Close() - - config := DefaultConfig(test.GetTestToken()) - config.BaseURL = ts.URL + "/v1" - client := NewClientWithConfig(config) - client.marshaller = &failingMarshaller{} - - ctx := context.Background() - - _, err = client.CreateCompletion(ctx, CompletionRequest{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } - - _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } - - _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } - - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } - - _, err = client.Moderations(ctx, ModerationRequest{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } - - _, err = client.Edits(ctx, EditsRequest{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } - - _, err = client.CreateImage(ctx, ImageRequest{}) - if !errors.Is(err, errTestMarshallerFailed) { - t.Fatalf("Did not return error when marshaller failed: %v", err) - } -} diff --git a/moderation.go b/moderation.go index 745b709..ff789a6 100644 --- a/moderation.go +++ b/moderation.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "context" "net/http" ) @@ -51,13 +50,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - var reqBytes []byte - reqBytes, err = c.marshaller.marshal(request) - if err != nil { - return - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/moderations"), bytes.NewBuffer(reqBytes)) + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request) if err != nil { return } diff --git a/request_builder.go b/request_builder.go new file mode 100644 index 0000000..f0cef10 --- /dev/null +++ b/request_builder.go @@ -0,0 +1,40 @@ +package openai + +import ( + "bytes" + "context" + "net/http" +) + +type requestBuilder interface { + build(ctx context.Context, method, url string, request any) (*http.Request, error) +} + +type httpRequestBuilder struct { + marshaller marshaller +} + +func newRequestBuilder() *httpRequestBuilder { + return &httpRequestBuilder{ + marshaller: &jsonMarshaller{}, + } +} + +func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) { + if request == nil { + return http.NewRequestWithContext(ctx, method, url, nil) + } + + var reqBytes []byte + reqBytes, err := b.marshaller.marshal(request) + if err != nil { + return nil, err + } + + return http.NewRequestWithContext( + ctx, + method, + url, + bytes.NewBuffer(reqBytes), + ) +} diff --git a/request_builder_test.go b/request_builder_test.go new file mode 100644 index 0000000..c06112c --- /dev/null +++ b/request_builder_test.go @@ -0,0 +1,143 @@ +package openai //nolint:testpackage // testing private field + +import ( + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "errors" + "net/http" + "testing" +) + +var ( + errTestMarshallerFailed = errors.New("test marshaller failed") + errTestRequestBuilderFailed = errors.New("test request builder failed") +) + +type ( + failingRequestBuilder struct{} + failingMarshaller struct{} +) + +func (*failingMarshaller) marshal(value any) ([]byte, error) { + return []byte{}, errTestMarshallerFailed +} + +func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) { + return nil, errTestRequestBuilderFailed +} + +func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { + builder := httpRequestBuilder{ + marshaller: &failingMarshaller{}, + } + + _, err := builder.build(context.Background(), "", "", struct{}{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } +} + +func TestClientReturnsRequestBuilderErrors(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + ctx := context.Background() + + _, err = client.CreateCompletion(ctx, CompletionRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFineTunes(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CancelFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.DeleteFineTune(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFineTuneEvents(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.Moderations(ctx, ModerationRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.Edits(ctx, EditsRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateImage(ctx, ImageRequest{}) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + err = client.DeleteFile(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetFile(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListFiles(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.ListEngines(ctx) + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.GetEngine(ctx, "") + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } +}