289 lines
7.4 KiB
Go
289 lines
7.4 KiB
Go
package openai
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
utils "github.com/sashabaranov/go-openai/internal"
|
|
)
|
|
|
|
// Client is OpenAI GPT-3 API client.
|
|
type Client struct {
|
|
config ClientConfig
|
|
|
|
requestBuilder utils.RequestBuilder
|
|
createFormBuilder func(io.Writer) utils.FormBuilder
|
|
}
|
|
|
|
type Response interface {
|
|
SetHeader(http.Header)
|
|
}
|
|
|
|
type httpHeader http.Header
|
|
|
|
func (h *httpHeader) SetHeader(header http.Header) {
|
|
*h = httpHeader(header)
|
|
}
|
|
|
|
func (h *httpHeader) Header() http.Header {
|
|
return http.Header(*h)
|
|
}
|
|
|
|
func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders {
|
|
return newRateLimitHeaders(h.Header())
|
|
}
|
|
|
|
type RawResponse struct {
|
|
io.ReadCloser
|
|
|
|
httpHeader
|
|
}
|
|
|
|
// NewClient creates new OpenAI API client.
|
|
func NewClient(authToken string) *Client {
|
|
config := DefaultConfig(authToken)
|
|
return NewClientWithConfig(config)
|
|
}
|
|
|
|
// NewClientWithConfig creates new OpenAI API client for specified config.
|
|
func NewClientWithConfig(config ClientConfig) *Client {
|
|
return &Client{
|
|
config: config,
|
|
requestBuilder: utils.NewRequestBuilder(),
|
|
createFormBuilder: func(body io.Writer) utils.FormBuilder {
|
|
return utils.NewFormBuilder(body)
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewOrgClient creates new OpenAI API client for specified Organization ID.
|
|
//
|
|
// Deprecated: Please use NewClientWithConfig.
|
|
func NewOrgClient(authToken, org string) *Client {
|
|
config := DefaultConfig(authToken)
|
|
config.OrgID = org
|
|
return NewClientWithConfig(config)
|
|
}
|
|
|
|
type requestOptions struct {
|
|
body any
|
|
header http.Header
|
|
}
|
|
|
|
type requestOption func(*requestOptions)
|
|
|
|
func withBody(body any) requestOption {
|
|
return func(args *requestOptions) {
|
|
args.body = body
|
|
}
|
|
}
|
|
|
|
func withContentType(contentType string) requestOption {
|
|
return func(args *requestOptions) {
|
|
args.header.Set("Content-Type", contentType)
|
|
}
|
|
}
|
|
|
|
func withBetaAssistantVersion(version string) requestOption {
|
|
return func(args *requestOptions) {
|
|
args.header.Set("OpenAI-Beta", fmt.Sprintf("assistants=%s", version))
|
|
}
|
|
}
|
|
|
|
func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) {
|
|
// Default Options
|
|
args := &requestOptions{
|
|
body: nil,
|
|
header: make(http.Header),
|
|
}
|
|
for _, setter := range setters {
|
|
setter(args)
|
|
}
|
|
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.setCommonHeaders(req)
|
|
return req, nil
|
|
}
|
|
|
|
func (c *Client) sendRequest(req *http.Request, v Response) error {
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
// Check whether Content-Type is already set, Upload Files API requires
|
|
// Content-Type == multipart/form-data
|
|
contentType := req.Header.Get("Content-Type")
|
|
if contentType == "" {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
res, err := c.config.HTTPClient.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer res.Body.Close()
|
|
|
|
if v != nil {
|
|
v.SetHeader(res.Header)
|
|
}
|
|
|
|
if isFailureStatusCode(res) {
|
|
return c.handleErrorResp(res)
|
|
}
|
|
|
|
return decodeResponse(res.Body, v)
|
|
}
|
|
|
|
func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err error) {
|
|
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body should be closed by outer function
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if isFailureStatusCode(resp) {
|
|
err = c.handleErrorResp(resp)
|
|
return
|
|
}
|
|
|
|
response.SetHeader(resp.Header)
|
|
response.ReadCloser = resp.Body
|
|
return
|
|
}
|
|
|
|
func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
req.Header.Set("Cache-Control", "no-cache")
|
|
req.Header.Set("Connection", "keep-alive")
|
|
|
|
resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
|
if err != nil {
|
|
return new(streamReader[T]), err
|
|
}
|
|
if isFailureStatusCode(resp) {
|
|
return new(streamReader[T]), client.handleErrorResp(resp)
|
|
}
|
|
return &streamReader[T]{
|
|
emptyMessagesLimit: client.config.EmptyMessagesLimit,
|
|
reader: bufio.NewReader(resp.Body),
|
|
response: resp,
|
|
errAccumulator: utils.NewErrorAccumulator(),
|
|
unmarshaler: &utils.JSONUnmarshaler{},
|
|
httpHeader: httpHeader(resp.Header),
|
|
}, nil
|
|
}
|
|
|
|
func (c *Client) setCommonHeaders(req *http.Request) {
|
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
|
// Azure API Key authentication
|
|
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
|
|
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
|
} else if c.config.authToken != "" {
|
|
// OpenAI or Azure AD authentication
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
|
}
|
|
if c.config.OrgID != "" {
|
|
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
|
}
|
|
}
|
|
|
|
func isFailureStatusCode(resp *http.Response) bool {
|
|
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
|
|
}
|
|
|
|
func decodeResponse(body io.Reader, v any) error {
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
|
|
switch o := v.(type) {
|
|
case *string:
|
|
return decodeString(body, o)
|
|
case *audioTextResponse:
|
|
return decodeString(body, &o.Text)
|
|
default:
|
|
return json.NewDecoder(body).Decode(v)
|
|
}
|
|
}
|
|
|
|
func decodeString(body io.Reader, output *string) error {
|
|
b, err := io.ReadAll(body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*output = string(b)
|
|
return nil
|
|
}
|
|
|
|
// fullURL returns full URL for request.
|
|
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
|
|
func (c *Client) fullURL(suffix string, args ...any) string {
|
|
// /openai/deployments/{model}/chat/completions?api-version={api_version}
|
|
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
|
|
baseURL := c.config.BaseURL
|
|
baseURL = strings.TrimRight(baseURL, "/")
|
|
parseURL, _ := url.Parse(baseURL)
|
|
query := parseURL.Query()
|
|
query.Add("api-version", c.config.APIVersion)
|
|
// if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01
|
|
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
|
|
if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) {
|
|
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode())
|
|
}
|
|
azureDeploymentName := "UNKNOWN"
|
|
if len(args) > 0 {
|
|
model, ok := args[0].(string)
|
|
if ok {
|
|
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
|
|
}
|
|
}
|
|
return fmt.Sprintf("%s/%s/%s/%s%s?%s",
|
|
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
|
|
azureDeploymentName, suffix, query.Encode(),
|
|
)
|
|
}
|
|
|
|
// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
|
|
if c.config.APIType == APITypeCloudflareAzure {
|
|
baseURL := c.config.BaseURL
|
|
baseURL = strings.TrimRight(baseURL, "/")
|
|
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
|
|
}
|
|
|
|
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
|
|
}
|
|
|
|
func (c *Client) handleErrorResp(resp *http.Response) error {
|
|
var errRes ErrorResponse
|
|
err := json.NewDecoder(resp.Body).Decode(&errRes)
|
|
if err != nil || errRes.Error == nil {
|
|
reqErr := &RequestError{
|
|
HTTPStatusCode: resp.StatusCode,
|
|
Err: err,
|
|
}
|
|
if errRes.Error != nil {
|
|
reqErr.Err = errRes.Error
|
|
}
|
|
return reqErr
|
|
}
|
|
|
|
errRes.Error.HTTPStatusCode = resp.StatusCode
|
|
return errRes.Error
|
|
}
|
|
|
|
func containsSubstr(s []string, e string) bool {
|
|
for _, v := range s {
|
|
if strings.Contains(e, v) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|