Refactor streamReader: Replace goto Statement with Loop in Recv Method (#339)

* test: Add tests for improved coverage before refactoring

This commit adds tests to improve coverage before refactoring
to ensure that the changes do not break anything.

* refactor: replace goto statement with loop

This commit introduces a refactor to improve the clarity of the control flow within the method.
The goto statement can sometimes make the code hard to understand and maintain, hence this refactor aims to resolve that.

* refactor: extract for-loop from Recv to another method

This commit improves code readability and maintainability
by making the Recv method simpler.
This commit is contained in:
Yuki Bobier Koshimizu
2023-06-09 00:31:25 +09:00
committed by GitHub
parent 6830e00406
commit b8c13e4c01
2 changed files with 200 additions and 31 deletions

View File

@@ -30,43 +30,52 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
return
}
response, err = stream.processLines()
return
}
func (stream *streamReader[T]) processLines() (T, error) {
var emptyMessagesCount uint
waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
respErr := stream.unmarshalError()
if respErr != nil {
err = fmt.Errorf("error, %w", respErr.Error)
}
return
}
var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
if writeErr := stream.errAccumulator.Write(line); writeErr != nil {
err = writeErr
return
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
for {
rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil {
respErr := stream.unmarshalError()
if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error)
}
return *new(T), readErr
}
goto waitForData
}
var headerData = []byte("data: ")
noSpaceLine := bytes.TrimSpace(rawLine)
if !bytes.HasPrefix(noSpaceLine, headerData) {
writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil {
return *new(T), writeErr
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(T), ErrTooManyEmptyStreamMessages
}
line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
}
continue
}
err = stream.unmarshaler.Unmarshal(line, &response)
return
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true
return *new(T), io.EOF
}
var response T
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
if unmarshalErr != nil {
return *new(T), unmarshalErr
}
return response, nil
}
}
func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {

View File

@@ -2,6 +2,7 @@ package openai_test
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
@@ -217,6 +218,165 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}
func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
// Totally 301 empty messages (300 is the limit)
for i := 0; i < 299; i++ {
dataBytes = append(dataBytes, '\n')
}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
_, _ = stream.Recv()
_, streamErr := stream.Recv()
if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) {
t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages")
}
}
func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
// Stream is terminated without sending "done" message
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
_, _ = stream.Recv()
_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("TestCreateCompletionStreamUnexpectedTerminatedError did not return io.EOF")
}
}
func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)
//nolint:lll
data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
// Send broken json
dataBytes = append(dataBytes, []byte("event: message\n")...)
data = `{"id":"2","object":"completion","created":1598069255,"model":`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
_, _ = stream.Recv()
_, streamErr := stream.Recv()
var syntaxError *json.SyntaxError
if !errors.As(streamErr, &syntaxError) {
t.Errorf("TestCreateCompletionStreamBrokenJSONError did not return json.SyntaxError")
}
}
// Helper funcs.
func compareResponses(r1, r2 CompletionResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {