* move error_accumulator into internal pkg (#304) * move error_accumulator into internal pkg (#304) * add a test for ErrTooManyEmptyStreamMessages in stream_reader (#304)
This commit is contained in:
committed by
GitHub
parent
fa694c61c2
commit
1394329e44
@@ -66,7 +66,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
emptyMessagesLimit: c.config.EmptyMessagesLimit,
|
emptyMessagesLimit: c.config.EmptyMessagesLimit,
|
||||||
reader: bufio.NewReader(resp.Body),
|
reader: bufio.NewReader(resp.Body),
|
||||||
response: resp,
|
response: resp,
|
||||||
errAccumulator: newErrorAccumulator(),
|
errAccumulator: utils.NewErrorAccumulator(),
|
||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package openai_test
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
@@ -63,9 +63,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
// Client portion of the test
|
// Client portion of the test
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = server.URL + "/v1"
|
config.BaseURL = server.URL + "/v1"
|
||||||
config.HTTPClient.Transport = &tokenRoundTripper{
|
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
||||||
test.GetTestToken(),
|
Token: test.GetTestToken(),
|
||||||
http.DefaultTransport,
|
Fallback: http.DefaultTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
@@ -170,9 +170,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
// Client portion of the test
|
// Client portion of the test
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = server.URL + "/v1"
|
config.BaseURL = server.URL + "/v1"
|
||||||
config.HTTPClient.Transport = &tokenRoundTripper{
|
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
||||||
test.GetTestToken(),
|
Token: test.GetTestToken(),
|
||||||
http.DefaultTransport,
|
Fallback: http.DefaultTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
@@ -227,9 +227,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
// Client portion of the test
|
// Client portion of the test
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = ts.URL + "/v1"
|
config.BaseURL = ts.URL + "/v1"
|
||||||
config.HTTPClient.Transport = &tokenRoundTripper{
|
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
||||||
test.GetTestToken(),
|
Token: test.GetTestToken(),
|
||||||
http.DefaultTransport,
|
Fallback: http.DefaultTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
@@ -255,6 +255,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
t.Logf("%+v\n", apiErr)
|
t.Logf("%+v\n", apiErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "error", 200)
|
||||||
|
})
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
||||||
|
checks.NoError(t, err)
|
||||||
|
|
||||||
|
stream.errAccumulator = &utils.DefaultErrorAccumulator{
|
||||||
|
Buffer: &test.FailingErrorBuffer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stream.Recv()
|
||||||
|
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
// Helper funcs.
|
// Helper funcs.
|
||||||
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
|
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
|
||||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
|
||||||
)
|
|
||||||
|
|
||||||
type errorAccumulator interface {
|
|
||||||
write(p []byte) error
|
|
||||||
unmarshalError() *ErrorResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
type errorBuffer interface {
|
|
||||||
io.Writer
|
|
||||||
Len() int
|
|
||||||
Bytes() []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type defaultErrorAccumulator struct {
|
|
||||||
buffer errorBuffer
|
|
||||||
unmarshaler utils.Unmarshaler
|
|
||||||
}
|
|
||||||
|
|
||||||
func newErrorAccumulator() errorAccumulator {
|
|
||||||
return &defaultErrorAccumulator{
|
|
||||||
buffer: &bytes.Buffer{},
|
|
||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *defaultErrorAccumulator) write(p []byte) error {
|
|
||||||
_, err := e.buffer.Write(p)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error accumulator write error, %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
|
|
||||||
if e.buffer.Len() == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp)
|
|
||||||
if err != nil {
|
|
||||||
errResp = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
|
|
||||||
errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
failingUnMarshaller struct{}
|
|
||||||
failingErrorBuffer struct{}
|
|
||||||
)
|
|
||||||
|
|
||||||
func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) {
|
|
||||||
return 0, errTestErrorAccumulatorWriteFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *failingErrorBuffer) Len() int {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *failingErrorBuffer) Bytes() []byte {
|
|
||||||
return []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
|
|
||||||
return errTestUnmarshalerFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
|
|
||||||
accumulator := &defaultErrorAccumulator{
|
|
||||||
buffer: &bytes.Buffer{},
|
|
||||||
unmarshaler: &failingUnMarshaller{},
|
|
||||||
}
|
|
||||||
|
|
||||||
respErr := accumulator.unmarshalError()
|
|
||||||
if respErr != nil {
|
|
||||||
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := accumulator.write([]byte("{"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("%+v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
respErr = accumulator.unmarshalError()
|
|
||||||
if respErr != nil {
|
|
||||||
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorByteWriteErrors(t *testing.T) {
|
|
||||||
accumulator := &defaultErrorAccumulator{
|
|
||||||
buffer: &failingErrorBuffer{},
|
|
||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
|
||||||
}
|
|
||||||
err := accumulator.write([]byte("{"))
|
|
||||||
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
|
|
||||||
t.Fatalf("Did not return error when write failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorAccumulatorWriteErrors(t *testing.T) {
|
|
||||||
var err error
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Error(w, "error", 200)
|
|
||||||
})
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
|
||||||
checks.NoError(t, err)
|
|
||||||
|
|
||||||
stream.errAccumulator = &defaultErrorAccumulator{
|
|
||||||
buffer: &failingErrorBuffer{},
|
|
||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = stream.Recv()
|
|
||||||
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
|
|
||||||
}
|
|
||||||
44
internal/error_accumulator.go
Normal file
44
internal/error_accumulator.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ErrorAccumulator interface {
|
||||||
|
Write(p []byte) error
|
||||||
|
Bytes() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorBuffer interface {
|
||||||
|
io.Writer
|
||||||
|
Len() int
|
||||||
|
Bytes() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultErrorAccumulator struct {
|
||||||
|
Buffer errorBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewErrorAccumulator() ErrorAccumulator {
|
||||||
|
return &DefaultErrorAccumulator{
|
||||||
|
Buffer: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DefaultErrorAccumulator) Write(p []byte) error {
|
||||||
|
_, err := e.Buffer.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error accumulator write error, %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) {
|
||||||
|
if e.Buffer.Len() == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errBytes = e.Buffer.Bytes()
|
||||||
|
return
|
||||||
|
}
|
||||||
41
internal/error_accumulator_test.go
Normal file
41
internal/error_accumulator_test.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package openai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrorAccumulatorBytes(t *testing.T) {
|
||||||
|
accumulator := &utils.DefaultErrorAccumulator{
|
||||||
|
Buffer: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
errBytes := accumulator.Bytes()
|
||||||
|
if len(errBytes) != 0 {
|
||||||
|
t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
err := accumulator.Write([]byte("{}"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errBytes = accumulator.Bytes()
|
||||||
|
if len(errBytes) == 0 {
|
||||||
|
t.Fatalf("Did not return error bytes when has error: %s", string(errBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrorByteWriteErrors(t *testing.T) {
|
||||||
|
accumulator := &utils.DefaultErrorAccumulator{
|
||||||
|
Buffer: &test.FailingErrorBuffer{},
|
||||||
|
}
|
||||||
|
err := accumulator.Write([]byte("{"))
|
||||||
|
if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) {
|
||||||
|
t.Fatalf("Did not return error when write failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
internal/test/failer.go
Normal file
21
internal/test/failer.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
|
||||||
|
)
|
||||||
|
|
||||||
|
type FailingErrorBuffer struct{}
|
||||||
|
|
||||||
|
func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) {
|
||||||
|
return 0, ErrTestErrorAccumulatorWriteFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *FailingErrorBuffer) Len() int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *FailingErrorBuffer) Bytes() []byte {
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package test
|
|||||||
import (
|
import (
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -27,3 +28,26 @@ func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {
|
|||||||
|
|
||||||
return path, func() { os.RemoveAll(path) }
|
return path, func() { os.RemoveAll(path) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenRoundTripper is a struct that implements the RoundTripper
|
||||||
|
// interface, specifically to handle the authentication token by adding a token
|
||||||
|
// to the request header. We need this because the API requires that each
|
||||||
|
// request include a valid API token in the headers for authentication and
|
||||||
|
// authorization.
|
||||||
|
type TokenRoundTripper struct {
|
||||||
|
Token string
|
||||||
|
Fallback http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip takes an *http.Request as input and returns an
|
||||||
|
// *http.Response and an error.
|
||||||
|
//
|
||||||
|
// It is expected to use the provided request to create a connection to an HTTP
|
||||||
|
// server and return the response, or an error if one occurred. The returned
|
||||||
|
// Response should have its Body closed. If the RoundTrip method returns an
|
||||||
|
// error, the Client's Get, Head, Post, and PostForm methods return the same
|
||||||
|
// error.
|
||||||
|
func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+t.Token)
|
||||||
|
return t.Fallback.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (c *Client) CreateCompletionStream(
|
|||||||
emptyMessagesLimit: c.config.EmptyMessagesLimit,
|
emptyMessagesLimit: c.config.EmptyMessagesLimit,
|
||||||
reader: bufio.NewReader(resp.Body),
|
reader: bufio.NewReader(resp.Body),
|
||||||
response: resp,
|
response: resp,
|
||||||
errAccumulator: newErrorAccumulator(),
|
errAccumulator: utils.NewErrorAccumulator(),
|
||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ type streamReader[T streamable] struct {
|
|||||||
|
|
||||||
reader *bufio.Reader
|
reader *bufio.Reader
|
||||||
response *http.Response
|
response *http.Response
|
||||||
errAccumulator errorAccumulator
|
errAccumulator utils.ErrorAccumulator
|
||||||
unmarshaler utils.Unmarshaler
|
unmarshaler utils.Unmarshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
|
|||||||
waitForData:
|
waitForData:
|
||||||
line, err := stream.reader.ReadBytes('\n')
|
line, err := stream.reader.ReadBytes('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respErr := stream.errAccumulator.unmarshalError()
|
respErr := stream.unmarshalError()
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
err = fmt.Errorf("error, %w", respErr.Error)
|
err = fmt.Errorf("error, %w", respErr.Error)
|
||||||
}
|
}
|
||||||
@@ -45,7 +45,7 @@ waitForData:
|
|||||||
var headerData = []byte("data: ")
|
var headerData = []byte("data: ")
|
||||||
line = bytes.TrimSpace(line)
|
line = bytes.TrimSpace(line)
|
||||||
if !bytes.HasPrefix(line, headerData) {
|
if !bytes.HasPrefix(line, headerData) {
|
||||||
if writeErr := stream.errAccumulator.write(line); writeErr != nil {
|
if writeErr := stream.errAccumulator.Write(line); writeErr != nil {
|
||||||
err = writeErr
|
err = writeErr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -69,6 +69,20 @@ waitForData:
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
|
||||||
|
errBytes := stream.errAccumulator.Bytes()
|
||||||
|
if len(errBytes) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := stream.unmarshaler.Unmarshal(errBytes, &errResp)
|
||||||
|
if err != nil {
|
||||||
|
errResp = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (stream *streamReader[T]) Close() {
|
func (stream *streamReader[T]) Close() {
|
||||||
stream.response.Body.Close()
|
stream.response.Body.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
53
stream_reader_test.go
Normal file
53
stream_reader_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
|
||||||
|
|
||||||
|
type failingUnMarshaller struct{}
|
||||||
|
|
||||||
|
func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
|
||||||
|
return errTestUnmarshalerFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) {
|
||||||
|
stream := &streamReader[ChatCompletionStreamResponse]{
|
||||||
|
errAccumulator: utils.NewErrorAccumulator(),
|
||||||
|
unmarshaler: &failingUnMarshaller{},
|
||||||
|
}
|
||||||
|
|
||||||
|
respErr := stream.unmarshalError()
|
||||||
|
if respErr != nil {
|
||||||
|
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := stream.errAccumulator.Write([]byte("{"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respErr = stream.unmarshalError()
|
||||||
|
if respErr != nil {
|
||||||
|
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) {
|
||||||
|
stream := &streamReader[ChatCompletionStreamResponse]{
|
||||||
|
emptyMessagesLimit: 3,
|
||||||
|
reader: bufio.NewReader(bytes.NewReader([]byte("\n\n\n\n"))),
|
||||||
|
errAccumulator: utils.NewErrorAccumulator(),
|
||||||
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
|
}
|
||||||
|
_, err := stream.Recv()
|
||||||
|
if !errors.Is(err, ErrTooManyEmptyStreamMessages) {
|
||||||
|
t.Fatalf("Did not return error when recv failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -57,9 +57,9 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
// Client portion of the test
|
// Client portion of the test
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = server.URL + "/v1"
|
config.BaseURL = server.URL + "/v1"
|
||||||
config.HTTPClient.Transport = &tokenRoundTripper{
|
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
||||||
test.GetTestToken(),
|
Token: test.GetTestToken(),
|
||||||
http.DefaultTransport,
|
Fallback: http.DefaultTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
@@ -142,9 +142,9 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
|||||||
// Client portion of the test
|
// Client portion of the test
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = server.URL + "/v1"
|
config.BaseURL = server.URL + "/v1"
|
||||||
config.HTTPClient.Transport = &tokenRoundTripper{
|
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
||||||
test.GetTestToken(),
|
Token: test.GetTestToken(),
|
||||||
http.DefaultTransport,
|
Fallback: http.DefaultTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
@@ -194,9 +194,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
// Client portion of the test
|
// Client portion of the test
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = ts.URL + "/v1"
|
config.BaseURL = ts.URL + "/v1"
|
||||||
config.HTTPClient.Transport = &tokenRoundTripper{
|
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
||||||
test.GetTestToken(),
|
Token: test.GetTestToken(),
|
||||||
http.DefaultTransport,
|
Fallback: http.DefaultTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
@@ -217,29 +217,6 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
t.Logf("%+v\n", apiErr)
|
t.Logf("%+v\n", apiErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// A "tokenRoundTripper" is a struct that implements the RoundTripper
|
|
||||||
// interface, specifically to handle the authentication token by adding a token
|
|
||||||
// to the request header. We need this because the API requires that each
|
|
||||||
// request include a valid API token in the headers for authentication and
|
|
||||||
// authorization.
|
|
||||||
type tokenRoundTripper struct {
|
|
||||||
token string
|
|
||||||
fallback http.RoundTripper
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoundTrip takes an *http.Request as input and returns an
|
|
||||||
// *http.Response and an error.
|
|
||||||
//
|
|
||||||
// It is expected to use the provided request to create a connection to an HTTP
|
|
||||||
// server and return the response, or an error if one occurred. The returned
|
|
||||||
// Response should have its Body closed. If the RoundTrip method returns an
|
|
||||||
// error, the Client's Get, Head, Post, and PostForm methods return the same
|
|
||||||
// error.
|
|
||||||
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+t.token)
|
|
||||||
return t.fallback.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper funcs.
|
// Helper funcs.
|
||||||
func compareResponses(r1, r2 CompletionResponse) bool {
|
func compareResponses(r1, r2 CompletionResponse) bool {
|
||||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
||||||
|
|||||||
Reference in New Issue
Block a user