diff --git a/stream.go b/stream.go index 51c7f84..b3c9eb1 100644 --- a/stream.go +++ b/stream.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" ) @@ -16,12 +17,18 @@ var ( type CompletionStream struct { emptyMessagesLimit uint + isFinished bool reader *bufio.Reader response *http.Response } func (stream *CompletionStream) Recv() (response CompletionResponse, err error) { + if stream.isFinished { + err = io.EOF + return + } + var emptyMessagesCount uint waitForData: @@ -44,6 +51,8 @@ waitForData: line = bytes.TrimPrefix(line, headerData) if string(line) == "[DONE]" { + stream.isFinished = true + err = io.EOF return } diff --git a/stream_test.go b/stream_test.go index cdf574a..8d2bfa2 100644 --- a/stream_test.go +++ b/stream_test.go @@ -5,6 +5,8 @@ import ( "github.com/sashabaranov/go-gpt3/internal/test" "context" + "errors" + "io" "net/http" "net/http/httptest" "testing" @@ -75,7 +77,6 @@ func TestCreateCompletionStream(t *testing.T) { Model: "text-davinci-002", Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, }, - {}, } for ix, expectedResponse := range expectedResponses { @@ -87,6 +88,11 @@ func TestCreateCompletionStream(t *testing.T) { t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) } } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } } // A "tokenRoundTripper" is a struct that implements the RoundTripper