Migrate From Old Completions + Embedding Endpoint (#28)

* migrate away from deprecated OpenAI endpoints

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>

* test embedding correctness

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>
This commit is contained in:
Oleg
2022-08-02 04:16:54 -04:00
committed by GitHub
parent 51f94a6ab3
commit 53212c71df
5 changed files with 69 additions and 37 deletions

View File

@@ -27,10 +27,11 @@ func main() {
ctx := context.Background() ctx := context.Background()
req := gogpt.CompletionRequest{ req := gogpt.CompletionRequest{
Model: "ada",
MaxTokens: 5, MaxTokens: 5,
Prompt: "Lorem ipsum", Prompt: "Lorem ipsum",
} }
resp, err := c.CreateCompletion(ctx, "ada", req) resp, err := c.CreateCompletion(ctx, req)
if err != nil { if err != nil {
return return
} }

View File

@@ -1,7 +1,9 @@
package gogpt package gogpt
import ( import (
"bytes"
"context" "context"
"encoding/json"
"io/ioutil" "io/ioutil"
"testing" "testing"
) )
@@ -36,9 +38,12 @@ func TestAPI(t *testing.T) {
} }
} // else skip } // else skip
req := CompletionRequest{MaxTokens: 5} req := CompletionRequest{
MaxTokens: 5,
Model: "ada",
}
req.Prompt = "Lorem ipsum" req.Prompt = "Lorem ipsum"
_, err = c.CreateCompletion(ctx, "ada", req) _, err = c.CreateCompletion(ctx, req)
if err != nil { if err != nil {
t.Fatalf("CreateCompletion error: %v", err) t.Fatalf("CreateCompletion error: %v", err)
} }
@@ -57,9 +62,49 @@ func TestAPI(t *testing.T) {
"The food was delicious and the waiter", "The food was delicious and the waiter",
"Other examples of embedding request", "Other examples of embedding request",
}, },
Model: AdaSearchQuery,
} }
_, err = c.CreateEmbeddings(ctx, embeddingReq, AdaSearchQuery) _, err = c.CreateEmbeddings(ctx, embeddingReq)
if err != nil { if err != nil {
t.Fatalf("Embedding error: %v", err) t.Fatalf("Embedding error: %v", err)
} }
} }
func TestEmbedding(t *testing.T) {
embeddedModels := []EmbeddingModel{
AdaSimilarity,
BabbageSimilarity,
CurieSimilarity,
DavinciSimilarity,
AdaSearchDocument,
AdaSearchQuery,
BabbageSearchDocument,
BabbageSearchQuery,
CurieSearchDocument,
CurieSearchQuery,
DavinciSearchDocument,
DavinciSearchQuery,
AdaCodeSearchCode,
AdaCodeSearchText,
BabbageCodeSearchCode,
BabbageCodeSearchText,
}
for _, model := range embeddedModels {
embeddingReq := EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: model,
}
// marshal embeddingReq to JSON and confirm that the model field equals
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
if err != nil {
t.Fatalf("Could not marshal embedding request: %v", err)
}
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
}
}

View File

@@ -4,13 +4,12 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
) )
// CompletionRequest represents a request structure for completion API // CompletionRequest represents a request structure for completion API
type CompletionRequest struct { type CompletionRequest struct {
Model *string `json:"model,omitempty"` Model string `json:"model"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"` Temperature float32 `json:"temperature,omitempty"`
@@ -60,29 +59,12 @@ type CompletionResponse struct {
Usage CompletionUsage `json:"usage"` Usage CompletionUsage `json:"usage"`
} }
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position. // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well
func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) { // as, if requested, the probabilities over each alternative token at each position.
var reqBytes []byte //
reqBytes, err = json.Marshal(request) // If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object,
if err != nil { // and the server will use the model's parameters to generate the completion.
return func (c *Client) CreateCompletion(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) {
}
urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID)
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
if err != nil {
return
}
req = req.WithContext(ctx)
err = c.sendRequest(req, &response)
return
}
// CreateCompletionWithFineTunedModel - API call to create a completion with a fine tuned model
// See https://beta.openai.com/docs/guides/fine-tuning/use-a-fine-tuned-model
// In this case, the model is specified in the CompletionRequest object.
func (c *Client) CreateCompletionWithFineTunedModel(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = json.Marshal(request)
if err != nil { if err != nil {

View File

@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
) )
@@ -120,18 +119,23 @@ type EmbeddingRequest struct {
// E.g. // E.g.
// "The food was delicious and the waiter..." // "The food was delicious and the waiter..."
Input []string `json:"input"` Input []string `json:"input"`
// ID of the model to use. You can use the List models API to see all of your available models,
// or see our Model overview for descriptions of them.
Model EmbeddingModel `json:"model"`
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
User string `json:"user"`
} }
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create // https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest, model EmbeddingModel) (resp EmbeddingResponse, err error) { func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = json.Marshal(request)
if err != nil { if err != nil {
return return
} }
urlSuffix := fmt.Sprintf("/engines/%s/embeddings", model) urlSuffix := "/embeddings"
req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
if err != nil { if err != nil {
return return

View File

@@ -18,10 +18,10 @@ import (
type SearchRequest struct { type SearchRequest struct {
Query string `json:"query"` Query string `json:"query"`
Documents []string `json:"documents"` // 1* Documents []string `json:"documents"` // 1*
FileID string `json:"file"` // 1* FileID string `json:"file,omitempty"` // 1*
MaxRerank int `json:"max_rerank"` // 2* MaxRerank int `json:"max_rerank,omitempty"` // 2*
ReturnMetadata bool `json:"return_metadata"` ReturnMetadata bool `json:"return_metadata,omitempty"`
User string `json:"user"` User string `json:"user,omitempty"`
} }
// SearchResult represents single result from search API // SearchResult represents single result from search API