* Support Retrieve model API (#340) * Test for GetModel error cases. (#340) * Reduce the cognitive complexity of TestClientReturnsRequestBuilderErrors (#340)
This commit is contained in:
committed by
GitHub
parent
1394329e44
commit
6830e00406
162
client_test.go
162
client_test.go
@@ -170,104 +170,82 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
type TestCase struct {
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
Name string
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
TestFunc func() (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
testCases := []TestCase{
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
{"CreateCompletion", func() (any, error) {
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
return client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
||||||
|
}},
|
||||||
|
{"CreateCompletionStream", func() (any, error) {
|
||||||
|
return client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
||||||
|
}},
|
||||||
|
{"CreateChatCompletion", func() (any, error) {
|
||||||
|
return client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||||
|
}},
|
||||||
|
{"CreateChatCompletionStream", func() (any, error) {
|
||||||
|
return client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||||
|
}},
|
||||||
|
{"CreateFineTune", func() (any, error) {
|
||||||
|
return client.CreateFineTune(ctx, FineTuneRequest{})
|
||||||
|
}},
|
||||||
|
{"ListFineTunes", func() (any, error) {
|
||||||
|
return client.ListFineTunes(ctx)
|
||||||
|
}},
|
||||||
|
{"CancelFineTune", func() (any, error) {
|
||||||
|
return client.CancelFineTune(ctx, "")
|
||||||
|
}},
|
||||||
|
{"GetFineTune", func() (any, error) {
|
||||||
|
return client.GetFineTune(ctx, "")
|
||||||
|
}},
|
||||||
|
{"DeleteFineTune", func() (any, error) {
|
||||||
|
return client.DeleteFineTune(ctx, "")
|
||||||
|
}},
|
||||||
|
{"ListFineTuneEvents", func() (any, error) {
|
||||||
|
return client.ListFineTuneEvents(ctx, "")
|
||||||
|
}},
|
||||||
|
{"Moderations", func() (any, error) {
|
||||||
|
return client.Moderations(ctx, ModerationRequest{})
|
||||||
|
}},
|
||||||
|
{"Edits", func() (any, error) {
|
||||||
|
return client.Edits(ctx, EditsRequest{})
|
||||||
|
}},
|
||||||
|
{"CreateEmbeddings", func() (any, error) {
|
||||||
|
return client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||||
|
}},
|
||||||
|
{"CreateImage", func() (any, error) {
|
||||||
|
return client.CreateImage(ctx, ImageRequest{})
|
||||||
|
}},
|
||||||
|
{"DeleteFile", func() (any, error) {
|
||||||
|
return nil, client.DeleteFile(ctx, "")
|
||||||
|
}},
|
||||||
|
{"GetFile", func() (any, error) {
|
||||||
|
return client.GetFile(ctx, "")
|
||||||
|
}},
|
||||||
|
{"ListFiles", func() (any, error) {
|
||||||
|
return client.ListFiles(ctx)
|
||||||
|
}},
|
||||||
|
{"ListEngines", func() (any, error) {
|
||||||
|
return client.ListEngines(ctx)
|
||||||
|
}},
|
||||||
|
{"GetEngine", func() (any, error) {
|
||||||
|
return client.GetEngine(ctx, "")
|
||||||
|
}},
|
||||||
|
{"ListModels", func() (any, error) {
|
||||||
|
return client.ListModels(ctx)
|
||||||
|
}},
|
||||||
|
{"GetModel", func() (any, error) {
|
||||||
|
return client.GetModel(ctx, "text-davinci-003")
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
for _, testCase := range testCases {
|
||||||
|
_, err = testCase.TestFunc()
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListFineTunes(ctx)
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CancelFineTune(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.GetFineTune(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.DeleteFineTune(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListFineTuneEvents(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.Moderations(ctx, ModerationRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.Edits(ctx, EditsRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateImage(ctx, ImageRequest{})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = client.DeleteFile(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.GetFile(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListFiles(ctx)
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListEngines(ctx)
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.GetEngine(ctx, "")
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListModels(ctx)
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
|
||||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
|
||||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
14
models.go
14
models.go
@@ -2,6 +2,7 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,3 +49,16 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error)
|
|||||||
err = c.sendRequest(req, &models)
|
err = c.sendRequest(req, &models)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetModel Retrieves a model instance, providing basic information about
|
||||||
|
// the model such as the owner and permissioning.
|
||||||
|
func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) {
|
||||||
|
urlSuffix := fmt.Sprintf("/models/%s", modelID)
|
||||||
|
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.sendRequest(req, &model)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,3 +54,44 @@ func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
|
|||||||
resBytes, _ := json.Marshal(ModelsList{})
|
resBytes, _ := json.Marshal(ModelsList{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGetModel Tests the retrieve model endpoint of the API using the mocked server.
|
||||||
|
func TestGetModel(t *testing.T) {
|
||||||
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint)
|
||||||
|
// create the test server
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := client.GetModel(ctx, "text-davinci-003")
|
||||||
|
checks.NoError(t, err, "GetModel error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAzureGetModel(t *testing.T) {
|
||||||
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint)
|
||||||
|
// create the test server
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
|
||||||
|
config.BaseURL = ts.URL
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := client.GetModel(ctx, "text-davinci-003")
|
||||||
|
checks.NoError(t, err, "GetModel error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleModelsEndpoint Handles the models endpoint by the test server.
|
||||||
|
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
resBytes, _ := json.Marshal(Model{})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user