fix: stream return EOF when openai return error (#184)

* fix: stream return EOF when openai return error

* perf: add error accumulator

* fix: golangci-lint

* fix: unmarshal error possibly null

* fix: error accumulator

* test: error accumulator use interface and add test code

* test: error accumulator add test code

* refactor: use stream reader to re-use stream code

* refactor: stream reader use generics
This commit is contained in:
Liu Shuang
2023-03-22 13:32:47 +08:00
committed by GitHub
parent aa149c1bf8
commit a5a945ad14
8 changed files with 372 additions and 107 deletions

View File

@@ -2,11 +2,7 @@ package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"io"
"net/http"
)
type ChatCompletionStreamChoiceDelta struct {
@@ -30,52 +26,7 @@ type ChatCompletionStreamResponse struct {
// ChatCompletionStream
// Note: Perhaps it is more elegant to abstract Stream using generics.
type ChatCompletionStream struct {
emptyMessagesLimit uint
isFinished bool
reader *bufio.Reader
response *http.Response
}
func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) {
if stream.isFinished {
err = io.EOF
return
}
var emptyMessagesCount uint
waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
return
}
var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
}
goto waitForData
}
line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
}
err = json.Unmarshal(line, &response)
return
}
func (stream *ChatCompletionStream) Close() {
stream.response.Body.Close()
*streamReader[ChatCompletionStreamResponse]
}
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
@@ -98,9 +49,13 @@ func (c *Client) CreateChatCompletionStream(
}
stream = &ChatCompletionStream{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
streamReader: &streamReader[ChatCompletionStreamResponse]{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{},
},
}
return
}

View File

@@ -123,6 +123,73 @@ func TestCreateChatCompletionStream(t *testing.T) {
}
}
func TestCreateChatCompletionStreamError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataStr := []string{
`{`,
`"error": {`,
`"message": "Incorrect API key provided: sk-***************************************",`,
`"type": "invalid_request_error",`,
`"param": null,`,
`"code": "invalid_api_key"`,
`}`,
`}`,
}
for _, str := range dataStr {
dataBytes = append(dataBytes, []byte(str+"\n")...)
}
_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
}))
defer server.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
}
stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()
_, streamErr := stream.Recv()
if streamErr == nil {
t.Errorf("stream.Recv() did not return error")
}
var apiErr *APIError
if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError")
}
t.Logf("%+v\n", apiErr)
}
// Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

51
error_accumulator.go Normal file
View File

@@ -0,0 +1,51 @@
package openai
import (
"bytes"
"fmt"
"io"
)
type errorAccumulator interface {
write(p []byte) error
unmarshalError() (*ErrorResponse, error)
}
type errorBuffer interface {
io.Writer
Len() int
Bytes() []byte
}
type errorAccumulate struct {
buffer errorBuffer
unmarshaler unmarshaler
}
func newErrorAccumulator() errorAccumulator {
return &errorAccumulate{
buffer: &bytes.Buffer{},
unmarshaler: &jsonUnmarshaler{},
}
}
func (e *errorAccumulate) write(p []byte) error {
_, err := e.buffer.Write(p)
if err != nil {
return fmt.Errorf("error accumulator write error, %w", err)
}
return nil
}
func (e *errorAccumulate) unmarshalError() (*ErrorResponse, error) {
var err error
if e.buffer.Len() > 0 {
var errRes ErrorResponse
err = e.unmarshaler.unmarshal(e.buffer.Bytes(), &errRes)
if err != nil {
return nil, err
}
return &errRes, nil
}
return nil, err
}

90
error_accumulator_test.go Normal file
View File

