@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -16,12 +17,18 @@ var (
|
||||
|
||||
type CompletionStream struct {
|
||||
emptyMessagesLimit uint
|
||||
isFinished bool
|
||||
|
||||
reader *bufio.Reader
|
||||
response *http.Response
|
||||
}
|
||||
|
||||
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
var emptyMessagesCount uint
|
||||
|
||||
waitForData:
|
||||
@@ -44,6 +51,8 @@ waitForData:
|
||||
|
||||
line = bytes.TrimPrefix(line, headerData)
|
||||
if string(line) == "[DONE]" {
|
||||
stream.isFinished = true
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"github.com/sashabaranov/go-gpt3/internal/test"
|
||||
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -75,7 +77,6 @@ func TestCreateCompletionStream(t *testing.T) {
|
||||
Model: "text-davinci-002",
|
||||
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
|
||||
},
|
||||
{},
|
||||
}
|
||||
|
||||
for ix, expectedResponse := range expectedResponses {
|
||||
@@ -87,6 +88,11 @@ func TestCreateCompletionStream(t *testing.T) {
|
||||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
|
||||
}
|
||||
}
|
||||
|
||||
_, streamErr := stream.Recv()
|
||||
if !errors.Is(streamErr, io.EOF) {
|
||||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
|
||||
}
|
||||
}
|
||||
|
||||
// A "tokenRoundTripper" is a struct that implements the RoundTripper
|
||||
|
||||
Reference in New Issue
Block a user