committed by
GitHub
parent
a243e7331f
commit
b616090e69
16
api_test.go
16
api_test.go
@@ -6,7 +6,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -226,18 +225,13 @@ func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestError(t *testing.T) {
|
func TestRequestError(t *testing.T) {
|
||||||
var err error
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusTeapot)
|
w.WriteHeader(http.StatusTeapot)
|
||||||
}))
|
})
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig("dummy")
|
_, err := client.ListEngines(context.Background())
|
||||||
config.BaseURL = ts.URL
|
|
||||||
c := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
_, err = c.ListEngines(ctx)
|
|
||||||
checks.HasError(t, err, "ListEngines did not fail")
|
checks.HasError(t, err, "ListEngines did not fail")
|
||||||
|
|
||||||
var reqErr *RequestError
|
var reqErr *RequestError
|
||||||
|
|||||||
162
audio_api_test.go
Normal file
162
audio_api_test.go
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
package openai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
|
||||||
|
func TestAudio(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
|
||||||
|
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
|
||||||
|
|
||||||
|
testcases := []struct {
|
||||||
|
name string
|
||||||
|
createFn func(context.Context, AudioRequest) (AudioResponse, error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"transcribe",
|
||||||
|
client.CreateTranscription,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"translate",
|
||||||
|
client.CreateTranslation,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dir, cleanup := test.CreateTestDirectory(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, tc := range testcases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
path := filepath.Join(dir, "fake.mp3")
|
||||||
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
|
req := AudioRequest{
|
||||||
|
FilePath: path,
|
||||||
|
Model: "whisper-3",
|
||||||
|
}
|
||||||
|
_, err := tc.createFn(ctx, req)
|
||||||
|
checks.NoError(t, err, "audio API error")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run(tc.name+" (with reader)", func(t *testing.T) {
|
||||||
|
req := AudioRequest{
|
||||||
|
FilePath: "fake.webm",
|
||||||
|
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
|
||||||
|
Model: "whisper-3",
|
||||||
|
}
|
||||||
|
_, err := tc.createFn(ctx, req)
|
||||||
|
checks.NoError(t, err, "audio API error")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAudioWithOptionalArgs(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
|
||||||
|
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
|
||||||
|
|
||||||
|
testcases := []struct {
|
||||||
|
name string
|
||||||
|
createFn func(context.Context, AudioRequest) (AudioResponse, error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"transcribe",
|
||||||
|
client.CreateTranscription,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"translate",
|
||||||
|
client.CreateTranslation,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dir, cleanup := test.CreateTestDirectory(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, tc := range testcases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
path := filepath.Join(dir, "fake.mp3")
|
||||||
|
test.CreateTestFile(t, path)
|
||||||
|
|
||||||
|
req := AudioRequest{
|
||||||
|
FilePath: path,
|
||||||
|
Model: "whisper-3",
|
||||||
|
Prompt: "用简体中文",
|
||||||
|
Temperature: 0.5,
|
||||||
|
Language: "zh",
|
||||||
|
Format: AudioResponseFormatSRT,
|
||||||
|
}
|
||||||
|
_, err := tc.createFn(ctx, req)
|
||||||
|
checks.NoError(t, err, "audio API error")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAudioEndpoint Handles the completion endpoint by the test server.
|
||||||
|
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// audio endpoints only accept POST requests
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to parse media type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(mediaType, "multipart") {
|
||||||
|
http.Error(w, "request is not multipart", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
boundary, ok := params["boundary"]
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "no boundary in params", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileData := &bytes.Buffer{}
|
||||||
|
mr := multipart.NewReader(r.Body, boundary)
|
||||||
|
part, err := mr.NextPart()
|
||||||
|
if err != nil && errors.Is(err, io.EOF) {
|
||||||
|
http.Error(w, "error accessing file", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err = io.Copy(fileData, part); err != nil {
|
||||||
|
http.Error(w, "failed to copy file", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fileData.Bytes()) == 0 {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
http.Error(w, "received empty file data", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
|
||||||
|
http.Error(w, "failed to write body", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
166
audio_test.go
166
audio_test.go
@@ -2,182 +2,16 @@ package openai //nolint:testpackage // testing private field
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime"
|
|
||||||
"mime/multipart"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
|
|
||||||
func TestAudio(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
|
|
||||||
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
|
|
||||||
testcases := []struct {
|
|
||||||
name string
|
|
||||||
createFn func(context.Context, AudioRequest) (AudioResponse, error)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"transcribe",
|
|
||||||
client.CreateTranscription,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"translate",
|
|
||||||
client.CreateTranslation,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
for _, tc := range testcases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
path := filepath.Join(dir, "fake.mp3")
|
|
||||||
test.CreateTestFile(t, path)
|
|
||||||
|
|
||||||
req := AudioRequest{
|
|
||||||
FilePath: path,
|
|
||||||
Model: "whisper-3",
|
|
||||||
}
|
|
||||||
_, err = tc.createFn(ctx, req)
|
|
||||||
checks.NoError(t, err, "audio API error")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run(tc.name+" (with reader)", func(t *testing.T) {
|
|
||||||
req := AudioRequest{
|
|
||||||
FilePath: "fake.webm",
|
|
||||||
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
|
|
||||||
Model: "whisper-3",
|
|
||||||
}
|
|
||||||
_, err = tc.createFn(ctx, req)
|
|
||||||
checks.NoError(t, err, "audio API error")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAudioWithOptionalArgs(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
|
|
||||||
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
|
|
||||||
testcases := []struct {
|
|
||||||
name string
|
|
||||||
createFn func(context.Context, AudioRequest) (AudioResponse, error)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"transcribe",
|
|
||||||
client.CreateTranscription,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"translate",
|
|
||||||
client.CreateTranslation,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
for _, tc := range testcases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
path := filepath.Join(dir, "fake.mp3")
|
|
||||||
test.CreateTestFile(t, path)
|
|
||||||
|
|
||||||
req := AudioRequest{
|
|
||||||
FilePath: path,
|
|
||||||
Model: "whisper-3",
|
|
||||||
Prompt: "用简体中文",
|
|
||||||
Temperature: 0.5,
|
|
||||||
Language: "zh",
|
|
||||||
Format: AudioResponseFormatSRT,
|
|
||||||
}
|
|
||||||
_, err = tc.createFn(ctx, req)
|
|
||||||
checks.NoError(t, err, "audio API error")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleAudioEndpoint Handles the completion endpoint by the test server.
|
|
||||||
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// audio endpoints only accept POST requests
|
|
||||||
if r.Method != "POST" {
|
|
||||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
}
|
|
||||||
|
|
||||||
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "failed to parse media type", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.HasPrefix(mediaType, "multipart") {
|
|
||||||
http.Error(w, "request is not multipart", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
boundary, ok := params["boundary"]
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "no boundary in params", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fileData := &bytes.Buffer{}
|
|
||||||
mr := multipart.NewReader(r.Body, boundary)
|
|
||||||
part, err := mr.NextPart()
|
|
||||||
if err != nil && errors.Is(err, io.EOF) {
|
|
||||||
http.Error(w, "error accessing file", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err = io.Copy(fileData, part); err != nil {
|
|
||||||
http.Error(w, "failed to copy file", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(fileData.Bytes()) == 0 {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
http.Error(w, "received empty file data", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
|
|
||||||
http.Error(w, "failed to write body", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAudioWithFailingFormBuilder(t *testing.T) {
|
func TestAudioWithFailingFormBuilder(t *testing.T) {
|
||||||
dir, cleanup := test.CreateTestDirectory(t)
|
dir, cleanup := test.CreateTestDirectory(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
@@ -10,7 +9,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,7 +35,9 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateChatCompletionStream(t *testing.T) {
|
func TestCreateChatCompletionStream(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -57,21 +57,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
}))
|
})
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = server.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := ChatCompletionRequest{
|
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []ChatCompletionMessage{
|
||||||
@@ -81,9 +69,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, request)
|
|
||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
@@ -143,7 +129,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateChatCompletionStreamError(t *testing.T) {
|
func TestCreateChatCompletionStreamError(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -164,21 +152,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
}))
|
})
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = server.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := ChatCompletionRequest{
|
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []ChatCompletionMessage{
|
||||||
@@ -188,9 +164,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, request)
|
|
||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
@@ -205,7 +179,8 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(429)
|
w.WriteHeader(429)
|
||||||
@@ -220,22 +195,7 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
ts := server.OpenAITestServer()
|
_, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := ChatCompletionRequest{
|
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []ChatCompletionMessage{
|
||||||
@@ -245,10 +205,8 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
var apiErr *APIError
|
var apiErr *APIError
|
||||||
_, err := client.CreateChatCompletionStream(ctx, request)
|
|
||||||
if !errors.As(err, &apiErr) {
|
if !errors.As(err, &apiErr) {
|
||||||
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
|
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
|
||||||
}
|
}
|
||||||
@@ -262,7 +220,8 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
"Please retry after 20 seconds. " +
|
"Please retry after 20 seconds. " +
|
||||||
"Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit."
|
"Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit."
|
||||||
|
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupAzureTestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions",
|
server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
@@ -273,17 +232,9 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
|
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultAzureConfig(test.GetTestToken(), ts.URL)
|
apiErr := &APIError{}
|
||||||
client := NewClientWithConfig(config)
|
_, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := ChatCompletionRequest{
|
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []ChatCompletionMessage{
|
||||||
@@ -293,10 +244,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
apiErr := &APIError{}
|
|
||||||
_, err = client.CreateChatCompletionStream(ctx, request)
|
|
||||||
if !errors.As(err, &apiErr) {
|
if !errors.As(err, &apiErr) {
|
||||||
t.Errorf("Did not return APIError: %+v\n", apiErr)
|
t.Errorf("Did not return APIError: %+v\n", apiErr)
|
||||||
return
|
return
|
||||||
@@ -316,33 +264,6 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
|
|
||||||
var err error
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Error(w, "error", 200)
|
|
||||||
})
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
|
||||||
checks.NoError(t, err)
|
|
||||||
|
|
||||||
stream.errAccumulator = &utils.DefaultErrorAccumulator{
|
|
||||||
Buffer: &test.FailingErrorBuffer{},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = stream.Recv()
|
|
||||||
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper funcs.
|
// Helper funcs.
|
||||||
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
|
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
|
||||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
||||||
|
|||||||
20
chat_test.go
20
chat_test.go
@@ -2,7 +2,6 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
@@ -52,20 +51,10 @@ func TestChatCompletionsWithStream(t *testing.T) {
|
|||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
func TestChatCompletions(t *testing.T) {
|
func TestChatCompletions(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||||
// create the test server
|
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
req := ChatCompletionRequest{
|
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Dot5Turbo,
|
Model: GPT3Dot5Turbo,
|
||||||
Messages: []ChatCompletionMessage{
|
Messages: []ChatCompletionMessage{
|
||||||
@@ -74,8 +63,7 @@ func TestChatCompletions(t *testing.T) {
|
|||||||
Content: "Hello!",
|
Content: "Hello!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
_, err = client.CreateChatCompletion(ctx, req)
|
|
||||||
checks.NoError(t, err, "CreateChatCompletion error")
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -167,16 +167,9 @@ func TestHandleErrorResp(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
||||||
var err error
|
|
||||||
ts := test.NewTestServer().OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
client.requestBuilder = &failingRequestBuilder{}
|
client.requestBuilder = &failingRequestBuilder{}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
type TestCase struct {
|
type TestCase struct {
|
||||||
@@ -254,7 +247,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
_, err = testCase.TestFunc()
|
_, err := testCase.TestFunc()
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err)
|
t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err)
|
||||||
}
|
}
|
||||||
@@ -262,23 +255,14 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) {
|
func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) {
|
||||||
var err error
|
|
||||||
ts := test.NewTestServer().OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
client.requestBuilder = &failingRequestBuilder{}
|
client.requestBuilder = &failingRequestBuilder{}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
_, err := client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
|
||||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
|
|
||||||
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
|
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
|
||||||
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
@@ -48,25 +47,15 @@ func TestCompletionWithStream(t *testing.T) {
|
|||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
func TestCompletions(t *testing.T) {
|
func TestCompletions(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
|
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
req := CompletionRequest{
|
req := CompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: "ada",
|
Model: "ada",
|
||||||
|
Prompt: "Lorem ipsum",
|
||||||
}
|
}
|
||||||
req.Prompt = "Lorem ipsum"
|
_, err := client.CreateCompletion(context.Background(), req)
|
||||||
_, err = client.CreateCompletion(ctx, req)
|
|
||||||
checks.NoError(t, err, "CreateCompletion error")
|
checks.NoError(t, err, "CreateCompletion error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
@@ -16,19 +15,9 @@ import (
|
|||||||
|
|
||||||
// TestEdits Tests the edits endpoint of the API using the mocked server.
|
// TestEdits Tests the edits endpoint of the API using the mocked server.
|
||||||
func TestEdits(t *testing.T) {
|
func TestEdits(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/edits", handleEditEndpoint)
|
server.RegisterHandler("/v1/edits", handleEditEndpoint)
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// create an edit request
|
// create an edit request
|
||||||
model := "ada"
|
model := "ada"
|
||||||
editReq := EditsRequest{
|
editReq := EditsRequest{
|
||||||
@@ -40,7 +29,7 @@ func TestEdits(t *testing.T) {
|
|||||||
Instruction: "test instruction",
|
Instruction: "test instruction",
|
||||||
N: 3,
|
N: 3,
|
||||||
}
|
}
|
||||||
response, err := client.Edits(ctx, editReq)
|
response, err := client.Edits(context.Background(), editReq)
|
||||||
checks.NoError(t, err, "Edits error")
|
checks.NoError(t, err, "Edits error")
|
||||||
if len(response.Choices) != editReq.N {
|
if len(response.Choices) != editReq.N {
|
||||||
t.Fatalf("edits does not properly return the correct number of choices")
|
t.Fatalf("edits does not properly return the correct number of choices")
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -67,7 +66,8 @@ func TestEmbeddingModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbeddingEndpoint(t *testing.T) {
|
func TestEmbeddingEndpoint(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -75,17 +75,6 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
|||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
// create the test server
|
_, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
|
||||||
checks.NoError(t, err, "CreateEmbeddings error")
|
checks.NoError(t, err, "CreateEmbeddings error")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,27 +8,29 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
|
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
|
||||||
func TestGetEngine(t *testing.T) {
|
func TestGetEngine(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
|
||||||
resBytes, _ := json.Marshal(Engine{})
|
resBytes, _ := json.Marshal(Engine{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
})
|
})
|
||||||
// create the test server
|
_, err := client.GetEngine(context.Background(), "text-davinci-003")
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err := client.GetEngine(ctx, "text-davinci-003")
|
|
||||||
checks.NoError(t, err, "GetEngine error")
|
checks.NoError(t, err, "GetEngine error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestListEngines Tests the list engines endpoint of the API using the mocked server.
|
||||||
|
func TestListEngines(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resBytes, _ := json.Marshal(EnginesList{})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
})
|
||||||
|
_, err := client.ListEngines(context.Background())
|
||||||
|
checks.NoError(t, err, "ListEngines error")
|
||||||
|
}
|
||||||
|
|||||||
183
files_api_test.go
Normal file
183
files_api_test.go
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
package openai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFileUpload(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/files", handleCreateFile)
|
||||||
|
req := FileRequest{
|
||||||
|
FileName: "test.go",
|
||||||
|
FilePath: "client.go",
|
||||||
|
Purpose: "fine-tune",
|
||||||
|
}
|
||||||
|
_, err := client.CreateFile(context.Background(), req)
|
||||||
|
checks.NoError(t, err, "CreateFile error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCreateFile Handles the images endpoint by the test server.
|
||||||
|
func handleCreateFile(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var err error
|
||||||
|
var resBytes []byte
|
||||||
|
|
||||||
|
// edits only accepts POST requests
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
err = r.ParseMultipartForm(1024 * 1024 * 1024)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "file is more than 1GB", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
values := r.Form
|
||||||
|
var purpose string
|
||||||
|
for key, value := range values {
|
||||||
|
if key == "purpose" {
|
||||||
|
purpose = value[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
file, header, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
var fileReq = File{
|
||||||
|
Bytes: int(header.Size),
|
||||||
|
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||||
|
FileName: header.Filename,
|
||||||
|
Purpose: purpose,
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
Object: "test-objecct",
|
||||||
|
Owner: "test-owner",
|
||||||
|
}
|
||||||
|
|
||||||
|
resBytes, _ = json.Marshal(fileReq)
|
||||||
|
fmt.Fprint(w, string(resBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteFile(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
err := client.DeleteFile(context.Background(), "deadbeef")
|
||||||
|
checks.NoError(t, err, "DeleteFile error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListFile(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resBytes, _ := json.Marshal(FilesList{})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
})
|
||||||
|
_, err := client.ListFiles(context.Background())
|
||||||
|
checks.NoError(t, err, "ListFiles error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFile(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resBytes, _ := json.Marshal(File{})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
})
|
||||||
|
_, err := client.GetFile(context.Background(), "deadbeef")
|
||||||
|
checks.NoError(t, err, "GetFile error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContent(t *testing.T) {
|
||||||
|
wantRespJsonl := `{"prompt": "foo", "completion": "foo"}
|
||||||
|
{"prompt": "bar", "completion": "bar"}
|
||||||
|
{"prompt": "baz", "completion": "baz"}
|
||||||
|
`
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// edits only accepts GET requests
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
fmt.Fprint(w, wantRespJsonl)
|
||||||
|
})
|
||||||
|
|
||||||
|
content, err := client.GetFileContent(context.Background(), "deadbeef")
|
||||||
|
checks.NoError(t, err, "GetFileContent error")
|
||||||
|
defer content.Close()
|
||||||
|
|
||||||
|
actual, _ := io.ReadAll(content)
|
||||||
|
if string(actual) != wantRespJsonl {
|
||||||
|
t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContentReturnError(t *testing.T) {
|
||||||
|
wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts."
|
||||||
|
wantType := "invalid_request_error"
|
||||||
|
wantErrorResp := `{
|
||||||
|
"error": {
|
||||||
|
"message": "` + wantMessage + `",
|
||||||
|
"type": "` + wantType + `",
|
||||||
|
"param": null,
|
||||||
|
"code": null
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprint(w, wantErrorResp)
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := client.GetFileContent(context.Background(), "deadbeef")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Did not return error")
|
||||||
|
}
|
||||||
|
|
||||||
|
apiErr := &APIError{}
|
||||||
|
if !errors.As(err, &apiErr) {
|
||||||
|
t.Fatalf("Did not return APIError: %+v\n", apiErr)
|
||||||
|
}
|
||||||
|
if apiErr.Message != wantMessage {
|
||||||
|
t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if apiErr.Type != wantType {
|
||||||
|
t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContentReturnTimeoutError(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(10 * time.Nanosecond)
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := client.GetFileContent(ctx, "deadbeef")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Did not return error")
|
||||||
|
}
|
||||||
|
if !os.IsTimeout(err) {
|
||||||
|
t.Fatal("Did not return timeout error")
|
||||||
|
}
|
||||||
|
}
|
||||||
236
files_test.go
236
files_test.go
@@ -2,86 +2,15 @@ package openai //nolint:testpackage // testing private field
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFileUpload(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/files", handleCreateFile)
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
req := FileRequest{
|
|
||||||
FileName: "test.go",
|
|
||||||
FilePath: "client.go",
|
|
||||||
Purpose: "fine-tune",
|
|
||||||
}
|
|
||||||
_, err = client.CreateFile(ctx, req)
|
|
||||||
checks.NoError(t, err, "CreateFile error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleCreateFile Handles the images endpoint by the test server.
|
|
||||||
func handleCreateFile(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var err error
|
|
||||||
var resBytes []byte
|
|
||||||
|
|
||||||
// edits only accepts POST requests
|
|
||||||
if r.Method != "POST" {
|
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
}
|
|
||||||
err = r.ParseMultipartForm(1024 * 1024 * 1024)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "file is more than 1GB", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
values := r.Form
|
|
||||||
var purpose string
|
|
||||||
for key, value := range values {
|
|
||||||
if key == "purpose" {
|
|
||||||
purpose = value[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
file, header, err := r.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
var fileReq = File{
|
|
||||||
Bytes: int(header.Size),
|
|
||||||
ID: strconv.Itoa(int(time.Now().Unix())),
|
|
||||||
FileName: header.Filename,
|
|
||||||
Purpose: purpose,
|
|
||||||
CreatedAt: time.Now().Unix(),
|
|
||||||
Object: "test-objecct",
|
|
||||||
Owner: "test-owner",
|
|
||||||
}
|
|
||||||
|
|
||||||
resBytes, _ = json.Marshal(fileReq)
|
|
||||||
fmt.Fprint(w, string(resBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
||||||
config := DefaultConfig("")
|
config := DefaultConfig("")
|
||||||
config.BaseURL = ""
|
config.BaseURL = ""
|
||||||
@@ -142,168 +71,3 @@ func TestFileUploadWithNonExistentPath(t *testing.T) {
|
|||||||
_, err := client.CreateFile(ctx, req)
|
_, err := client.CreateFile(ctx, req)
|
||||||
checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist")
|
checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteFile(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
|
|
||||||
})
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
err = client.DeleteFile(ctx, "deadbeef")
|
|
||||||
checks.NoError(t, err, "DeleteFile error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestListFile(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fmt.Fprint(w, "{}")
|
|
||||||
})
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err = client.ListFiles(ctx)
|
|
||||||
checks.NoError(t, err, "ListFiles error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetFile(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fmt.Fprint(w, "{}")
|
|
||||||
})
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err = client.GetFile(ctx, "deadbeef")
|
|
||||||
checks.NoError(t, err, "GetFile error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetFileContent(t *testing.T) {
|
|
||||||
wantRespJsonl := `{"prompt": "foo", "completion": "foo"}
|
|
||||||
{"prompt": "bar", "completion": "bar"}
|
|
||||||
{"prompt": "baz", "completion": "baz"}
|
|
||||||
`
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// edits only accepts GET requests
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
}
|
|
||||||
fmt.Fprint(w, wantRespJsonl)
|
|
||||||
})
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
content, err := client.GetFileContent(ctx, "deadbeef")
|
|
||||||
checks.NoError(t, err, "GetFileContent error")
|
|
||||||
defer content.Close()
|
|
||||||
|
|
||||||
actual, _ := io.ReadAll(content)
|
|
||||||
if string(actual) != wantRespJsonl {
|
|
||||||
t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetFileContentReturnError(t *testing.T) {
|
|
||||||
wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts."
|
|
||||||
wantType := "invalid_request_error"
|
|
||||||
wantErrorResp := `{
|
|
||||||
"error": {
|
|
||||||
"message": "` + wantMessage + `",
|
|
||||||
"type": "` + wantType + `",
|
|
||||||
"param": null,
|
|
||||||
"code": null
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
fmt.Fprint(w, wantErrorResp)
|
|
||||||
})
|
|
||||||
// create the test server
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err := client.GetFileContent(ctx, "deadbeef")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Did not return error")
|
|
||||||
}
|
|
||||||
|
|
||||||
apiErr := &APIError{}
|
|
||||||
if !errors.As(err, &apiErr) {
|
|
||||||
t.Fatalf("Did not return APIError: %+v\n", apiErr)
|
|
||||||
}
|
|
||||||
if apiErr.Message != wantMessage {
|
|
||||||
t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if apiErr.Type != wantType {
|
|
||||||
t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetFileContentReturnTimeoutError(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
time.Sleep(10 * time.Nanosecond)
|
|
||||||
})
|
|
||||||
// create the test server
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, err := client.GetFileContent(ctx, "deadbeef")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Did not return error")
|
|
||||||
}
|
|
||||||
if !os.IsTimeout(err) {
|
|
||||||
t.Fatal("Did not return timeout error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
@@ -16,7 +15,8 @@ const testFineTuneID = "fine-tune-id"
|
|||||||
|
|
||||||
// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server.
|
// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server.
|
||||||
func TestFineTunes(t *testing.T) {
|
func TestFineTunes(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/v1/fine-tunes",
|
"/v1/fine-tunes",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -59,18 +59,9 @@ func TestFineTunes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
_, err = client.ListFineTunes(ctx)
|
_, err := client.ListFineTunes(ctx)
|
||||||
checks.NoError(t, err, "ListFineTunes error")
|
checks.NoError(t, err, "ListFineTunes error")
|
||||||
|
|
||||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||||
|
|||||||
223
image_api_test.go
Normal file
223
image_api_test.go
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
package openai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestImages(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
|
||||||
|
_, err := client.CreateImage(context.Background(), ImageRequest{
|
||||||
|
Prompt: "Lorem ipsum",
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateImage error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleImageEndpoint Handles the images endpoint by the test server.
|
||||||
|
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var err error
|
||||||
|
var resBytes []byte
|
||||||
|
|
||||||
|
// imagess only accepts POST requests
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
var imageReq ImageRequest
|
||||||
|
if imageReq, err = getImageBody(r); err != nil {
|
||||||
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res := ImageResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
for i := 0; i < imageReq.N; i++ {
|
||||||
|
imageData := ImageResponseDataInner{}
|
||||||
|
switch imageReq.ResponseFormat {
|
||||||
|
case CreateImageResponseFormatURL, "":
|
||||||
|
imageData.URL = "https://example.com/image.png"
|
||||||
|
case CreateImageResponseFormatB64JSON:
|
||||||
|
// This decodes to "{}" in base64.
|
||||||
|
imageData.B64JSON = "e30K"
|
||||||
|
default:
|
||||||
|
http.Error(w, "invalid response format", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.Data = append(res.Data, imageData)
|
||||||
|
}
|
||||||
|
resBytes, _ = json.Marshal(res)
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// getImageBody Returns the body of the request to create a image.
|
||||||
|
func getImageBody(r *http.Request) (ImageRequest, error) {
|
||||||
|
image := ImageRequest{}
|
||||||
|
// read the request body
|
||||||
|
reqBody, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRequest{}, err
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(reqBody, &image)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRequest{}, err
|
||||||
|
}
|
||||||
|
return image, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestImageEdit(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
||||||
|
|
||||||
|
origin, err := os.Create("image.png")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("open origin file error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mask, err := os.Create("mask.png")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("open mask file error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
mask.Close()
|
||||||
|
origin.Close()
|
||||||
|
os.Remove("mask.png")
|
||||||
|
os.Remove("image.png")
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{
|
||||||
|
Image: origin,
|
||||||
|
Mask: mask,
|
||||||
|
Prompt: "There is a turtle in the pool",
|
||||||
|
N: 3,
|
||||||
|
Size: CreateImageSize1024x1024,
|
||||||
|
ResponseFormat: CreateImageResponseFormatURL,
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateImage error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestImageEditWithoutMask(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
||||||
|
|
||||||
|
origin, err := os.Create("image.png")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("open origin file error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
origin.Close()
|
||||||
|
os.Remove("image.png")
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{
|
||||||
|
Image: origin,
|
||||||
|
Prompt: "There is a turtle in the pool",
|
||||||
|
N: 3,
|
||||||
|
Size: CreateImageSize1024x1024,
|
||||||
|
ResponseFormat: CreateImageResponseFormatURL,
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateImage error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEditImageEndpoint Handles the images endpoint by the test server.
|
||||||
|
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var resBytes []byte
|
||||||
|
|
||||||
|
// imagess only accepts POST requests
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
responses := ImageResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Data: []ImageResponseDataInner{
|
||||||
|
{
|
||||||
|
URL: "test-url1",
|
||||||
|
B64JSON: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
URL: "test-url2",
|
||||||
|
B64JSON: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
URL: "test-url3",
|
||||||
|
B64JSON: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resBytes, _ = json.Marshal(responses)
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestImageVariation(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
|
||||||
|
|
||||||
|
origin, err := os.Create("image.png")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("open origin file error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
origin.Close()
|
||||||
|
os.Remove("image.png")
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = client.CreateVariImage(context.Background(), ImageVariRequest{
|
||||||
|
Image: origin,
|
||||||
|
N: 3,
|
||||||
|
Size: CreateImageSize1024x1024,
|
||||||
|
ResponseFormat: CreateImageResponseFormatURL,
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateImage error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleVariateImageEndpoint Handles the images endpoint by the test server.
|
||||||
|
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var resBytes []byte
|
||||||
|
|
||||||
|
// imagess only accepts POST requests
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
responses := ImageResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Data: []ImageResponseDataInner{
|
||||||
|
{
|
||||||
|
URL: "test-url1",
|
||||||
|
B64JSON: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
URL: "test-url2",
|
||||||
|
B64JSON: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
URL: "test-url3",
|
||||||
|
B64JSON: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resBytes, _ = json.Marshal(responses)
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
}
|
||||||
252
image_test.go
252
image_test.go
@@ -2,267 +2,15 @@ package openai //nolint:testpackage // testing private field
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestImages(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
req := ImageRequest{}
|
|
||||||
req.Prompt = "Lorem ipsum"
|
|
||||||
_, err = client.CreateImage(ctx, req)
|
|
||||||
checks.NoError(t, err, "CreateImage error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleImageEndpoint Handles the images endpoint by the test server.
|
|
||||||
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var err error
|
|
||||||
var resBytes []byte
|
|
||||||
|
|
||||||
// imagess only accepts POST requests
|
|
||||||
if r.Method != "POST" {
|
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
}
|
|
||||||
var imageReq ImageRequest
|
|
||||||
if imageReq, err = getImageBody(r); err != nil {
|
|
||||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res := ImageResponse{
|
|
||||||
Created: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
for i := 0; i < imageReq.N; i++ {
|
|
||||||
imageData := ImageResponseDataInner{}
|
|
||||||
switch imageReq.ResponseFormat {
|
|
||||||
case CreateImageResponseFormatURL, "":
|
|
||||||
imageData.URL = "https://example.com/image.png"
|
|
||||||
case CreateImageResponseFormatB64JSON:
|
|
||||||
// This decodes to "{}" in base64.
|
|
||||||
imageData.B64JSON = "e30K"
|
|
||||||
default:
|
|
||||||
http.Error(w, "invalid response format", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res.Data = append(res.Data, imageData)
|
|
||||||
}
|
|
||||||
resBytes, _ = json.Marshal(res)
|
|
||||||
fmt.Fprintln(w, string(resBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// getImageBody Returns the body of the request to create a image.
|
|
||||||
func getImageBody(r *http.Request) (ImageRequest, error) {
|
|
||||||
image := ImageRequest{}
|
|
||||||
// read the request body
|
|
||||||
reqBody, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return ImageRequest{}, err
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(reqBody, &image)
|
|
||||||
if err != nil {
|
|
||||||
return ImageRequest{}, err
|
|
||||||
}
|
|
||||||
return image, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestImageEdit(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
origin, err := os.Create("image.png")
|
|
||||||
if err != nil {
|
|
||||||
t.Error("open origin file error")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
mask, err := os.Create("mask.png")
|
|
||||||
if err != nil {
|
|
||||||
t.Error("open mask file error")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
mask.Close()
|
|
||||||
origin.Close()
|
|
||||||
os.Remove("mask.png")
|
|
||||||
os.Remove("image.png")
|
|
||||||
}()
|
|
||||||
|
|
||||||
req := ImageEditRequest{
|
|
||||||
Image: origin,
|
|
||||||
Mask: mask,
|
|
||||||
Prompt: "There is a turtle in the pool",
|
|
||||||
N: 3,
|
|
||||||
Size: CreateImageSize1024x1024,
|
|
||||||
ResponseFormat: CreateImageResponseFormatURL,
|
|
||||||
}
|
|
||||||
_, err = client.CreateEditImage(ctx, req)
|
|
||||||
checks.NoError(t, err, "CreateImage error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestImageEditWithoutMask(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
origin, err := os.Create("image.png")
|
|
||||||
if err != nil {
|
|
||||||
t.Error("open origin file error")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
origin.Close()
|
|
||||||
os.Remove("image.png")
|
|
||||||
}()
|
|
||||||
|
|
||||||
req := ImageEditRequest{
|
|
||||||
Image: origin,
|
|
||||||
Prompt: "There is a turtle in the pool",
|
|
||||||
N: 3,
|
|
||||||
Size: CreateImageSize1024x1024,
|
|
||||||
ResponseFormat: CreateImageResponseFormatURL,
|
|
||||||
}
|
|
||||||
_, err = client.CreateEditImage(ctx, req)
|
|
||||||
checks.NoError(t, err, "CreateImage error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleEditImageEndpoint Handles the images endpoint by the test server.
|
|
||||||
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var resBytes []byte
|
|
||||||
|
|
||||||
// imagess only accepts POST requests
|
|
||||||
if r.Method != "POST" {
|
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := ImageResponse{
|
|
||||||
Created: time.Now().Unix(),
|
|
||||||
Data: []ImageResponseDataInner{
|
|
||||||
{
|
|
||||||
URL: "test-url1",
|
|
||||||
B64JSON: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
URL: "test-url2",
|
|
||||||
B64JSON: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
URL: "test-url3",
|
|
||||||
B64JSON: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resBytes, _ = json.Marshal(responses)
|
|
||||||
fmt.Fprintln(w, string(resBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestImageVariation(t *testing.T) {
|
|
||||||
server := test.NewTestServer()
|
|
||||||
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
|
|
||||||
// create the test server
|
|
||||||
var err error
|
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
origin, err := os.Create("image.png")
|
|
||||||
if err != nil {
|
|
||||||
t.Error("open origin file error")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
origin.Close()
|
|
||||||
os.Remove("image.png")
|
|
||||||
}()
|
|
||||||
|
|
||||||
req := ImageVariRequest{
|
|
||||||
Image: origin,
|
|
||||||
N: 3,
|
|
||||||
Size: CreateImageSize1024x1024,
|
|
||||||
ResponseFormat: CreateImageResponseFormatURL,
|
|
||||||
}
|
|
||||||
_, err = client.CreateVariImage(ctx, req)
|
|
||||||
checks.NoError(t, err, "CreateImage error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleVariateImageEndpoint Handles the images endpoint by the test server.
|
|
||||||
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var resBytes []byte
|
|
||||||
|
|
||||||
// imagess only accepts POST requests
|
|
||||||
if r.Method != "POST" {
|
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := ImageResponse{
|
|
||||||
Created: time.Now().Unix(),
|
|
||||||
Data: []ImageResponseDataInner{
|
|
||||||
{
|
|
||||||
URL: "test-url1",
|
|
||||||
B64JSON: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
URL: "test-url2",
|
|
||||||
B64JSON: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
URL: "test-url3",
|
|
||||||
B64JSON: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resBytes, _ = json.Marshal(responses)
|
|
||||||
fmt.Fprintln(w, string(resBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockFormBuilder struct {
|
type mockFormBuilder struct {
|
||||||
mockCreateFormFile func(string, *os.File) error
|
mockCreateFormFile func(string, *os.File) error
|
||||||
mockCreateFormFileReader func(string, io.Reader, string) error
|
mockCreateFormFileReader func(string, io.Reader, string) error
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
@@ -12,85 +11,47 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestListModels Tests the models endpoint of the API using the mocked server.
|
// TestListModels Tests the list models endpoint of the API using the mocked server.
|
||||||
func TestListModels(t *testing.T) {
|
func TestListModels(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
server.RegisterHandler("/v1/models", handleModelsEndpoint)
|
defer teardown()
|
||||||
// create the test server
|
server.RegisterHandler("/v1/models", handleListModelsEndpoint)
|
||||||
var err error
|
_, err := client.ListModels(context.Background())
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err = client.ListModels(ctx)
|
|
||||||
checks.NoError(t, err, "ListModels error")
|
checks.NoError(t, err, "ListModels error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureListModels(t *testing.T) {
|
func TestAzureListModels(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupAzureTestServer()
|
||||||
server.RegisterHandler("/openai/models", handleModelsEndpoint)
|
defer teardown()
|
||||||
// create the test server
|
server.RegisterHandler("/openai/models", handleListModelsEndpoint)
|
||||||
var err error
|
_, err := client.ListModels(context.Background())
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
|
|
||||||
config.BaseURL = ts.URL
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err = client.ListModels(ctx)
|
|
||||||
checks.NoError(t, err, "ListModels error")
|
checks.NoError(t, err, "ListModels error")
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleModelsEndpoint Handles the models endpoint by the test server.
|
// handleListModelsEndpoint Handles the list models endpoint by the test server.
|
||||||
func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
|
func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(ModelsList{})
|
resBytes, _ := json.Marshal(ModelsList{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestGetModel Tests the retrieve model endpoint of the API using the mocked server.
|
// TestGetModel Tests the retrieve model endpoint of the API using the mocked server.
|
||||||
func TestGetModel(t *testing.T) {
|
func TestGetModel(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint)
|
server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint)
|
||||||
// create the test server
|
_, err := client.GetModel(context.Background(), "text-davinci-003")
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err := client.GetModel(ctx, "text-davinci-003")
|
|
||||||
checks.NoError(t, err, "GetModel error")
|
checks.NoError(t, err, "GetModel error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureGetModel(t *testing.T) {
|
func TestAzureGetModel(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupAzureTestServer()
|
||||||
server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint)
|
defer teardown()
|
||||||
// create the test server
|
server.RegisterHandler("/openai/models/text-davinci-003", handleGetModelEndpoint)
|
||||||
ts := server.OpenAITestServer()
|
_, err := client.GetModel(context.Background(), "text-davinci-003")
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
|
|
||||||
config.BaseURL = ts.URL
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err := client.GetModel(ctx, "text-davinci-003")
|
|
||||||
checks.NoError(t, err, "GetModel error")
|
checks.NoError(t, err, "GetModel error")
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleModelsEndpoint Handles the models endpoint by the test server.
|
// handleGetModelsEndpoint Handles the get model endpoint by the test server.
|
||||||
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(Model{})
|
resBytes, _ := json.Marshal(Model{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
@@ -18,26 +17,13 @@ import (
|
|||||||
|
|
||||||
// TestModeration Tests the moderations endpoint of the API using the mocked server.
|
// TestModeration Tests the moderations endpoint of the API using the mocked server.
|
||||||
func TestModerations(t *testing.T) {
|
func TestModerations(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
|
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
|
||||||
// create the test server
|
_, err := client.Moderations(context.Background(), ModerationRequest{
|
||||||
var err error
|
Model: ModerationTextStable,
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// create an edit request
|
|
||||||
model := "text-moderation-stable"
|
|
||||||
moderationReq := ModerationRequest{
|
|
||||||
Model: model,
|
|
||||||
Input: "I want to kill them.",
|
Input: "I want to kill them.",
|
||||||
}
|
})
|
||||||
_, err = client.Moderations(ctx, moderationReq)
|
|
||||||
checks.NoError(t, err, "Moderation error")
|
checks.NoError(t, err, "Moderation error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
28
openai_test.go
Normal file
28
openai_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package openai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) {
|
||||||
|
server = test.NewTestServer()
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
teardown = ts.Close
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client = NewClientWithConfig(config)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) {
|
||||||
|
server = test.NewTestServer()
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
teardown = ts.Close
|
||||||
|
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
|
||||||
|
config.BaseURL = ts.URL
|
||||||
|
client = NewClientWithConfig(config)
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
|
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
|
||||||
@@ -47,7 +49,17 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) {
|
|||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
}
|
}
|
||||||
_, err := stream.Recv()
|
_, err := stream.Recv()
|
||||||
if !errors.Is(err, ErrTooManyEmptyStreamMessages) {
|
checks.ErrorIs(t, err, ErrTooManyEmptyStreamMessages, "Did not return error when recv failed", err.Error())
|
||||||
t.Fatalf("Did not return error when recv failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
|
||||||
|
stream := &streamReader[ChatCompletionStreamResponse]{
|
||||||
|
reader: bufio.NewReader(bytes.NewReader([]byte("\n"))),
|
||||||
|
errAccumulator: &utils.DefaultErrorAccumulator{
|
||||||
|
Buffer: &test.FailingErrorBuffer{},
|
||||||
|
},
|
||||||
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
|
}
|
||||||
|
_, err := stream.Recv()
|
||||||
|
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
147
stream_test.go
147
stream_test.go
@@ -6,11 +6,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,7 +30,9 @@ func TestCompletionsStreamWrongModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateCompletionStream(t *testing.T) {
|
func TestCreateCompletionStream(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -52,28 +52,14 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
}))
|
})
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = server.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := CompletionRequest{
|
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(ctx, request)
|
|
||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
@@ -116,7 +102,9 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateCompletionStreamError(t *testing.T) {
|
func TestCreateCompletionStreamError(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -137,28 +125,14 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
|||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
}))
|
})
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = server.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := CompletionRequest{
|
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3TextDavinci003,
|
Model: GPT3TextDavinci003,
|
||||||
Prompt: "Hello!",
|
Prompt: "Hello!",
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(ctx, request)
|
|
||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
@@ -173,7 +147,8 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateCompletionStreamRateLimitError(t *testing.T) {
|
func TestCreateCompletionStreamRateLimitError(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(429)
|
w.WriteHeader(429)
|
||||||
@@ -188,30 +163,14 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
})
|
})
|
||||||
ts := server.OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
var apiErr *APIError
|
||||||
config := DefaultConfig(test.GetTestToken())
|
_, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := CompletionRequest{
|
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Model: GPT3Ada,
|
Model: GPT3Ada,
|
||||||
Prompt: "Hello!",
|
Prompt: "Hello!",
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
var apiErr *APIError
|
|
||||||
_, err := client.CreateCompletionStream(ctx, request)
|
|
||||||
if !errors.As(err, &apiErr) {
|
if !errors.As(err, &apiErr) {
|
||||||
t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError")
|
t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError")
|
||||||
}
|
}
|
||||||
@@ -219,7 +178,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
|
func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -244,28 +205,14 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
|
|||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
}))
|
})
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = server.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := CompletionRequest{
|
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(ctx, request)
|
|
||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
@@ -277,7 +224,9 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
|
func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -291,28 +240,14 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
|
|||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
}))
|
})
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = server.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := CompletionRequest{
|
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(ctx, request)
|
|
||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
@@ -324,7 +259,9 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
|
func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
// Send test responses
|
// Send test responses
|
||||||
@@ -344,28 +281,14 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
|
|||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
checks.NoError(t, err, "Write error")
|
checks.NoError(t, err, "Write error")
|
||||||
}))
|
})
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Client portion of the test
|
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = server.URL + "/v1"
|
|
||||||
config.HTTPClient.Transport = &test.TokenRoundTripper{
|
|
||||||
Token: test.GetTestToken(),
|
|
||||||
Fallback: http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
request := CompletionRequest{
|
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
Model: "text-davinci-002",
|
Model: "text-davinci-002",
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
})
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(ctx, request)
|
|
||||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user