diff --git a/api.go b/api.go index 90e5066..3e14c8c 100644 --- a/api.go +++ b/api.go @@ -11,17 +11,22 @@ import ( // Client is OpenAI GPT-3 API client. type Client struct { config ClientConfig + + marshaller marshaller } // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) - return &Client{config} + return NewClientWithConfig(config) } // NewClientWithConfig creates new OpenAI API client for specified config. func NewClientWithConfig(config ClientConfig) *Client { - return &Client{config} + return &Client{ + config: config, + marshaller: &jsonMarshaller{}, + } } // NewOrgClient creates new OpenAI API client for specified Organization ID. @@ -30,7 +35,7 @@ func NewClientWithConfig(config ClientConfig) *Client { func NewOrgClient(authToken, org string) *Client { config := DefaultConfig(authToken) config.OrgID = org - return &Client{config} + return NewClientWithConfig(config) } func (c *Client) sendRequest(req *http.Request, v interface{}) error { @@ -90,7 +95,7 @@ func (c *Client) newStreamRequest( var reqBody []byte if body != nil { var err error - reqBody, err = json.Marshal(body) + reqBody, err = c.marshaller.marshal(body) if err != nil { return nil, err } diff --git a/chat.go b/chat.go index cfa86b6..035ab56 100644 --- a/chat.go +++ b/chat.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "context" - "encoding/json" "errors" "net/http" ) @@ -74,7 +73,7 @@ func (c *Client) CreateChatCompletion( } var reqBytes []byte - reqBytes, err = json.Marshal(request) + reqBytes, err = c.marshaller.marshal(request) if err != nil { return } diff --git a/completion.go b/completion.go index e727193..79a6de7 100644 --- a/completion.go +++ b/completion.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "context" - "encoding/json" "errors" "net/http" ) @@ -107,7 +106,7 @@ func (c *Client) CreateCompletion( } var reqBytes []byte - reqBytes, err = json.Marshal(request) + reqBytes, err = c.marshaller.marshal(request) if err != nil { return } diff --git a/edits.go b/edits.go index fa2706d..169665e 100644 --- a/edits.go +++ b/edits.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "context" - "encoding/json" "net/http" ) @@ -34,7 +33,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { var reqBytes []byte - reqBytes, err = json.Marshal(request) + reqBytes, err = c.marshaller.marshal(request) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index e8247d2..5a6f078 100644 --- a/embeddings.go +++ b/embeddings.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "context" - "encoding/json" "net/http" ) @@ -135,7 +134,7 @@ type EmbeddingRequest struct { // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { var reqBytes []byte - reqBytes, err = json.Marshal(request) + reqBytes, err = c.marshaller.marshal(request) if err != nil { return } diff --git a/fine_tunes.go b/fine_tunes.go index 458a169..a296b5b 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "context" - "encoding/json" "fmt" "net/http" ) @@ -70,7 +69,7 @@ type FineTuneDeleteResponse struct { func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { var reqBytes []byte - reqBytes, err = json.Marshal(request) + reqBytes, err = c.marshaller.marshal(request) if err != nil { return } diff --git a/image.go b/image.go index 107d1bb..35fa4e6 100644 --- a/image.go +++ b/image.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "context" - "encoding/json" "io" "mime/multipart" "net/http" @@ -47,7 +46,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { var reqBytes []byte - reqBytes, err = json.Marshal(request) + reqBytes, err = c.marshaller.marshal(request) if err != nil { return } diff --git a/marshaller.go b/marshaller.go new file mode 100644 index 0000000..308ccd1 --- /dev/null +++ b/marshaller.go @@ -0,0 +1,15 @@ +package openai + +import ( + "encoding/json" +) + +type marshaller interface { + marshal(value any) ([]byte, error) +} + +type jsonMarshaller struct{} + +func (jm *jsonMarshaller) marshal(value any) ([]byte, error) { + return json.Marshal(value) +} diff --git a/marshaller_test.go b/marshaller_test.go new file mode 100644 index 0000000..096fe52 --- /dev/null +++ b/marshaller_test.go @@ -0,0 +1,71 @@ +package openai //nolint:testpackage // testing private field + +import ( + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "errors" + "testing" +) + +type failingMarshaller struct{} + +var errTestMarshallerFailed = errors.New("test marshaller failed") + +func (jm *failingMarshaller) marshal(value any) ([]byte, error) { + return []byte{}, errTestMarshallerFailed +} + +func TestClientReturnMarshallerErrors(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + client.marshaller = &failingMarshaller{} + + ctx := context.Background() + + _, err = client.CreateCompletion(ctx, CompletionRequest{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } + + _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } + + _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } + + _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } + + _, err = client.Moderations(ctx, ModerationRequest{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } + + _, err = client.Edits(ctx, EditsRequest{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } + + _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } + + _, err = client.CreateImage(ctx, ImageRequest{}) + if !errors.Is(err, errTestMarshallerFailed) { + t.Fatalf("Did not return error when marshaller failed: %v", err) + } +} diff --git a/moderation.go b/moderation.go index 789443e..745b709 100644 --- a/moderation.go +++ b/moderation.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "context" - "encoding/json" "net/http" ) @@ -53,7 +52,7 @@ type ModerationResponse struct { // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { var reqBytes []byte - reqBytes, err = json.Marshal(request) + reqBytes, err = c.marshaller.marshal(request) if err != nil { return }