diff --git a/stream_reader.go b/stream_reader.go index a9940b0..3416198 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -30,43 +30,52 @@ func (stream *streamReader[T]) Recv() (response T, err error) { return } + response, err = stream.processLines() + return +} + +func (stream *streamReader[T]) processLines() (T, error) { var emptyMessagesCount uint -waitForData: - line, err := stream.reader.ReadBytes('\n') - if err != nil { - respErr := stream.unmarshalError() - if respErr != nil { - err = fmt.Errorf("error, %w", respErr.Error) - } - return - } - - var headerData = []byte("data: ") - line = bytes.TrimSpace(line) - if !bytes.HasPrefix(line, headerData) { - if writeErr := stream.errAccumulator.Write(line); writeErr != nil { - err = writeErr - return - } - emptyMessagesCount++ - if emptyMessagesCount > stream.emptyMessagesLimit { - err = ErrTooManyEmptyStreamMessages - return + for { + rawLine, readErr := stream.reader.ReadBytes('\n') + if readErr != nil { + respErr := stream.unmarshalError() + if respErr != nil { + return *new(T), fmt.Errorf("error, %w", respErr.Error) + } + return *new(T), readErr } - goto waitForData - } + var headerData = []byte("data: ") + noSpaceLine := bytes.TrimSpace(rawLine) + if !bytes.HasPrefix(noSpaceLine, headerData) { + writeErr := stream.errAccumulator.Write(noSpaceLine) + if writeErr != nil { + return *new(T), writeErr + } + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + return *new(T), ErrTooManyEmptyStreamMessages + } - line = bytes.TrimPrefix(line, headerData) - if string(line) == "[DONE]" { - stream.isFinished = true - err = io.EOF - return - } + continue + } - err = stream.unmarshaler.Unmarshal(line, &response) - return + noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + if string(noPrefixLine) == "[DONE]" { + stream.isFinished = true + return *new(T), io.EOF + } + + var response T + unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response) + if unmarshalErr != nil { + return *new(T), unmarshalErr + } + + return response, nil + } } func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { diff --git a/stream_test.go b/stream_test.go index 589fc9e..0faa212 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,6 +2,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "net/http" @@ -217,6 +218,165 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Totally 301 empty messages (300 is the limit) + for i := 0; i < 299; i++ { + dataBytes = append(dataBytes, '\n') + } + + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) { + t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") + } +} + +func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Stream is terminated without sending "done" message + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("TestCreateCompletionStreamUnexpectedTerminatedError did not return io.EOF") + } +} + +func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Send broken json + dataBytes = append(dataBytes, []byte("event: message\n")...) + data = `{"id":"2","object":"completion","created":1598069255,"model":` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + var syntaxError *json.SyntaxError + if !errors.As(streamErr, &syntaxError) { + t.Errorf("TestCreateCompletionStreamBrokenJSONError did not return json.SyntaxError") + } +} + // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {