From f5b3ec4ffe3b23bad05fa43209b7d137768b27de Mon Sep 17 00:00:00 2001 From: Alexander Baranov Date: Wed, 19 Aug 2020 12:57:32 +0300 Subject: [PATCH] Add api client code --- .gitignore | 3 +++ api.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++ api_test.go | 42 ++++++++++++++++++++++++++++++++++++ completion.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++ engines.go | 43 ++++++++++++++++++++++++++++++++++++ search.go | 42 ++++++++++++++++++++++++++++++++++++ 6 files changed, 245 insertions(+) create mode 100644 api.go create mode 100644 api_test.go create mode 100644 completion.go create mode 100644 engines.go create mode 100644 search.go diff --git a/.gitignore b/.gitignore index 66fd13c..42385aa 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ # Dependency directories (remove the comment below to include it) # vendor/ + +# Auth token for tests +.openai-token \ No newline at end of file diff --git a/api.go b/api.go new file mode 100644 index 0000000..fec3af0 --- /dev/null +++ b/api.go @@ -0,0 +1,55 @@ +package gogpt + +import ( + "encoding/json" + "fmt" + "net/http" + "time" +) + +const apiURLv1 = "https://api.openai.com/v1" + +// Client is OpenAI GPT-3 API client +type Client struct { + BaseURL string + authToken string + HTTPClient *http.Client +} + +// NewClient creates new OpenAI API client +func NewClient(authToken string) *Client { + return &Client{ + BaseURL: apiURLv1, + authToken: authToken, + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, + } +} + +func (c *Client) sendRequest(req *http.Request, v interface{}) error { + req.Header.Set("Content-Type", "application/json; charset=utf-8") + req.Header.Set("Accept", "application/json; charset=utf-8") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken)) + + res, err := c.HTTPClient.Do(req) + if err != nil { + return err + } + + defer res.Body.Close() + + if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { + return fmt.Errorf("error, status code: %d", res.StatusCode) + } + + if err = json.NewDecoder(res.Body).Decode(&v); err != nil { + return err + } + + return nil +} + +func (c *Client) fullURL(suffix string) string { + return fmt.Sprintf("%s%s", c.BaseURL, suffix) +} diff --git a/api_test.go b/api_test.go new file mode 100644 index 0000000..ccccccd --- /dev/null +++ b/api_test.go @@ -0,0 +1,42 @@ +package gogpt + +import ( + "context" + "io/ioutil" + "testing" +) + +func TestAPI(t *testing.T) { + tokenBytes, err := ioutil.ReadFile(".openai-token") + if err != nil { + t.Fatalf("Could not load auth token from .openai-token file") + } + + c := NewClient(string(tokenBytes)) + ctx := context.Background() + _, err = c.ListEngines(ctx) + if err != nil { + t.Fatalf("ListEngines error: %v", err) + } + + _, err = c.GetEngine(ctx, "davinci") + if err != nil { + t.Fatalf("GetEngine error: %v", err) + } + + req := CompletionRequest{MaxTokens: 5} + req.Prompt = "Lorem ipsum" + _, err = c.CreateCompletion(ctx, "ada", req) + if err != nil { + t.Fatalf("CreateCompletion error: %v", err) + } + + searchReq := SearchRequest{ + Documents: []string{"White House", "hospital", "school"}, + Query: "the president", + } + _, err = c.Search(ctx, "ada", searchReq) + if err != nil { + t.Fatalf("Search error: %v", err) + } +} diff --git a/completion.go b/completion.go new file mode 100644 index 0000000..e0074a9 --- /dev/null +++ b/completion.go @@ -0,0 +1,60 @@ +package gogpt + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" +) + +type CompletionRequest struct { + Prompt string `json:"prompt,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + + N int `json:"n,omitempty"` + + LogProbs int `json:"logobs,omitempty"` + + Echo bool `json:"echo,omitempty"` + Stop string `json:"stop,omitempty"` + + PresencePenalty float32 `json:"presence_penalty,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` +} + +type Choice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` +} + +type CompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created uint64 `json:"created"` + Model string `json:"model"` + Сhoices []Choice `json:"choices"` +} + +// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position. +func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) { + var reqBytes []byte + reqBytes, err = json.Marshal(request) + if err != nil { + return + } + + urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID) + req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + if err != nil { + return + } + + req = req.WithContext(ctx) + err = c.sendRequest(req, &response) + return +} diff --git a/engines.go b/engines.go new file mode 100644 index 0000000..4af32d1 --- /dev/null +++ b/engines.go @@ -0,0 +1,43 @@ +package gogpt + +import ( + "context" + "fmt" + "net/http" +) + +type Engine struct { + ID string `json:"id"` + Object string `json:"object"` + Owner string `json:"owner"` + Ready bool `json:"ready"` +} + +type EnginesList struct { + Engines []Engine `json:"data"` +} + +// ListEngines Lists the currently available engines, and provides basic information about each option such as the owner and availability. +func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { + req, err := http.NewRequest("GET", c.fullURL("/engines"), nil) + if err != nil { + return + } + + req = req.WithContext(ctx) + err = c.sendRequest(req, &engines) + return +} + +// GetEngine Retrieves an engine instance, providing basic information about the engine such as the owner and availability. +func (c *Client) GetEngine(ctx context.Context, engineID string) (engine Engine, err error) { + urlSuffix := fmt.Sprintf("/engines/%s", engineID) + req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil) + if err != nil { + return + } + + req = req.WithContext(ctx) + err = c.sendRequest(req, &engine) + return +} diff --git a/search.go b/search.go new file mode 100644 index 0000000..1ee4cb2 --- /dev/null +++ b/search.go @@ -0,0 +1,42 @@ +package gogpt + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" +) + +type SearchRequest struct { + Documents []string `json:"documents"` + Query string `json:"query"` +} + +type SearchResult struct { + Document int `json:"document"` + Score float32 `json:"score"` +} + +type SearchResponse struct { + SearchResults []SearchResult `json:"data"` +} + +// Search — perform a semantic search api call over a list of documents. +func (c *Client) Search(ctx context.Context, engineID string, request SearchRequest) (response SearchResponse, err error) { + var reqBytes []byte + reqBytes, err = json.Marshal(request) + if err != nil { + return + } + + urlSuffix := fmt.Sprintf("/engines/%s/search", engineID) + req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + if err != nil { + return + } + + req = req.WithContext(ctx) + err = c.sendRequest(req, &response) + return +}