refactoring tests with mock servers (#30) (#356)

This commit is contained in:
渡邉祐一 / Yuichi Watanabe
2023-06-12 22:40:26 +09:00
committed by GitHub
parent a243e7331f
commit b616090e69
20 changed files with 732 additions and 1061 deletions

View File

@@ -6,11 +6,9 @@ import (
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
@@ -32,7 +30,9 @@ func TestCompletionsStreamWrongModel(t *testing.T) {
}
func TestCreateCompletionStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
@@ -52,28 +52,14 @@ func TestCreateCompletionStream(t *testing.T) {
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()
})
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
@@ -116,7 +102,9 @@ func TestCreateCompletionStream(t *testing.T) {
}
func TestCreateCompletionStreamError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
@@ -137,28 +125,14 @@ func TestCreateCompletionStreamError(t *testing.T) {
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()
})
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
MaxTokens: 5,
Model: GPT3TextDavinci003,
Prompt: "Hello!",
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
@@ -173,7 +147,8 @@ func TestCreateCompletionStreamError(t *testing.T) {
}
func TestCreateCompletionStreamRateLimitError(t *testing.T) {
server := test.NewTestServer()
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)
@@ -188,30 +163,14 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
var apiErr *APIError
_, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
MaxTokens: 5,
Model: GPT3Ada,
Prompt: "Hello!",
Stream: true,
}
var apiErr *APIError
_, err := client.CreateCompletionStream(ctx, request)
})
if !errors.As(err, &apiErr) {
t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError")
}
@@ -219,7 +178,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
}
func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
@@ -244,28 +205,14 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()
})
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
@@ -277,7 +224,9 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
}
func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
@@ -291,28 +240,14 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()
})
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
@@ -324,7 +259,9 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
}
func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
@@ -344,28 +281,14 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
}))
defer server.Close()
})
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: "text-davinci-002",
MaxTokens: 10,
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()