From 9428f6cc3df02488c3ada31cfa7005c4c3d6c027 Mon Sep 17 00:00:00 2001 From: sashabaranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 9 Mar 2023 23:56:23 +0400 Subject: [PATCH] add more tests (#140) * test models endpoint * simplify * add fine tune tests --- fine_tunes_test.go | 101 +++++++++++++++++++++++++++++++++++++++++++++ models_test.go | 39 +++++++++++++++++ 2 files changed, 140 insertions(+) create mode 100644 fine_tunes_test.go create mode 100644 models_test.go diff --git a/fine_tunes_test.go b/fine_tunes_test.go new file mode 100644 index 0000000..1f6f967 --- /dev/null +++ b/fine_tunes_test.go @@ -0,0 +1,101 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "fmt" + "net/http" + "testing" +) + +const testFineTuneID = "fine-tune-id" + +// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. +func TestFineTunes(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler( + "/v1/fine-tunes", + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + if r.Method == http.MethodGet { + resBytes, _ = json.Marshal(FineTuneList{}) + } else { + resBytes, _ = json.Marshal(FineTune{}) + } + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine-tunes/"+testFineTuneID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTune{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine-tunes/"+testFineTuneID, + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + if r.Method == http.MethodDelete { + resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) + } else { + resBytes, _ = json.Marshal(FineTune{}) + } + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine-tunes/"+testFineTuneID+"/events", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuneEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.ListFineTunes(ctx) + if err != nil { + t.Fatalf("ListFineTunes error: %v", err) + } + + _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + if err != nil { + t.Fatalf("CreateFineTune error: %v", err) + } + + _, err = client.CancelFineTune(ctx, testFineTuneID) + if err != nil { + t.Fatalf("CancelFineTune error: %v", err) + } + + _, err = client.GetFineTune(ctx, testFineTuneID) + if err != nil { + t.Fatalf("GetFineTune error: %v", err) + } + + _, err = client.DeleteFineTune(ctx, testFineTuneID) + if err != nil { + t.Fatalf("DeleteFineTune error: %v", err) + } + + _, err = client.ListFineTuneEvents(ctx, testFineTuneID) + if err != nil { + t.Fatalf("ListFineTuneEvents error: %v", err) + } +} diff --git a/models_test.go b/models_test.go new file mode 100644 index 0000000..c96ece8 --- /dev/null +++ b/models_test.go @@ -0,0 +1,39 @@ +package openai_test + +import ( + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + + "context" + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestListModels Tests the models endpoint of the API using the mocked server. +func TestListModels(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/models", handleModelsEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.ListModels(ctx) + if err != nil { + t.Fatalf("ListModels error: %v", err) + } +} + +// handleModelsEndpoint Handles the models endpoint by the test server. +func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(ModelsList{}) + fmt.Fprintln(w, string(resBytes)) +}