@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,12 +17,18 @@ var (
|
|||||||
|
|
||||||
type CompletionStream struct {
|
type CompletionStream struct {
|
||||||
emptyMessagesLimit uint
|
emptyMessagesLimit uint
|
||||||
|
isFinished bool
|
||||||
|
|
||||||
reader *bufio.Reader
|
reader *bufio.Reader
|
||||||
response *http.Response
|
response *http.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
|
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
|
||||||
|
if stream.isFinished {
|
||||||
|
err = io.EOF
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var emptyMessagesCount uint
|
var emptyMessagesCount uint
|
||||||
|
|
||||||
waitForData:
|
waitForData:
|
||||||
@@ -44,6 +51,8 @@ waitForData:
|
|||||||
|
|
||||||
line = bytes.TrimPrefix(line, headerData)
|
line = bytes.TrimPrefix(line, headerData)
|
||||||
if string(line) == "[DONE]" {
|
if string(line) == "[DONE]" {
|
||||||
|
stream.isFinished = true
|
||||||
|
err = io.EOF
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"github.com/sashabaranov/go-gpt3/internal/test"
|
"github.com/sashabaranov/go-gpt3/internal/test"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -75,7 +77,6 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
|
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
|
||||||
},
|
},
|
||||||
{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for ix, expectedResponse := range expectedResponses {
|
for ix, expectedResponse := range expectedResponses {
|
||||||
@@ -87,6 +88,11 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
|
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, streamErr := stream.Recv()
|
||||||
|
if !errors.Is(streamErr, io.EOF) {
|
||||||
|
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// A "tokenRoundTripper" is a struct that implements the RoundTripper
|
// A "tokenRoundTripper" is a struct that implements the RoundTripper
|
||||||
|
|||||||
Reference in New Issue
Block a user