move request_builder into internal pkg (#304) (#329)

* move request_builder into internal pkg (#304)

* add some test for internal.RequestBuilder

* add a test for openai.GetEngine
This commit is contained in:
渡邉祐一 / Yuichi Watanabe
2023-05-31 17:01:42 +09:00
committed by GitHub
parent 62eb4beed2
commit 61ba5f3369
16 changed files with 273 additions and 208 deletions

View File

@@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion(
return return
} }
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -15,7 +15,7 @@ import (
type Client struct { type Client struct {
config ClientConfig config ClientConfig
requestBuilder requestBuilder requestBuilder utils.RequestBuilder
createFormBuilder func(io.Writer) utils.FormBuilder createFormBuilder func(io.Writer) utils.FormBuilder
} }
@@ -29,7 +29,7 @@ func NewClient(authToken string) *Client {
func NewClientWithConfig(config ClientConfig) *Client { func NewClientWithConfig(config ClientConfig) *Client {
return &Client{ return &Client{
config: config, config: config,
requestBuilder: newRequestBuilder(), requestBuilder: utils.NewRequestBuilder(),
createFormBuilder: func(body io.Writer) utils.FormBuilder { createFormBuilder: func(body io.Writer) utils.FormBuilder {
return utils.NewFormBuilder(body) return utils.NewFormBuilder(body)
}, },
@@ -135,7 +135,7 @@ func (c *Client) newStreamRequest(
urlSuffix string, urlSuffix string,
body any, body any,
model string) (*http.Request, error) { model string) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body) req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -2,13 +2,24 @@ package openai //nolint:testpackage // testing private field
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"testing" "testing"
"github.com/sashabaranov/go-openai/internal/test"
) )
var errTestRequestBuilderFailed = errors.New("test request builder failed")
type failingRequestBuilder struct{}
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
const mockToken = "mock token" const mockToken = "mock token"
client := NewClient(mockToken) client := NewClient(mockToken)
@@ -145,3 +156,141 @@ func TestHandleErrorResp(t *testing.T) {
}) })
} }
} }
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}
ctx := context.Background()
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFineTunes(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CancelFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.DeleteFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFineTuneEvents(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
err = client.DeleteFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFiles(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListEngines(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetEngine(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListModels(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}
func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}
ctx := context.Background()
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

View File

@@ -155,7 +155,7 @@ func (c *Client) CreateCompletion(
return return
} }
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -32,7 +32,7 @@ type EditsResponse struct {
// Perform an API call to the Edits endpoint. // Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -132,7 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create // https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -22,7 +22,7 @@ type EnginesList struct {
// ListEngines Lists the currently available engines, and provides basic // ListEngines Lists the currently available engines, and provides basic
// information about each option such as the owner and availability. // information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil) req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
if err != nil { if err != nil {
return return
} }
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
engineID string, engineID string,
) (engine Engine, err error) { ) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID) urlSuffix := fmt.Sprintf("/engines/%s", engineID)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil { if err != nil {
return return
} }

34
engines_test.go Normal file
View File

@@ -0,0 +1,34 @@
package openai_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"testing"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
func TestGetEngine(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(Engine{})
fmt.Fprintln(w, string(resBytes))
})
// create the test server
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err := client.GetEngine(ctx, "text-davinci-003")
checks.NoError(t, err, "GetEngine error")
}

View File

@@ -70,7 +70,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
// DeleteFile deletes an existing file. // DeleteFile deletes an existing file.
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
if err != nil { if err != nil {
return return
} }
@@ -82,7 +82,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
// ListFiles Lists the currently available files, // ListFiles Lists the currently available files,
// and provides basic information about each file such as the file name and purpose. // and provides basic information about each file such as the file name and purpose.
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil) req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil)
if err != nil { if err != nil {
return return
} }
@@ -95,7 +95,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
// such as the file name and purpose. // such as the file name and purpose.
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
urlSuffix := fmt.Sprintf("/files/%s", fileID) urlSuffix := fmt.Sprintf("/files/%s", fileID)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil { if err != nil {
return return
} }

View File

