Migrate From Old Completions + Embedding Endpoint (#28)
* migrate away from deprecated OpenAI endpoints Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com> * test embedding correctness Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>
This commit is contained in:
@@ -27,10 +27,11 @@ func main() {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
req := gogpt.CompletionRequest{
|
req := gogpt.CompletionRequest{
|
||||||
|
Model: "ada",
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Prompt: "Lorem ipsum",
|
Prompt: "Lorem ipsum",
|
||||||
}
|
}
|
||||||
resp, err := c.CreateCompletion(ctx, "ada", req)
|
resp, err := c.CreateCompletion(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
51
api_test.go
51
api_test.go
@@ -1,7 +1,9 @@
|
|||||||
package gogpt
|
package gogpt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -36,9 +38,12 @@ func TestAPI(t *testing.T) {
|
|||||||
}
|
}
|
||||||
} // else skip
|
} // else skip
|
||||||
|
|
||||||
req := CompletionRequest{MaxTokens: 5}
|
req := CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: "ada",
|
||||||
|
}
|
||||||
req.Prompt = "Lorem ipsum"
|
req.Prompt = "Lorem ipsum"
|
||||||
_, err = c.CreateCompletion(ctx, "ada", req)
|
_, err = c.CreateCompletion(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("CreateCompletion error: %v", err)
|
t.Fatalf("CreateCompletion error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -57,9 +62,49 @@ func TestAPI(t *testing.T) {
|
|||||||
"The food was delicious and the waiter",
|
"The food was delicious and the waiter",
|
||||||
"Other examples of embedding request",
|
"Other examples of embedding request",
|
||||||
},
|
},
|
||||||
|
Model: AdaSearchQuery,
|
||||||
}
|
}
|
||||||
_, err = c.CreateEmbeddings(ctx, embeddingReq, AdaSearchQuery)
|
_, err = c.CreateEmbeddings(ctx, embeddingReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Embedding error: %v", err)
|
t.Fatalf("Embedding error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbedding(t *testing.T) {
|
||||||
|
embeddedModels := []EmbeddingModel{
|
||||||
|
AdaSimilarity,
|
||||||
|
BabbageSimilarity,
|
||||||
|
CurieSimilarity,
|
||||||
|
DavinciSimilarity,
|
||||||
|
AdaSearchDocument,
|
||||||
|
AdaSearchQuery,
|
||||||
|
BabbageSearchDocument,
|
||||||
|
BabbageSearchQuery,
|
||||||
|
CurieSearchDocument,
|
||||||
|
CurieSearchQuery,
|
||||||
|
DavinciSearchDocument,
|
||||||
|
DavinciSearchQuery,
|
||||||
|
AdaCodeSearchCode,
|
||||||
|
AdaCodeSearchText,
|
||||||
|
BabbageCodeSearchCode,
|
||||||
|
BabbageCodeSearchText,
|
||||||
|
}
|
||||||
|
for _, model := range embeddedModels {
|
||||||
|
embeddingReq := EmbeddingRequest{
|
||||||
|
Input: []string{
|
||||||
|
"The food was delicious and the waiter",
|
||||||
|
"Other examples of embedding request",
|
||||||
|
},
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
// marshal embeddingReq to JSON and confirm that the model field equals
|
||||||
|
// the AdaSearchQuery type
|
||||||
|
marshaled, err := json.Marshal(embeddingReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not marshal embedding request: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
||||||
|
t.Fatalf("Expected embedding request to contain model field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,13 +4,12 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CompletionRequest represents a request structure for completion API
|
// CompletionRequest represents a request structure for completion API
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Model *string `json:"model,omitempty"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
@@ -60,29 +59,12 @@ type CompletionResponse struct {
|
|||||||
Usage CompletionUsage `json:"usage"`
|
Usage CompletionUsage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position.
|
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well
|
||||||
func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) {
|
// as, if requested, the probabilities over each alternative token at each position.
|
||||||
var reqBytes []byte
|
//
|
||||||
reqBytes, err = json.Marshal(request)
|
// If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object,
|
||||||
if err != nil {
|
// and the server will use the model's parameters to generate the completion.
|
||||||
return
|
func (c *Client) CreateCompletion(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) {
|
||||||
}
|
|
||||||
|
|
||||||
urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID)
|
|
||||||
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
err = c.sendRequest(req, &response)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateCompletionWithFineTunedModel - API call to create a completion with a fine tuned model
|
|
||||||
// See https://beta.openai.com/docs/guides/fine-tuning/use-a-fine-tuned-model
|
|
||||||
// In this case, the model is specified in the CompletionRequest object.
|
|
||||||
func (c *Client) CreateCompletionWithFineTunedModel(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) {
|
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -120,18 +119,23 @@ type EmbeddingRequest struct {
|
|||||||
// E.g.
|
// E.g.
|
||||||
// "The food was delicious and the waiter..."
|
// "The food was delicious and the waiter..."
|
||||||
Input []string `json:"input"`
|
Input []string `json:"input"`
|
||||||
|
// ID of the model to use. You can use the List models API to see all of your available models,
|
||||||
|
// or see our Model overview for descriptions of them.
|
||||||
|
Model EmbeddingModel `json:"model"`
|
||||||
|
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||||
|
User string `json:"user"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
||||||
// https://beta.openai.com/docs/api-reference/embeddings/create
|
// https://beta.openai.com/docs/api-reference/embeddings/create
|
||||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest, model EmbeddingModel) (resp EmbeddingResponse, err error) {
|
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
urlSuffix := fmt.Sprintf("/engines/%s/embeddings", model)
|
urlSuffix := "/embeddings"
|
||||||
req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|||||||
10
search.go
10
search.go
@@ -17,11 +17,11 @@ import (
|
|||||||
*/
|
*/
|
||||||
type SearchRequest struct {
|
type SearchRequest struct {
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
Documents []string `json:"documents"` // 1*
|
Documents []string `json:"documents"` // 1*
|
||||||
FileID string `json:"file"` // 1*
|
FileID string `json:"file,omitempty"` // 1*
|
||||||
MaxRerank int `json:"max_rerank"` // 2*
|
MaxRerank int `json:"max_rerank,omitempty"` // 2*
|
||||||
ReturnMetadata bool `json:"return_metadata"`
|
ReturnMetadata bool `json:"return_metadata,omitempty"`
|
||||||
User string `json:"user"`
|
User string `json:"user,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchResult represents single result from search API
|
// SearchResult represents single result from search API
|
||||||
|
|||||||
Reference in New Issue
Block a user