Implement the fine-tunes API (#130)

- Add FineTune Structs and Requests
- Add CRUD Methods
This commit is contained in:
Matt Trefilek
2023-03-08 04:09:08 -06:00
committed by GitHub
parent c46ebb2f08
commit c380d5031b

137
fine_tunes.go Normal file
View File

@@ -0,0 +1,137 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
type FineTuneRequest struct {
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
Model string `json:"model,omitempty"`
Epochs int `json:"n_epochs,omitempty"`
BatchSize int `json:"batch_size,omitempty"`
LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"`
PromptLossRate float32 `json:"prompt_loss_rate,omitempty"`
ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"`
ClassificationClasses int `json:"classification_n_classes,omitempty"`
ClassificationPositiveClass string `json:"classification_positive_class,omitempty"`
ClassificationBetas []float32 `json:"classification_betas,omitempty"`
Suffix string `json:"suffix,omitempty"`
}
type FineTune struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
CreatedAt int `json:"created_at"`
FineTunedModel string `json:"fine_tuned_model"`
Hyperparams FineTuneHyperParams `json:"hyperparams"`
OrganizationID string `json:"organization_id"`
ResultFiles []File `json:"result_files"`
Status string `json:"status"`
ValidationFiles []File `json:"validation_files"`
TrainingFiles []File `json:"training_files"`
UpdatedAt int `json:"updated_at"`
}
type FineTuneEvent struct {
Object string `json:"object"`
CreatedAt int `json:"created_at"`
Level string `json:"level"`
Message string `json:"message"`
}
type FineTuneHyperParams struct {
BatchSize int `json:"batch_size"`
LearningRateMultiplier float64 `json:"learning_rate_multiplier"`
Epochs int `json:"n_epochs"`
PromptLossWeight float64 `json:"prompt_loss_weight"`
}
type FineTuneList struct {
Object string `json:"object"`
Data []FineTune `json:"data"`
}
type FineTuneEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
}
type FineTuneDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`
}
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}
urlSuffix := "/fine-tunes"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
// Cancel a fine-tune job.
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}