Add api client code
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -13,3 +13,6 @@
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
# Auth token for tests
|
||||
.openai-token
|
||||
55
api.go
Normal file
55
api.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package gogpt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const apiURLv1 = "https://api.openai.com/v1"
|
||||
|
||||
// Client is OpenAI GPT-3 API client
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
authToken string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates new OpenAI API client
|
||||
func NewClient(authToken string) *Client {
|
||||
return &Client{
|
||||
BaseURL: apiURLv1,
|
||||
authToken: authToken,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("Accept", "application/json; charset=utf-8")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
|
||||
|
||||
res, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
|
||||
return fmt.Errorf("error, status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
if err = json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) fullURL(suffix string) string {
|
||||
return fmt.Sprintf("%s%s", c.BaseURL, suffix)
|
||||
}
|
||||
42
api_test.go
Normal file
42
api_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package gogpt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAPI(t *testing.T) {
|
||||
tokenBytes, err := ioutil.ReadFile(".openai-token")
|
||||
if err != nil {
|
||||
t.Fatalf("Could not load auth token from .openai-token file")
|
||||
}
|
||||
|
||||
c := NewClient(string(tokenBytes))
|
||||
ctx := context.Background()
|
||||
_, err = c.ListEngines(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListEngines error: %v", err)
|
||||
}
|
||||
|
||||
_, err = c.GetEngine(ctx, "davinci")
|
||||
if err != nil {
|
||||
t.Fatalf("GetEngine error: %v", err)
|
||||
}
|
||||
|
||||
req := CompletionRequest{MaxTokens: 5}
|
||||
req.Prompt = "Lorem ipsum"
|
||||
_, err = c.CreateCompletion(ctx, "ada", req)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCompletion error: %v", err)
|
||||
}
|
||||
|
||||
searchReq := SearchRequest{
|
||||
Documents: []string{"White House", "hospital", "school"},
|
||||
Query: "the president",
|
||||
}
|
||||
_, err = c.Search(ctx, "ada", searchReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Search error: %v", err)
|
||||
}
|
||||
}
|
||||
60
completion.go
Normal file
60
completion.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package gogpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
|
||||
N int `json:"n,omitempty"`
|
||||
|
||||
LogProbs int `json:"logobs,omitempty"`
|
||||
|
||||
Echo bool `json:"echo,omitempty"`
|
||||
Stop string `json:"stop,omitempty"`
|
||||
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Text string `json:"text"`
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created uint64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Сhoices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position.
|
||||
func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) {
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID)
|
||||
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
err = c.sendRequest(req, &response)
|
||||
return
|
||||
}
|
||||
43
engines.go
Normal file
43
engines.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package gogpt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Engine struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Owner string `json:"owner"`
|
||||
Ready bool `json:"ready"`
|
||||
}
|
||||
|
||||
type EnginesList struct {
|
||||
Engines []Engine `json:"data"`
|
||||
}
|
||||
|
||||
// ListEngines Lists the currently available engines, and provides basic information about each option such as the owner and availability.
|
||||
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
||||
req, err := http.NewRequest("GET", c.fullURL("/engines"), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
err = c.sendRequest(req, &engines)
|
||||
return
|
||||
}
|
||||
|
||||
// GetEngine Retrieves an engine instance, providing basic information about the engine such as the owner and availability.
|
||||
func (c *Client) GetEngine(ctx context.Context, engineID string) (engine Engine, err error) {
|
||||
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
||||
req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
err = c.sendRequest(req, &engine)
|
||||
return
|
||||
}
|
||||
42
search.go
Normal file
42
search.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package gogpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type SearchRequest struct {
|
||||
Documents []string `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
type SearchResult struct {
|
||||
Document int `json:"document"`
|
||||
Score float32 `json:"score"`
|
||||
}
|
||||
|
||||
type SearchResponse struct {
|
||||
SearchResults []SearchResult `json:"data"`
|
||||
}
|
||||
|
||||
// Search — perform a semantic search api call over a list of documents.
|
||||
func (c *Client) Search(ctx context.Context, engineID string, request SearchRequest) (response SearchResponse, err error) {
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
urlSuffix := fmt.Sprintf("/engines/%s/search", engineID)
|
||||
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
err = c.sendRequest(req, &response)
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user