diff --git a/chat_stream_test.go b/chat_stream_test.go index c3cb9f3..5fc70b0 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -178,6 +178,45 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, streamErr := stream.Recv() + checks.HasError(t, streamErr, "stream.Recv() did not return error") + + var apiErr *APIError + if !errors.As(streamErr, &apiErr) { + t.Errorf("stream.Recv() did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/stream_reader.go b/stream_reader.go index 3416198..87e59e0 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -10,6 +10,11 @@ import ( utils "github.com/sashabaranov/go-openai/internal" ) +var ( + headerData = []byte("data: ") + errorPrefix = []byte(`data: {"error":`) +) + type streamable interface { ChatCompletionStreamResponse | CompletionResponse } @@ -34,12 +39,16 @@ func (stream *streamReader[T]) Recv() (response T, err error) { return } +//nolint:gocognit func (stream *streamReader[T]) processLines() (T, error) { - var emptyMessagesCount uint + var ( + emptyMessagesCount uint + hasErrorPrefix bool + ) for { rawLine, readErr := stream.reader.ReadBytes('\n') - if readErr != nil { + if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { return *new(T), fmt.Errorf("error, %w", respErr.Error) @@ -47,9 +56,14 @@ func (stream *streamReader[T]) processLines() (T, error) { return *new(T), readErr } - var headerData = []byte("data: ") noSpaceLine := bytes.TrimSpace(rawLine) - if !bytes.HasPrefix(noSpaceLine, headerData) { + if bytes.HasPrefix(noSpaceLine, errorPrefix) { + hasErrorPrefix = true + } + if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if hasErrorPrefix { + noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { return *new(T), writeErr