handle stream completion (#86)

* handle stream completion

* fix tests
This commit is contained in:
sashabaranov
2023-02-22 12:33:25 +04:00
committed by GitHub
parent 1eb5d625f8
commit ae05ed976f
2 changed files with 16 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
)
@@ -16,12 +17,18 @@ var (
type CompletionStream struct {
emptyMessagesLimit uint
isFinished bool
reader *bufio.Reader
response *http.Response
}
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
if stream.isFinished {
err = io.EOF
return
}
var emptyMessagesCount uint
waitForData:
@@ -44,6 +51,8 @@ waitForData:
line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
}

View File

@@ -5,6 +5,8 @@ import (
"github.com/sashabaranov/go-gpt3/internal/test"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
@@ -75,7 +77,6 @@ func TestCreateCompletionStream(t *testing.T) {
Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
},
{},
}
for ix, expectedResponse := range expectedResponses {
@@ -87,6 +88,11 @@ func TestCreateCompletionStream(t *testing.T) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
}
_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}
}
// A "tokenRoundTripper" is a struct that implements the RoundTripper