fix: stream return EOF when openai return error (#184)
* fix: stream return EOF when openai return error * perf: add error accumulator * fix: golangci-lint * fix: unmarshal error possibly null * fix: error accumulator * test: error accumulator use interface and add test code * test: error accumulator add test code * refactor: use stream reader to re-use stream code * refactor: stream reader use generics
This commit is contained in:
@@ -100,6 +100,68 @@ func TestCreateCompletionStream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCompletionStreamError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Send test responses
|
||||
dataBytes := []byte{}
|
||||
dataStr := []string{
|
||||
`{`,
|
||||
`"error": {`,
|
||||
`"message": "Incorrect API key provided: sk-***************************************",`,
|
||||
`"type": "invalid_request_error",`,
|
||||
`"param": null,`,
|
||||
`"code": "invalid_api_key"`,
|
||||
`}`,
|
||||
`}`,
|
||||
}
|
||||
for _, str := range dataStr {
|
||||
dataBytes = append(dataBytes, []byte(str+"\n")...)
|
||||
}
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
if err != nil {
|
||||
t.Errorf("Write error: %s", err)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Client portion of the test
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = server.URL + "/v1"
|
||||
config.HTTPClient.Transport = &tokenRoundTripper{
|
||||
test.GetTestToken(),
|
||||
http.DefaultTransport,
|
||||
}
|
||||
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
|
||||
request := CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: GPT3Dot5Turbo,
|
||||
Prompt: "Hello!",
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
stream, err := client.CreateCompletionStream(ctx, request)
|
||||
if err != nil {
|
||||
t.Errorf("CreateCompletionStream returned error: %v", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
_, streamErr := stream.Recv()
|
||||
if streamErr == nil {
|
||||
t.Errorf("stream.Recv() did not return error")
|
||||
}
|
||||
var apiErr *APIError
|
||||
if !errors.As(streamErr, &apiErr) {
|
||||
t.Errorf("stream.Recv() did not return APIError")
|
||||
}
|
||||
t.Logf("%+v\n", apiErr)
|
||||
}
|
||||
|
||||
// A "tokenRoundTripper" is a struct that implements the RoundTripper
|
||||
// interface, specifically to handle the authentication token by adding a token
|
||||
// to the request header. We need this because the API requires that each
|
||||
|
||||
Reference in New Issue
Block a user