From a2ca01bb6dae1a7d58860a5b2d5d5273667e089e Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 29 Aug 2023 14:04:27 +0200 Subject: [PATCH] feat: implement new fine tuning job API (#479) * feat: implement new fine tuning job API * fix: export ListFineTuningJobEventsParameter * fix: lint errors * fix: test errors * fix: code test coverage * fix: code test coverage * fix: use any * chore: use url.Values --- client_test.go | 12 ++++ fine_tuning_job.go | 153 ++++++++++++++++++++++++++++++++++++++++ fine_tuning_job_test.go | 90 +++++++++++++++++++++++ 3 files changed, 255 insertions(+) create mode 100644 fine_tuning_job.go create mode 100644 fine_tuning_job_test.go diff --git a/client_test.go b/client_test.go index 29d84ed..9b50468 100644 --- a/client_test.go +++ b/client_test.go @@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListFineTuneEvents", func() (any, error) { return client.ListFineTuneEvents(ctx, "") }}, + {"CreateFineTuningJob", func() (any, error) { + return client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + }}, + {"CancelFineTuningJob", func() (any, error) { + return client.CancelFineTuningJob(ctx, "") + }}, + {"RetrieveFineTuningJob", func() (any, error) { + return client.RetrieveFineTuningJob(ctx, "") + }}, + {"ListFineTuningJobEvents", func() (any, error) { + return client.ListFineTuningJobEvents(ctx, "") + }}, {"Moderations", func() (any, error) { return client.Moderations(ctx, ModerationRequest{}) }}, diff --git a/fine_tuning_job.go b/fine_tuning_job.go new file mode 100644 index 0000000..a840b7e --- /dev/null +++ b/fine_tuning_job.go @@ -0,0 +1,153 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type FineTuningJob struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at"` + Model string `json:"model"` + FineTunedModel string `json:"fine_tuned_model,omitempty"` + OrganizationID string `json:"organization_id"` + Status string `json:"status"` + Hyperparameters Hyperparameters `json:"hyperparameters"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + ResultFiles []string `json:"result_files"` + TrainedTokens int `json:"trained_tokens"` +} + +type Hyperparameters struct { + Epochs int `json:"n_epochs"` +} + +type FineTuningJobRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTuningJobEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` + HasMore bool `json:"has_more"` +} + +type FineTuningJobEvent struct { + Object string `json:"object"` + ID string `json:"id"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + Data any `json:"data"` + Type string `json:"type"` +} + +// CreateFineTuningJob create a fine tuning job. +func (c *Client) CreateFineTuningJob( + ctx context.Context, + request FineTuningJobRequest, +) (response FineTuningJob, err error) { + urlSuffix := "/fine_tuning/jobs" + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelFineTuningJob cancel a fine tuning job. +func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel")) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveFineTuningJob retrieve a fine tuning job. +func (c *Client) RetrieveFineTuningJob( + ctx context.Context, + fineTuningJobID string, +) (response FineTuningJob, err error) { + urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type listFineTuningJobEventsParameters struct { + after *string + limit *int +} + +type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) + +func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.after = &after + } +} + +func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.limit = &limit + } +} + +// ListFineTuningJobs list fine tuning jobs events. +func (c *Client) ListFineTuningJobEvents( + ctx context.Context, + fineTuningJobID string, + setters ...ListFineTuningJobEventsParameter, +) (response FineTuningJobEventList, err error) { + parameters := &listFineTuningJobEventsParameters{ + after: nil, + limit: nil, + } + + for _, setter := range setters { + setter(parameters) + } + + urlValues := url.Values{} + if parameters.after != nil { + urlValues.Add("after", *parameters.after) + } + if parameters.limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit)) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go new file mode 100644 index 0000000..519c6cd --- /dev/null +++ b/fine_tuning_job_test.go @@ -0,0 +1,90 @@ +package openai_test + +import ( + "context" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +const testFineTuninigJobID = "fine-tuning-job-id" + +// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. +func TestFineTuningJob(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler( + "/v1/fine_tuning/jobs", + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID, + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJobEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + ctx := context.Background() + + _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + checks.NoError(t, err, "CreateFineTuningJob error") + + _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "CancelFineTuningJob error") + + _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "RetrieveFineTuningJob error") + + _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") +}