Add tests (#171)

* test models listing

* remove non-needed method

* test for .streamFinished

* add more error tests

* improve stream testing

* fix typo
This commit is contained in:
sashabaranov
2023-03-16 19:10:27 +04:00
committed by GitHub
parent fd44d3665e
commit a8acb5f63b
6 changed files with 23 additions and 5 deletions

View File

@@ -141,6 +141,9 @@ func TestAPIError(t *testing.T) {
if *apiErr.Code != "invalid_api_key" { if *apiErr.Code != "invalid_api_key" {
t.Fatalf("Unexpected API error code: %s", *apiErr.Code) t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
} }
if apiErr.Error() == "" {
t.Fatal("Empty error message occured")
}
} }
func TestRequestError(t *testing.T) { func TestRequestError(t *testing.T) {
@@ -163,6 +166,10 @@ func TestRequestError(t *testing.T) {
if reqErr.StatusCode != 418 { if reqErr.StatusCode != 418 {
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
} }
if reqErr.Unwrap() == nil {
t.Fatalf("Empty request error occured")
}
} }
// numTokens Returns the number of GPT-3 encoded tokens in the given text. // numTokens Returns the number of GPT-3 encoded tokens in the given text.

View File

@@ -78,10 +78,6 @@ func (stream *ChatCompletionStream) Close() {
stream.response.Body.Close() stream.response.Body.Close()
} }
func (stream *ChatCompletionStream) GetResponse() *http.Response {
return stream.response
}
// CreateChatCompletionStream — API call to create a chat completion w/ streaming // CreateChatCompletionStream — API call to create a chat completion w/ streaming
// support. It sets whether to stream back partial progress. If set, tokens will be // support. It sets whether to stream back partial progress. If set, tokens will be
// sent as data-only server-sent events as they become available, with the // sent as data-only server-sent events as they become available, with the

View File

@@ -116,6 +116,11 @@ func TestCreateChatCompletionStream(t *testing.T) {
if !errors.Is(streamErr, io.EOF) { if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
} }
_, streamErr = stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
}
} }
// Helper funcs. // Helper funcs.

View File

@@ -40,7 +40,7 @@ type ModelsList struct {
// ListModels Lists the currently available models, // ListModels Lists the currently available models,
// and provides basic information about each model such as the model id and parent. // and provides basic information about each model such as the model id and parent.
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/models"), nil) req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil)
if err != nil { if err != nil {
return return
} }

View File

@@ -140,4 +140,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
if !errors.Is(err, errTestRequestBuilderFailed) { if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err) t.Fatalf("Did not return error when request builder failed: %v", err)
} }
_, err = client.ListModels(ctx)
if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
} }

View File

@@ -93,6 +93,11 @@ func TestCreateCompletionStream(t *testing.T) {
if !errors.Is(streamErr, io.EOF) { if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
} }
_, streamErr = stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
}
} }
// A "tokenRoundTripper" is a struct that implements the RoundTripper // A "tokenRoundTripper" is a struct that implements the RoundTripper