Add testable request builder (#162)

* Add testable request builder

* improve tests
This commit is contained in:
sashabaranov
2023-03-15 13:16:33 +04:00
committed by GitHub
parent 53d195cf5a
commit c34bc77f1a
13 changed files with 205 additions and 152 deletions

18
api.go
View File

@@ -1,7 +1,6 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -12,7 +11,7 @@ import (
type Client struct {
config ClientConfig
marshaller marshaller
requestBuilder requestBuilder
}
// NewClient creates new OpenAI API client.
@@ -25,7 +24,7 @@ func NewClient(authToken string) *Client {
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{
config: config,
marshaller: &jsonMarshaller{},
requestBuilder: newRequestBuilder(),
}
}
@@ -91,17 +90,8 @@ func (c *Client) newStreamRequest(
ctx context.Context,
method string,
urlSuffix string,
body interface{}) (*http.Request, error) {
var reqBody []byte
if body != nil {
var err error
reqBody, err = c.marshaller.marshal(body)
if err != nil {
return nil, err
}
}
req, err := http.NewRequestWithContext(ctx, method, c.fullURL(urlSuffix), bytes.NewBuffer(reqBody))
body any) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
if err != nil {
return nil, err
}

View File

@@ -1,7 +1,6 @@
package openai
import (
"bytes"
"context"
"errors"
"net/http"
@@ -72,14 +71,8 @@ func (c *Client) CreateChatCompletion(
return
}
var reqBytes []byte
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
urlSuffix := "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}

View File

@@ -1,7 +1,6 @@
package openai
import (
"bytes"
"context"
"errors"
"net/http"
@@ -105,14 +104,8 @@ func (c *Client) CreateCompletion(
return
}
var reqBytes []byte
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
urlSuffix := "/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}

View File

@@ -1,7 +1,6 @@
package openai
import (
"bytes"
"context"
"net/http"
)
@@ -32,13 +31,7 @@ type EditsResponse struct {
// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
var reqBytes []byte
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/edits"), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request)
if err != nil {
return
}

View File

@@ -1,7 +1,6 @@
package openai
import (
"bytes"
"context"
"net/http"
)
@@ -133,14 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
var reqBytes []byte
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
urlSuffix := "/embeddings"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
if err != nil {
return
}

View File

@@ -22,7 +22,7 @@ type EnginesList struct {
// ListEngines Lists the currently available engines, and provides basic
// information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/engines"), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
if err != nil {
return
}
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
engineID string,
) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}

View File

@@ -112,7 +112,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
// DeleteFile deletes an existing file.
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
if err != nil {
return
}
@@ -124,7 +124,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
// ListFiles Lists the currently available files,
// and provides basic information about each file such as the file name and purpose.
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/files"), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil)
if err != nil {
return
}
@@ -137,7 +137,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
// such as the file name and purpose.
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
urlSuffix := fmt.Sprintf("/files/%s", fileID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}

View File

@@ -1,7 +1,6 @@
package openai
import (
"bytes"
"context"
"fmt"
"net/http"
@@ -68,14 +67,8 @@ type FineTuneDeleteResponse struct {
}
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
var reqBytes []byte
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
urlSuffix := "/fine-tunes"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}
@@ -86,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r
// CancelFineTune cancel a fine-tune job.
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
if err != nil {
return
}
@@ -96,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons
}
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
if err != nil {
return
}
@@ -107,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}
@@ -117,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F
}
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
if err != nil {
return
}
@@ -127,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons
}
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
if err != nil {
return
}

View File

@@ -45,14 +45,8 @@ type ImageResponseDataInner struct {
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
var reqBytes []byte
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
urlSuffix := "/images/generations"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil {
return
}

View File

@@ -1,71 +0,0 @@
package openai //nolint:testpackage // testing private field
import (
"github.com/sashabaranov/go-openai/internal/test"
"context"
"errors"
"testing"
)
type failingMarshaller struct{}
var errTestMarshallerFailed = errors.New("test marshaller failed")
func (jm *failingMarshaller) marshal(value any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}
func TestClientReturnMarshallerErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.marshaller = &failingMarshaller{}
ctx := context.Background()
_, err = client.CreateCompletion(ctx, CompletionRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}

View File

@@ -1,7 +1,6 @@
package openai
import (
"bytes"
"context"
"net/http"
)
@@ -51,13 +50,7 @@ type ModerationResponse struct {
// Moderations — perform a moderation api call over a string.
// Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
var reqBytes []byte
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/moderations"), bytes.NewBuffer(reqBytes))
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request)
if err != nil {
return
}

40
request_builder.go Normal file
View File

@@ -0,0 +1,40 @@
package openai
import (
"bytes"
"context"
"net/http"
)
type requestBuilder interface {
build(ctx context.Context, method, url string, request any) (*http.Request, error)
}
type httpRequestBuilder struct {
marshaller marshaller
}
func newRequestBuilder() *httpRequestBuilder {
return &httpRequestBuilder{
marshaller: &jsonMarshaller{},
}
}
func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) {
if request == nil {
return http.NewRequestWithContext(ctx, method, url, nil)
}
var reqBytes []byte
reqBytes, err := b.marshaller.marshal(request)
if err != nil {
return nil, err
}
return http.NewRequestWithContext(
ctx,
method,
url,
bytes.NewBuffer(reqBytes),
)
}

143
request_builder_test.go Normal file
View File

@@ -0,0 +1,143 @@
package openai //nolint:testpackage // testing private field
import (
"github.com/sashabaranov/go-openai/internal/test"
"context"
"errors"
"net/http"
"testing"
)
var (
errTestMarshallerFailed = errors.New("test marshaller failed")
errTestRequestBuilderFailed = errors.New("test request builder failed")
)
type (
failingRequestBuilder struct{}
failingMarshaller struct{}
)
func (*failingMarshaller) marshal(value any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
builder := httpRequestBuilder{
marshaller: &failingMarshaller{},
}
_, err := builder.build(context.Background(), "", "", struct{}{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}
ctx := context.Background()
_, err = client.CreateCompletion(ctx, CompletionRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFineTunes(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CancelFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.DeleteFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFineTuneEvents(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
err = client.DeleteFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFiles(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListEngines(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetEngine(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}