feat: add RecvRaw (#896)
This commit is contained in:
@@ -32,17 +32,28 @@ type streamReader[T streamable] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (stream *streamReader[T]) Recv() (response T, err error) {
|
func (stream *streamReader[T]) Recv() (response T, err error) {
|
||||||
if stream.isFinished {
|
rawLine, err := stream.RecvRaw()
|
||||||
err = io.EOF
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err = stream.processLines()
|
err = stream.unmarshaler.Unmarshal(rawLine, &response)
|
||||||
return
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream *streamReader[T]) RecvRaw() ([]byte, error) {
|
||||||
|
if stream.isFinished {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream.processLines()
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit
|
//nolint:gocognit
|
||||||
func (stream *streamReader[T]) processLines() (T, error) {
|
func (stream *streamReader[T]) processLines() ([]byte, error) {
|
||||||
var (
|
var (
|
||||||
emptyMessagesCount uint
|
emptyMessagesCount uint
|
||||||
hasErrorPrefix bool
|
hasErrorPrefix bool
|
||||||
@@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) {
|
|||||||
if readErr != nil || hasErrorPrefix {
|
if readErr != nil || hasErrorPrefix {
|
||||||
respErr := stream.unmarshalError()
|
respErr := stream.unmarshalError()
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return *new(T), fmt.Errorf("error, %w", respErr.Error)
|
return nil, fmt.Errorf("error, %w", respErr.Error)
|
||||||
}
|
}
|
||||||
return *new(T), readErr
|
return nil, readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||||
@@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) {
|
|||||||
}
|
}
|
||||||
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
||||||
if writeErr != nil {
|
if writeErr != nil {
|
||||||
return *new(T), writeErr
|
return nil, writeErr
|
||||||
}
|
}
|
||||||
emptyMessagesCount++
|
emptyMessagesCount++
|
||||||
if emptyMessagesCount > stream.emptyMessagesLimit {
|
if emptyMessagesCount > stream.emptyMessagesLimit {
|
||||||
return *new(T), ErrTooManyEmptyStreamMessages
|
return nil, ErrTooManyEmptyStreamMessages
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
@@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) {
|
|||||||
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
|
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
|
||||||
if string(noPrefixLine) == "[DONE]" {
|
if string(noPrefixLine) == "[DONE]" {
|
||||||
stream.isFinished = true
|
stream.isFinished = true
|
||||||
return *new(T), io.EOF
|
return nil, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
var response T
|
return noPrefixLine, nil
|
||||||
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
|
|
||||||
if unmarshalErr != nil {
|
|
||||||
return *new(T), unmarshalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
|
|||||||
_, err := stream.Recv()
|
_, err := stream.Recv()
|
||||||
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
|
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStreamReaderRecvRaw(t *testing.T) {
|
||||||
|
stream := &streamReader[ChatCompletionStreamResponse]{
|
||||||
|
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))),
|
||||||
|
}
|
||||||
|
rawLine, err := stream.RecvRaw()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Did not return raw line: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) {
|
||||||
|
t.Fatalf("Did not return raw line: %v", string(rawLine))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user