move error_accumulator into internal pkg (#304) (#335)

* move error_accumulator into internal pkg (#304)

* move error_accumulator into internal pkg (#304)

* add a test for ErrTooManyEmptyStreamMessages in stream_reader (#304)
This commit is contained in:
渡邉祐一 / Yuichi Watanabe
2023-06-05 23:35:46 +09:00
committed by GitHub
parent fa694c61c2
commit 1394329e44
12 changed files with 249 additions and 201 deletions

View File

@@ -1,7 +1,7 @@
package openai_test
package openai //nolint:testpackage // testing private field
import (
. "github.com/sashabaranov/go-openai"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
@@ -63,9 +63,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
@@ -170,9 +170,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
@@ -227,9 +227,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
@@ -255,6 +255,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}
func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
var err error
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", 200)
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
checks.NoError(t, err)
stream.errAccumulator = &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
}
_, err = stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error())
}
// Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {