From 53212c71dff4a8813a3152610a67bba0b51dd22a Mon Sep 17 00:00:00 2001 From: Oleg <97077423+RobotSail@users.noreply.github.com> Date: Tue, 2 Aug 2022 04:16:54 -0400 Subject: [PATCH] Migrate From Old Completions + Embedding Endpoint (#28) * migrate away from deprecated OpenAI endpoints Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com> * test embedding correctness Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com> --- README.md | 3 ++- api_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++--- completion.go | 32 +++++++------------------------- embeddings.go | 10 +++++++--- search.go | 10 +++++----- 5 files changed, 69 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 39204ba..3fb5d8c 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,11 @@ func main() { ctx := context.Background() req := gogpt.CompletionRequest{ + Model: "ada", MaxTokens: 5, Prompt: "Lorem ipsum", } - resp, err := c.CreateCompletion(ctx, "ada", req) + resp, err := c.CreateCompletion(ctx, req) if err != nil { return } diff --git a/api_test.go b/api_test.go index a5da27d..089897b 100644 --- a/api_test.go +++ b/api_test.go @@ -1,7 +1,9 @@ package gogpt import ( + "bytes" "context" + "encoding/json" "io/ioutil" "testing" ) @@ -36,9 +38,12 @@ func TestAPI(t *testing.T) { } } // else skip - req := CompletionRequest{MaxTokens: 5} + req := CompletionRequest{ + MaxTokens: 5, + Model: "ada", + } req.Prompt = "Lorem ipsum" - _, err = c.CreateCompletion(ctx, "ada", req) + _, err = c.CreateCompletion(ctx, req) if err != nil { t.Fatalf("CreateCompletion error: %v", err) } @@ -57,9 +62,49 @@ func TestAPI(t *testing.T) { "The food was delicious and the waiter", "Other examples of embedding request", }, + Model: AdaSearchQuery, } - _, err = c.CreateEmbeddings(ctx, embeddingReq, AdaSearchQuery) + _, err = c.CreateEmbeddings(ctx, embeddingReq) if err != nil { t.Fatalf("Embedding error: %v", err) } } + +func TestEmbedding(t *testing.T) { + embeddedModels := []EmbeddingModel{ + AdaSimilarity, + BabbageSimilarity, + CurieSimilarity, + DavinciSimilarity, + AdaSearchDocument, + AdaSearchQuery, + BabbageSearchDocument, + BabbageSearchQuery, + CurieSearchDocument, + CurieSearchQuery, + DavinciSearchDocument, + DavinciSearchQuery, + AdaCodeSearchCode, + AdaCodeSearchText, + BabbageCodeSearchCode, + BabbageCodeSearchText, + } + for _, model := range embeddedModels { + embeddingReq := EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + } + // marshal embeddingReq to JSON and confirm that the model field equals + // the AdaSearchQuery type + marshaled, err := json.Marshal(embeddingReq) + if err != nil { + t.Fatalf("Could not marshal embedding request: %v", err) + } + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + } +} diff --git a/completion.go b/completion.go index 43f9111..bdbee56 100644 --- a/completion.go +++ b/completion.go @@ -4,13 +4,12 @@ import ( "bytes" "context" "encoding/json" - "fmt" "net/http" ) // CompletionRequest represents a request structure for completion API type CompletionRequest struct { - Model *string `json:"model,omitempty"` + Model string `json:"model"` Prompt string `json:"prompt,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` @@ -60,29 +59,12 @@ type CompletionResponse struct { Usage CompletionUsage `json:"usage"` } -// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position. -func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - - urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID) - req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - if err != nil { - return - } - - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return -} - -// CreateCompletionWithFineTunedModel - API call to create a completion with a fine tuned model -// See https://beta.openai.com/docs/guides/fine-tuning/use-a-fine-tuned-model -// In this case, the model is specified in the CompletionRequest object. -func (c *Client) CreateCompletionWithFineTunedModel(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) { +// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well +// as, if requested, the probabilities over each alternative token at each position. +// +// If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object, +// and the server will use the model's parameters to generate the completion. +func (c *Client) CreateCompletion(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) { var reqBytes []byte reqBytes, err = json.Marshal(request) if err != nil { diff --git a/embeddings.go b/embeddings.go index fb6c199..ec263f5 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "fmt" "net/http" ) @@ -120,18 +119,23 @@ type EmbeddingRequest struct { // E.g. // "The food was delicious and the waiter..." Input []string `json:"input"` + // ID of the model to use. You can use the List models API to see all of your available models, + // or see our Model overview for descriptions of them. + Model EmbeddingModel `json:"model"` + // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. + User string `json:"user"` } // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create -func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest, model EmbeddingModel) (resp EmbeddingResponse, err error) { +func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { var reqBytes []byte reqBytes, err = json.Marshal(request) if err != nil { return } - urlSuffix := fmt.Sprintf("/engines/%s/embeddings", model) + urlSuffix := "/embeddings" req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) if err != nil { return diff --git a/search.go b/search.go index f44191e..b14442d 100644 --- a/search.go +++ b/search.go @@ -17,11 +17,11 @@ import ( */ type SearchRequest struct { Query string `json:"query"` - Documents []string `json:"documents"` // 1* - FileID string `json:"file"` // 1* - MaxRerank int `json:"max_rerank"` // 2* - ReturnMetadata bool `json:"return_metadata"` - User string `json:"user"` + Documents []string `json:"documents"` // 1* + FileID string `json:"file,omitempty"` // 1* + MaxRerank int `json:"max_rerank,omitempty"` // 2* + ReturnMetadata bool `json:"return_metadata,omitempty"` + User string `json:"user,omitempty"` } // SearchResult represents single result from search API