* Support Retrieve model API (#340) * Test for GetModel error cases. (#340) * Reduce the cognitive complexity of TestClientReturnsRequestBuilderErrors (#340)
This commit is contained in:
committed by
GitHub
parent
1394329e44
commit
6830e00406
162
client_test.go
162
client_test.go
@@ -170,104 +170,82 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
type TestCase struct {
|
||||
Name string
|
||||
TestFunc func() (any, error)
|
||||
}
|
||||
|
||||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
testCases := []TestCase{
|
||||
{"CreateCompletion", func() (any, error) {
|
||||
return client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
||||
}},
|
||||
{"CreateCompletionStream", func() (any, error) {
|
||||
return client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
||||
}},
|
||||
{"CreateChatCompletion", func() (any, error) {
|
||||
return client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
}},
|
||||
{"CreateChatCompletionStream", func() (any, error) {
|
||||
return client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
}},
|
||||
{"CreateFineTune", func() (any, error) {
|
||||
return client.CreateFineTune(ctx, FineTuneRequest{})
|
||||
}},
|
||||
{"ListFineTunes", func() (any, error) {
|
||||
return client.ListFineTunes(ctx)
|
||||
}},
|
||||
{"CancelFineTune", func() (any, error) {
|
||||
return client.CancelFineTune(ctx, "")
|
||||
}},
|
||||
{"GetFineTune", func() (any, error) {
|
||||
return client.GetFineTune(ctx, "")
|
||||
}},
|
||||
{"DeleteFineTune", func() (any, error) {
|
||||
return client.DeleteFineTune(ctx, "")
|
||||
}},
|
||||
{"ListFineTuneEvents", func() (any, error) {
|
||||
return client.ListFineTuneEvents(ctx, "")
|
||||
}},
|
||||
{"Moderations", func() (any, error) {
|
||||
return client.Moderations(ctx, ModerationRequest{})
|
||||
}},
|
||||
{"Edits", func() (any, error) {
|
||||
return client.Edits(ctx, EditsRequest{})
|
||||
}},
|
||||
{"CreateEmbeddings", func() (any, error) {
|
||||
return client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||
}},
|
||||
{"CreateImage", func() (any, error) {
|
||||
return client.CreateImage(ctx, ImageRequest{})
|
||||
}},
|
||||
{"DeleteFile", func() (any, error) {
|
||||
return nil, client.DeleteFile(ctx, "")
|
||||
}},
|
||||
{"GetFile", func() (any, error) {
|
||||
return client.GetFile(ctx, "")
|
||||
}},
|
||||
{"ListFiles", func() (any, error) {
|
||||
return client.ListFiles(ctx)
|
||||
}},
|
||||
{"ListEngines", func() (any, error) {
|
||||
return client.ListEngines(ctx)
|
||||
}},
|
||||
{"GetEngine", func() (any, error) {
|
||||
return client.GetEngine(ctx, "")
|
||||
}},
|
||||
{"ListModels", func() (any, error) {
|
||||
return client.ListModels(ctx)
|
||||
}},
|
||||
{"GetModel", func() (any, error) {
|
||||
return client.GetModel(ctx, "text-davinci-003")
|
||||
}},
|
||||
}
|
||||
|
||||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
for _, testCase := range testCases {
|
||||
_, err = testCase.TestFunc()
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err)
|
||||
}
|
||||
|
||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFineTunes(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CancelFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.DeleteFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFineTuneEvents(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.Moderations(ctx, ModerationRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.Edits(ctx, EditsRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateImage(ctx, ImageRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
err = client.DeleteFile(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetFile(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFiles(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListEngines(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetEngine(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListModels(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
14
models.go
14
models.go
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -48,3 +49,16 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error)
|
||||
err = c.sendRequest(req, &models)
|
||||
return
|
||||
}
|
||||
|
||||
// GetModel Retrieves a model instance, providing basic information about
|
||||
// the model such as the owner and permissioning.
|
||||
func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) {
|
||||
urlSuffix := fmt.Sprintf("/models/%s", modelID)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.sendRequest(req, &model)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -54,3 +54,44 @@ func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||
resBytes, _ := json.Marshal(ModelsList{})
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
// TestGetModel Tests the retrieve model endpoint of the API using the mocked server.
|
||||
func TestGetModel(t *testing.T) {
|
||||
server := test.NewTestServer()
|
||||
server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint)
|
||||
// create the test server
|
||||
ts := server.OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.GetModel(ctx, "text-davinci-003")
|
||||
checks.NoError(t, err, "GetModel error")
|
||||
}
|
||||
|
||||
func TestAzureGetModel(t *testing.T) {
|
||||
server := test.NewTestServer()
|
||||
server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint)
|
||||
// create the test server
|
||||
ts := server.OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
|
||||
config.BaseURL = ts.URL
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.GetModel(ctx, "text-davinci-003")
|
||||
checks.NoError(t, err, "GetModel error")
|
||||
}
|
||||
|
||||
// handleModelsEndpoint Handles the models endpoint by the test server.
|
||||
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||
resBytes, _ := json.Marshal(Model{})
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user