Add api client code

This commit is contained in:
Alexander Baranov
2020-08-19 12:57:32 +03:00
parent d62f32901b
commit f5b3ec4ffe
6 changed files with 245 additions and 0 deletions

3
.gitignore vendored
View File

@@ -13,3 +13,6 @@
# Dependency directories (remove the comment below to include it) # Dependency directories (remove the comment below to include it)
# vendor/ # vendor/
# Auth token for tests
.openai-token

55
api.go Normal file
View File

@@ -0,0 +1,55 @@
package gogpt
import (
"encoding/json"
"fmt"
"net/http"
"time"
)
const apiURLv1 = "https://api.openai.com/v1"
// Client is OpenAI GPT-3 API client
type Client struct {
BaseURL string
authToken string
HTTPClient *http.Client
}
// NewClient creates new OpenAI API client
func NewClient(authToken string) *Client {
return &Client{
BaseURL: apiURLv1,
authToken: authToken,
HTTPClient: &http.Client{
Timeout: time.Minute,
},
}
}
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json; charset=utf-8")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
res, err := c.HTTPClient.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
return fmt.Errorf("error, status code: %d", res.StatusCode)
}
if err = json.NewDecoder(res.Body).Decode(&v); err != nil {
return err
}
return nil
}
func (c *Client) fullURL(suffix string) string {
return fmt.Sprintf("%s%s", c.BaseURL, suffix)
}

42
api_test.go Normal file
View File

@@ -0,0 +1,42 @@
package gogpt
import (
"context"
"io/ioutil"
"testing"
)
func TestAPI(t *testing.T) {
tokenBytes, err := ioutil.ReadFile(".openai-token")
if err != nil {
t.Fatalf("Could not load auth token from .openai-token file")
}
c := NewClient(string(tokenBytes))
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err != nil {
t.Fatalf("ListEngines error: %v", err)
}
_, err = c.GetEngine(ctx, "davinci")
if err != nil {
t.Fatalf("GetEngine error: %v", err)
}
req := CompletionRequest{MaxTokens: 5}
req.Prompt = "Lorem ipsum"
_, err = c.CreateCompletion(ctx, "ada", req)
if err != nil {
t.Fatalf("CreateCompletion error: %v", err)
}
searchReq := SearchRequest{
Documents: []string{"White House", "hospital", "school"},
Query: "the president",
}
_, err = c.Search(ctx, "ada", searchReq)
if err != nil {
t.Fatalf("Search error: %v", err)
}
}

60
completion.go Normal file
View File

@@ -0,0 +1,60 @@
package gogpt
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
type CompletionRequest struct {
Prompt string `json:"prompt,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
LogProbs int `json:"logobs,omitempty"`
Echo bool `json:"echo,omitempty"`
Stop string `json:"stop,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
}
type Choice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
type CompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created uint64 `json:"created"`
Model string `json:"model"`
Сhoices []Choice `json:"choices"`
}
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position.
func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}
urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID)
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
if err != nil {
return
}
req = req.WithContext(ctx)
err = c.sendRequest(req, &response)
return
}

43
engines.go Normal file
View File

@@ -0,0 +1,43 @@
package gogpt
import (
"context"
"fmt"
"net/http"
)
type Engine struct {
ID string `json:"id"`
Object string `json:"object"`
Owner string `json:"owner"`
Ready bool `json:"ready"`
}
type EnginesList struct {
Engines []Engine `json:"data"`
}
// ListEngines Lists the currently available engines, and provides basic information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
req, err := http.NewRequest("GET", c.fullURL("/engines"), nil)
if err != nil {
return
}
req = req.WithContext(ctx)
err = c.sendRequest(req, &engines)
return
}
// GetEngine Retrieves an engine instance, providing basic information about the engine such as the owner and availability.
func (c *Client) GetEngine(ctx context.Context, engineID string) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil)
if err != nil {
return
}
req = req.WithContext(ctx)
err = c.sendRequest(req, &engine)
return
}

42
search.go Normal file
View File

@@ -0,0 +1,42 @@
package gogpt
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
type SearchRequest struct {
Documents []string `json:"documents"`
Query string `json:"query"`
}
type SearchResult struct {
Document int `json:"document"`
Score float32 `json:"score"`
}
type SearchResponse struct {
SearchResults []SearchResult `json:"data"`
}
// Search — perform a semantic search api call over a list of documents.
func (c *Client) Search(ctx context.Context, engineID string, request SearchRequest) (response SearchResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}
urlSuffix := fmt.Sprintf("/engines/%s/search", engineID)
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
if err != nil {
return
}
req = req.WithContext(ctx)
err = c.sendRequest(req, &response)
return
}