fix: chat stream resp error (#259)
This commit is contained in:
17
api_test.go
17
api_test.go
@@ -1,16 +1,15 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAPI(t *testing.T) {
|
func TestAPI(t *testing.T) {
|
||||||
@@ -119,8 +118,8 @@ func TestAPIError(t *testing.T) {
|
|||||||
t.Fatalf("Error is not an APIError: %+v", err)
|
t.Fatalf("Error is not an APIError: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if apiErr.StatusCode != 401 {
|
if apiErr.HTTPStatusCode != 401 {
|
||||||
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
|
t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v := apiErr.Code.(type) {
|
switch v := apiErr.Code.(type) {
|
||||||
@@ -239,8 +238,8 @@ func TestRequestError(t *testing.T) {
|
|||||||
t.Fatalf("Error is not a RequestError: %+v", err)
|
t.Fatalf("Error is not a RequestError: %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if reqErr.StatusCode != 418 {
|
if reqErr.HTTPStatusCode != 418 {
|
||||||
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
|
t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
if reqErr.Unwrap() == nil {
|
if reqErr.Unwrap() == nil {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatCompletionStreamChoiceDelta struct {
|
type ChatCompletionStreamChoiceDelta struct {
|
||||||
@@ -53,6 +54,9 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
|
||||||
|
return nil, c.handleErrorResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
stream = &ChatCompletionStream{
|
stream = &ChatCompletionStream{
|
||||||
streamReader: &streamReader[ChatCompletionStreamResponse]{
|
streamReader: &streamReader[ChatCompletionStreamResponse]{
|
||||||
|
|||||||
@@ -204,6 +204,57 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
t.Logf("%+v\n", apiErr)
|
t.Logf("%+v\n", apiErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
||||||
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(429)
|
||||||
|
|
||||||
|
// Send test responses
|
||||||
|
dataBytes := []byte(`{"error":{` +
|
||||||
|
`"message": "You are sending requests too quickly.",` +
|
||||||
|
`"type":"rate_limit_reached",` +
|
||||||
|
`"param":null,` +
|
||||||
|
`"code":"rate_limit_reached"}}`)
|
||||||
|
|
||||||
|
_, err := w.Write(dataBytes)
|
||||||
|
checks.NoError(t, err, "Write error")
|
||||||
|
})
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// Client portion of the test
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
config.HTTPClient.Transport = &tokenRoundTripper{
|
||||||
|
test.GetTestToken(),
|
||||||
|
http.DefaultTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
request := ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Dot5Turbo,
|
||||||
|
Messages: []ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiErr *APIError
|
||||||
|
_, err := client.CreateChatCompletionStream(ctx, request)
|
||||||
|
if !errors.As(err, &apiErr) {
|
||||||
|
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
|
||||||
|
}
|
||||||
|
t.Logf("%+v\n", apiErr)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper funcs.
|
// Helper funcs.
|
||||||
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
|
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
|
||||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
||||||
|
|||||||
26
client.go
26
client.go
@@ -72,17 +72,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
|||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
|
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
|
||||||
var errRes ErrorResponse
|
return c.handleErrorResp(res)
|
||||||
err = json.NewDecoder(res.Body).Decode(&errRes)
|
|
||||||
if err != nil || errRes.Error == nil {
|
|
||||||
reqErr := RequestError{
|
|
||||||
StatusCode: res.StatusCode,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
return fmt.Errorf("error, %w", &reqErr)
|
|
||||||
}
|
|
||||||
errRes.Error.StatusCode = res.StatusCode
|
|
||||||
return fmt.Errorf("error, status code: %d, message: %w", res.StatusCode, errRes.Error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if v != nil {
|
if v != nil {
|
||||||
@@ -132,3 +122,17 @@ func (c *Client) newStreamRequest(
|
|||||||
}
|
}
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleErrorResp(resp *http.Response) error {
|
||||||
|
var errRes ErrorResponse
|
||||||
|
err := json.NewDecoder(resp.Body).Decode(&errRes)
|
||||||
|
if err != nil || errRes.Error == nil {
|
||||||
|
reqErr := RequestError{
|
||||||
|
HTTPStatusCode: resp.StatusCode,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error, %w", &reqErr)
|
||||||
|
}
|
||||||
|
errRes.Error.HTTPStatusCode = resp.StatusCode
|
||||||
|
return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
|
||||||
|
}
|
||||||
|
|||||||
6
error.go
6
error.go
@@ -11,12 +11,12 @@ type APIError struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Param *string `json:"param,omitempty"`
|
Param *string `json:"param,omitempty"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
StatusCode int `json:"-"`
|
HTTPStatusCode int `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestError provides informations about generic request errors.
|
// RequestError provides informations about generic request errors.
|
||||||
type RequestError struct {
|
type RequestError struct {
|
||||||
StatusCode int
|
HTTPStatusCode int
|
||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ func (e *RequestError) Error() string {
|
|||||||
if e.Err != nil {
|
if e.Err != nil {
|
||||||
return e.Err.Error()
|
return e.Err.Error()
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("status code %d", e.StatusCode)
|
return fmt.Sprintf("status code %d", e.HTTPStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *RequestError) Unwrap() error {
|
func (e *RequestError) Unwrap() error {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
@@ -71,7 +72,11 @@ func TestErrorByteWriteErrors(t *testing.T) {
|
|||||||
|
|
||||||
func TestErrorAccumulatorWriteErrors(t *testing.T) {
|
func TestErrorAccumulatorWriteErrors(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
ts := test.NewTestServer().OpenAITestServer()
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "error", 200)
|
||||||
|
})
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
ts.Start()
|
ts.Start()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -43,6 +44,9 @@ func (c *Client) CreateCompletionStream(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
|
||||||
|
return nil, c.handleErrorResp(resp)
|
||||||
|
}
|
||||||
|
|
||||||
stream = &CompletionStream{
|
stream = &CompletionStream{
|
||||||
streamReader: &streamReader[CompletionResponse]{
|
streamReader: &streamReader[CompletionResponse]{
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCompletionsStreamWrongModel(t *testing.T) {
|
func TestCompletionsStreamWrongModel(t *testing.T) {
|
||||||
@@ -171,6 +171,52 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
|||||||
t.Logf("%+v\n", apiErr)
|
t.Logf("%+v\n", apiErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateCompletionStreamRateLimitError(t *testing.T) {
|
||||||
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(429)
|
||||||
|
|
||||||
|
// Send test responses
|
||||||
|
dataBytes := []byte(`{"error":{` +
|
||||||
|
`"message": "You are sending requests too quickly.",` +
|
||||||
|
`"type":"rate_limit_reached",` +
|
||||||
|
`"param":null,` +
|
||||||
|
`"code":"rate_limit_reached"}}`)
|
||||||
|
|
||||||
|
_, err := w.Write(dataBytes)
|
||||||
|
checks.NoError(t, err, "Write error")
|
||||||
|
})
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// Client portion of the test
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
config.HTTPClient.Transport = &tokenRoundTripper{
|
||||||
|
test.GetTestToken(),
|
||||||
|
http.DefaultTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
request := CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Ada,
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiErr *APIError
|
||||||
|
_, err := client.CreateCompletionStream(ctx, request)
|
||||||
|
if !errors.As(err, &apiErr) {
|
||||||
|
t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError")
|
||||||
|
}
|
||||||
|
t.Logf("%+v\n", apiErr)
|
||||||
|
}
|
||||||
|
|
||||||
// A "tokenRoundTripper" is a struct that implements the RoundTripper
|
// A "tokenRoundTripper" is a struct that implements the RoundTripper
|
||||||
// interface, specifically to handle the authentication token by adding a token
|
// interface, specifically to handle the authentication token by adding a token
|
||||||
// to the request header. We need this because the API requires that each
|
// to the request header. We need this because the API requires that each
|
||||||
|
|||||||
Reference in New Issue
Block a user