Better configuration (#79)
* Configurable Transport (#75) * new functions to allow HTTPClient configuration * updated go.mod for testing from remote * updated go.mod for remote testing * revert go.mod replace directives * Fixed NewOrgClientWithTransport comment * Make client fully configurable * make empty messages limit configurable #70 #71 * make auth token private in config * add docs * lint --------- Co-authored-by: Michael Fox <m.will.fox@gmail.com>
This commit is contained in:
43
api.go
43
api.go
@@ -6,43 +6,34 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
const apiURLv1 = "https://api.openai.com/v1"
|
|
||||||
|
|
||||||
func newTransport() *http.Client {
|
|
||||||
return &http.Client{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client is OpenAI GPT-3 API client.
|
// Client is OpenAI GPT-3 API client.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
BaseURL string
|
config ClientConfig
|
||||||
HTTPClient *http.Client
|
|
||||||
authToken string
|
|
||||||
idOrg string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates new OpenAI API client.
|
// NewClient creates new OpenAI API client.
|
||||||
func NewClient(authToken string) *Client {
|
func NewClient(authToken string) *Client {
|
||||||
return &Client{
|
config := DefaultConfig(authToken)
|
||||||
BaseURL: apiURLv1,
|
return &Client{config}
|
||||||
HTTPClient: newTransport(),
|
|
||||||
authToken: authToken,
|
|
||||||
idOrg: "",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewClientWithConfig creates new OpenAI API client for specified config.
|
||||||
|
func NewClientWithConfig(config ClientConfig) *Client {
|
||||||
|
return &Client{config}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOrgClient creates new OpenAI API client for specified Organization ID.
|
// NewOrgClient creates new OpenAI API client for specified Organization ID.
|
||||||
|
//
|
||||||
|
// Deprecated: Please use NewClientWithConfig.
|
||||||
func NewOrgClient(authToken, org string) *Client {
|
func NewOrgClient(authToken, org string) *Client {
|
||||||
return &Client{
|
config := DefaultConfig(authToken)
|
||||||
BaseURL: apiURLv1,
|
config.OrgID = org
|
||||||
HTTPClient: newTransport(),
|
return &Client{config}
|
||||||
authToken: authToken,
|
|
||||||
idOrg: org,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||||
req.Header.Set("Accept", "application/json; charset=utf-8")
|
req.Header.Set("Accept", "application/json; charset=utf-8")
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||||
|
|
||||||
// Check whether Content-Type is already set, Upload Files API requires
|
// Check whether Content-Type is already set, Upload Files API requires
|
||||||
// Content-Type == multipart/form-data
|
// Content-Type == multipart/form-data
|
||||||
@@ -51,11 +42,11 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
|||||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.idOrg) > 0 {
|
if len(c.config.OrgID) > 0 {
|
||||||
req.Header.Set("OpenAI-Organization", c.idOrg)
|
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := c.HTTPClient.Do(req)
|
res, err := c.config.HTTPClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -86,5 +77,5 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) fullURL(suffix string) string {
|
func (c *Client) fullURL(suffix string) string {
|
||||||
return fmt.Sprintf("%s%s", c.BaseURL, suffix)
|
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -110,8 +110,10 @@ func TestAPIError(t *testing.T) {
|
|||||||
|
|
||||||
func TestRequestError(t *testing.T) {
|
func TestRequestError(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
c := NewClient("dummy")
|
|
||||||
c.BaseURL = "https://httpbin.org/status/418?"
|
config := DefaultConfig("dummy")
|
||||||
|
config.BaseURL = "https://httpbin.org/status/418?"
|
||||||
|
c := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
_, err = c.ListEngines(ctx)
|
_, err = c.ListEngines(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ func TestCompletions(t *testing.T) {
|
|||||||
ts.Start()
|
ts.Start()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
client := NewClient(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client.BaseURL = ts.URL + "/v1"
|
|
||||||
|
|
||||||
req := CompletionRequest{
|
req := CompletionRequest{
|
||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
|
|||||||
33
config.go
Normal file
33
config.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package gogpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
apiURLv1 = "https://api.openai.com/v1"
|
||||||
|
defaultEmptyMessagesLimit uint = 300
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientConfig is a configuration of a client.
|
||||||
|
type ClientConfig struct {
|
||||||
|
authToken string
|
||||||
|
|
||||||
|
HTTPClient *http.Client
|
||||||
|
|
||||||
|
BaseURL string
|
||||||
|
OrgID string
|
||||||
|
|
||||||
|
EmptyMessagesLimit uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultConfig(authToken string) ClientConfig {
|
||||||
|
return ClientConfig{
|
||||||
|
HTTPClient: &http.Client{},
|
||||||
|
BaseURL: apiURLv1,
|
||||||
|
OrgID: "",
|
||||||
|
authToken: authToken,
|
||||||
|
|
||||||
|
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,9 +23,10 @@ func TestEdits(t *testing.T) {
|
|||||||
ts.Start()
|
ts.Start()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
client := NewClient(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client.BaseURL = ts.URL + "/v1"
|
|
||||||
|
|
||||||
// create an edit request
|
// create an edit request
|
||||||
model := "ada"
|
model := "ada"
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ func TestFileUpload(t *testing.T) {
|
|||||||
ts.Start()
|
ts.Start()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
client := NewClient(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client.BaseURL = ts.URL + "/v1"
|
|
||||||
|
|
||||||
req := FileRequest{
|
req := FileRequest{
|
||||||
FileName: "test.go",
|
FileName: "test.go",
|
||||||
|
|||||||
@@ -23,9 +23,10 @@ func TestImages(t *testing.T) {
|
|||||||
ts.Start()
|
ts.Start()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
client := NewClient(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client.BaseURL = ts.URL + "/v1"
|
|
||||||
|
|
||||||
req := ImageRequest{}
|
req := ImageRequest{}
|
||||||
req.Prompt = "Lorem ipsum"
|
req.Prompt = "Lorem ipsum"
|
||||||
@@ -94,9 +95,10 @@ func TestImageEdit(t *testing.T) {
|
|||||||
ts.Start()
|
ts.Start()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
client := NewClient(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client.BaseURL = ts.URL + "/v1"
|
|
||||||
|
|
||||||
origin, err := os.Create("image.png")
|
origin, err := os.Create("image.png")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ func TestModerations(t *testing.T) {
|
|||||||
ts.Start()
|
ts.Start()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
client := NewClient(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client.BaseURL = ts.URL + "/v1"
|
|
||||||
|
|
||||||
// create an edit request
|
// create an edit request
|
||||||
model := "text-moderation-stable"
|
model := "text-moderation-stable"
|
||||||
|
|||||||
13
stream.go
13
stream.go
@@ -11,17 +11,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
emptyMessagesLimit = 300
|
|
||||||
ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
|
ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
|
||||||
)
|
)
|
||||||
|
|
||||||
type CompletionStream struct {
|
type CompletionStream struct {
|
||||||
|
emptyMessagesLimit uint
|
||||||
|
|
||||||
reader *bufio.Reader
|
reader *bufio.Reader
|
||||||
response *http.Response
|
response *http.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
|
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
|
||||||
emptyMessagesCount := 0
|
var emptyMessagesCount uint
|
||||||
|
|
||||||
waitForData:
|
waitForData:
|
||||||
line, err := stream.reader.ReadBytes('\n')
|
line, err := stream.reader.ReadBytes('\n')
|
||||||
@@ -33,7 +34,7 @@ waitForData:
|
|||||||
line = bytes.TrimSpace(line)
|
line = bytes.TrimSpace(line)
|
||||||
if !bytes.HasPrefix(line, headerData) {
|
if !bytes.HasPrefix(line, headerData) {
|
||||||
emptyMessagesCount++
|
emptyMessagesCount++
|
||||||
if emptyMessagesCount > emptyMessagesLimit {
|
if emptyMessagesCount > stream.emptyMessagesLimit {
|
||||||
err = ErrTooManyEmptyStreamMessages
|
err = ErrTooManyEmptyStreamMessages
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -74,18 +75,20 @@ func (c *Client) CreateCompletionStream(
|
|||||||
req.Header.Set("Accept", "text/event-stream")
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
req.Header.Set("Cache-Control", "no-cache")
|
req.Header.Set("Cache-Control", "no-cache")
|
||||||
req.Header.Set("Connection", "keep-alive")
|
req.Header.Set("Connection", "keep-alive")
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
resp, err := c.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stream = &CompletionStream{
|
stream = &CompletionStream{
|
||||||
|
emptyMessagesLimit: c.config.EmptyMessagesLimit,
|
||||||
|
|
||||||
reader: bufio.NewReader(resp.Body),
|
reader: bufio.NewReader(resp.Body),
|
||||||
response: resp,
|
response: resp,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,9 +37,15 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
// Client portion of the test
|
// Client portion of the test
|
||||||
client := NewClient(test.GetTestToken())
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = server.URL + "/v1"
|
||||||
|
config.HTTPClient.Transport = &tokenRoundTripper{
|
||||||
|
test.GetTestToken(),
|
||||||
|
http.DefaultTransport,
|
||||||
|
}
|
||||||
|
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
client.BaseURL = server.URL + "/v1"
|
|
||||||
|
|
||||||
request := CompletionRequest{
|
request := CompletionRequest{
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
@@ -48,11 +54,6 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
client.HTTPClient.Transport = &tokenRoundTripper{
|
|
||||||
test.GetTestToken(),
|
|
||||||
http.DefaultTransport,
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(ctx, request)
|
stream, err := client.CreateCompletionStream(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("CreateCompletionStream returned error: %v", err)
|
t.Errorf("CreateCompletionStream returned error: %v", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user