@@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct {
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
urlSuffix := "/fine-tunes" urlSuffix := "/fine-tunes"
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil { if err != nil {
return return
} }
@@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r
// CancelFineTune cancel a fine-tune job. // CancelFineTune cancel a fine-tune job.
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
if err != nil { if err != nil {
return return
} }
@@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons
} }
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
if err != nil { if err != nil {
return return
} }
@@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil { if err != nil {
return return
} }
@@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F
} }
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
if err != nil { if err != nil {
return return
} }
@@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons
} }
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
if err != nil { if err != nil {
return return
} }

View File

@@ -44,7 +44,7 @@ type ImageResponseDataInner struct {
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
urlSuffix := "/images/generations" urlSuffix := "/images/generations"
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -4,25 +4,23 @@ import (
"bytes" "bytes"
"context" "context"
"net/http" "net/http"
utils "github.com/sashabaranov/go-openai/internal"
) )
type requestBuilder interface { type RequestBuilder interface {
build(ctx context.Context, method, url string, request any) (*http.Request, error) Build(ctx context.Context, method, url string, request any) (*http.Request, error)
} }
type httpRequestBuilder struct { type HTTPRequestBuilder struct {
marshaller utils.Marshaller marshaller Marshaller
} }
func newRequestBuilder() *httpRequestBuilder { func NewRequestBuilder() *HTTPRequestBuilder {
return &httpRequestBuilder{ return &HTTPRequestBuilder{
marshaller: &utils.JSONMarshaller{}, marshaller: &JSONMarshaller{},
} }
} }
func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) { func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) {
if request == nil { if request == nil {
return http.NewRequestWithContext(ctx, method, url, nil) return http.NewRequestWithContext(ctx, method, url, nil)
} }

View File

@@ -0,0 +1,61 @@
package openai //nolint:testpackage // testing private field
import (
"bytes"
"context"
"errors"
"net/http"
"reflect"
"testing"
)
var errTestMarshallerFailed = errors.New("test marshaller failed")
type failingMarshaller struct{}
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
builder := HTTPRequestBuilder{
marshaller: &failingMarshaller{},
}
_, err := builder.Build(context.Background(), "", "", struct{}{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}
func TestRequestBuilderReturnsRequest(t *testing.T) {
b := NewRequestBuilder()
var (
ctx = context.Background()
method = http.MethodPost
url = "/foo"
request = map[string]string{"foo": "bar"}
reqBytes, _ = b.marshaller.Marshal(request)
want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes))
)
got, _ := b.Build(ctx, method, url, request)
if !reflect.DeepEqual(got.Body, want.Body) ||
!reflect.DeepEqual(got.URL, want.URL) ||
!reflect.DeepEqual(got.Method, want.Method) {
t.Errorf("Build() got = %v, want %v", got, want)
}
}
func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
var (
ctx = context.Background()
method = http.MethodGet
url = "/foo"
want, _ = http.NewRequestWithContext(ctx, method, url, nil)
)
b := NewRequestBuilder()
got, _ := b.Build(ctx, method, url, nil)
if !reflect.DeepEqual(got, want) {
t.Errorf("Build() got = %v, want %v", got, want)
}
}

View File

@@ -40,7 +40,7 @@ type ModelsList struct {
// ListModels Lists the currently available models, // ListModels Lists the currently available models,
// and provides basic information about each model such as the model id and parent. // and provides basic information about each model such as the model id and parent.
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil) req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil)
if err != nil { if err != nil {
return return
} }

View File

@@ -63,7 +63,7 @@ type ModerationResponse struct {
// Moderations — perform a moderation api call over a string. // Moderations — perform a moderation api call over a string.
// Input can be an array or slice but a string will reduce the complexity. // Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -1,177 +0,0 @@
package openai //nolint:testpackage // testing private field
import (
"github.com/sashabaranov/go-openai/internal/test"
"context"
"errors"
"net/http"
"testing"
)
var (
errTestMarshallerFailed = errors.New("test marshaller failed")
errTestRequestBuilderFailed = errors.New("test request builder failed")
)
type (
failingRequestBuilder struct{}
failingMarshaller struct{}
)
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}
func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
builder := httpRequestBuilder{
marshaller: &failingMarshaller{},
}
_, err := builder.build(context.Background(), "", "", struct{}{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}
ctx := context.Background()
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFineTunes(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CancelFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.DeleteFineTune(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFineTuneEvents(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
err = client.DeleteFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetFile(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListFiles(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListEngines(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.GetEngine(ctx, "")
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.ListModels(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}
func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{}
ctx := context.Background()
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}