@@ -0,0 +1,90 @@
package openai //nolint:testpackage // testing private field
import (
"bytes"
"context"
"errors"
"testing"
"github.com/sashabaranov/go-openai/internal/test"
)
var (
errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
)
type (
failingUnMarshaller struct{}
failingErrorBuffer struct{}
)
func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) {
return 0, errTestErrorAccumulatorWriteFailed
}
func (b *failingErrorBuffer) Len() int {
return 0
}
func (b *failingErrorBuffer) Bytes() []byte {
return []byte{}
}
func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error {
return errTestUnmarshalerFailed
}
func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
accumulator := &errorAccumulate{
buffer: &bytes.Buffer{},
unmarshaler: &failingUnMarshaller{},
}
err := accumulator.write([]byte("{"))
if err != nil {
t.Fatalf("%+v", err)
}
_, err = accumulator.unmarshalError()
if !errors.Is(err, errTestUnmarshalerFailed) {
t.Fatalf("Did not return error when unmarshaler failed: %v", err)
}
}
func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &errorAccumulate{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
}
err := accumulator.write([]byte("{"))
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
}
func TestErrorAccumulatorWriteErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if err != nil {
t.Fatal(err)
}
stream.errAccumulator = &errorAccumulate{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
}
_, err = stream.Recv()
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
}

View File

@@ -2,12 +2,8 @@ package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
)
var (
@@ -15,52 +11,7 @@ var (
)
type CompletionStream struct {
emptyMessagesLimit uint
isFinished bool
reader *bufio.Reader
response *http.Response
}
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
if stream.isFinished {
err = io.EOF
return
}
var emptyMessagesCount uint
waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
return
}
var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
}
goto waitForData
}
line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
}
err = json.Unmarshal(line, &response)
return
}
func (stream *CompletionStream) Close() {
stream.response.Body.Close()
*streamReader[CompletionResponse]
}
// CreateCompletionStream — API call to create a completion w/ streaming
@@ -83,10 +34,13 @@ func (c *Client) CreateCompletionStream(
}
stream = &CompletionStream{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
streamReader: &streamReader[CompletionResponse]{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{},
},
}
return
}

71
stream_reader.go Normal file
View File

@@ -0,0 +1,71 @@
package openai
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
)
type streamable interface {
ChatCompletionStreamResponse | CompletionResponse
}
type streamReader[T streamable] struct {
emptyMessagesLimit uint
isFinished bool
reader *bufio.Reader
response *http.Response
errAccumulator errorAccumulator
unmarshaler unmarshaler
}
func (stream *streamReader[T]) Recv() (response T, err error) {
if stream.isFinished {
err = io.EOF
return
}
var emptyMessagesCount uint
waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
if errRes, _ := stream.errAccumulator.unmarshalError(); errRes != nil {
err = fmt.Errorf("error, %w", errRes.Error)
}
return
}
var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
if writeErr := stream.errAccumulator.write(line); writeErr != nil {
err = writeErr
return
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
}
goto waitForData
}
line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
}
err = stream.unmarshaler.unmarshal(line, &response)
return
}
func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}

View File

@@ -100,6 +100,68 @@ func TestCreateCompletionStream(t *testing.T) {
}
}
func TestCreateCompletionStreamError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataStr := []string{
`{`,
`"error": {`,
`"message": "Incorrect API key provided: sk-***************************************",`,
`"type": "invalid_request_error",`,
`"param": null,`,
`"code": "invalid_api_key"`,
`}`,
`}`,
}
for _, str := range dataStr {
dataBytes = append(dataBytes, []byte(str+"\n")...)
}
_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
}))
defer server.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Prompt: "Hello!",
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()
_, streamErr := stream.Recv()
if streamErr == nil {
t.Errorf("stream.Recv() did not return error")
}
var apiErr *APIError
if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError")
}
t.Logf("%+v\n", apiErr)
}
// A "tokenRoundTripper" is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each

15
unmarshaler.go Normal file
View File

@@ -0,0 +1,15 @@
package openai
import (
"encoding/json"
)
type unmarshaler interface {
unmarshal(data []byte, v any) error
}
type jsonUnmarshaler struct{}
func (jm *jsonUnmarshaler) unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}