fix: fullURL endpoint generation (#817)

This commit is contained in:
eiixy
2024-08-17 01:11:38 +08:00
committed by GitHub
parent 2c6889e081
commit dd7f5824f9
14 changed files with 244 additions and 51 deletions

View File

@@ -112,6 +112,7 @@ func TestAzureFullURL(t *testing.T) {
Name string Name string
BaseURL string BaseURL string
AzureModelMapper map[string]string AzureModelMapper map[string]string
Suffix string
Model string Model string
Expect string Expect string
}{ }{
@@ -119,6 +120,7 @@ func TestAzureFullURL(t *testing.T) {
"AzureBaseURLWithSlashAutoStrip", "AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/", "https://httpbin.org/",
nil, nil,
"/chat/completions",
"chatgpt-demo", "chatgpt-demo",
"https://httpbin.org/" + "https://httpbin.org/" +
"openai/deployments/chatgpt-demo" + "openai/deployments/chatgpt-demo" +
@@ -128,11 +130,20 @@ func TestAzureFullURL(t *testing.T) {
"AzureBaseURLWithoutSlashOK", "AzureBaseURLWithoutSlashOK",
"https://httpbin.org", "https://httpbin.org",
nil, nil,
"/chat/completions",
"chatgpt-demo", "chatgpt-demo",
"https://httpbin.org/" + "https://httpbin.org/" +
"openai/deployments/chatgpt-demo" + "openai/deployments/chatgpt-demo" +
"/chat/completions?api-version=2023-05-15", "/chat/completions?api-version=2023-05-15",
}, },
{
"",
"https://httpbin.org",
nil,
"/assistants?limit=10",
"chatgpt-demo",
"https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10",
},
} }
for _, c := range cases { for _, c := range cases {
@@ -140,7 +151,7 @@ func TestAzureFullURL(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL) az := DefaultAzureConfig("dummy", c.BaseURL)
cli := NewClientWithConfig(az) cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version} // /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions", c.Model) actual := cli.fullURL(c.Suffix, withModel(c.Model))
if actual != c.Expect { if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual) t.Errorf("Expected %s, got %s", c.Expect, actual)
} }
@@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) {
cases := []struct { cases := []struct {
Name string Name string
BaseURL string BaseURL string
Suffix string
Expect string Expect string
}{ }{
{ {
"CloudflareAzureBaseURLWithSlashAutoStrip", "CloudflareAzureBaseURLWithSlashAutoStrip",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/",
"/chat/completions",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15", "chat/completions?api-version=2023-05-15",
}, },
{ {
"CloudflareAzureBaseURLWithoutSlashOK", "",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "/assistants?limit=10",
"chat/completions?api-version=2023-05-15", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" +
"/assistants?api-version=2023-05-15&limit=10",
}, },
} }
@@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) {
cli := NewClientWithConfig(az) cli := NewClientWithConfig(az)
actual := cli.fullURL("/chat/completions") actual := cli.fullURL(c.Suffix)
if actual != c.Expect { if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual) t.Errorf("Expected %s, got %s", c.Expect, actual)
} }

View File

@@ -122,8 +122,13 @@ func (c *Client) callAudioAPI(
} }
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), req, err := c.newRequest(
withBody(&formBody), withContentType(builder.FormDataContentType())) ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(&formBody),
withContentType(builder.FormDataContentType()),
)
if err != nil { if err != nil {
return AudioResponse{}, err return AudioResponse{}, err
} }

View File

@@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion(
return return
} }
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil { if err != nil {
return return
} }

View File

@@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream(
} }
request.Stream = true request.Stream = true
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error {
return nil return nil
} }
type fullURLOptions struct {
model string
}
type fullURLOption func(*fullURLOptions)
func withModel(model string) fullURLOption {
return func(args *fullURLOptions) {
args.model = model
}
}
var azureDeploymentsEndpoints = []string{
"/completions",
"/embeddings",
"/chat/completions",
"/audio/transcriptions",
"/audio/translations",
"/audio/speech",
"/images/generations",
}
// fullURL returns full URL for request. // fullURL returns full URL for request.
// args[0] is model name, if API type is Azure, model name is required to get deployment name. func (c *Client) fullURL(suffix string, setters ...fullURLOption) string {
func (c *Client) fullURL(suffix string, args ...any) string { baseURL := strings.TrimRight(c.config.BaseURL, "/")
// /openai/deployments/{model}/chat/completions?api-version={api_version} args := fullURLOptions{}
for _, setter := range setters {
setter(&args)
}
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model)
baseURL = strings.TrimRight(baseURL, "/") }
parseURL, _ := url.Parse(baseURL)
query := parseURL.Query() if c.config.APIVersion != "" {
suffix = c.suffixWithAPIVersion(suffix)
}
return fmt.Sprintf("%s%s", baseURL, suffix)
}
func (c *Client) suffixWithAPIVersion(suffix string) string {
parsedSuffix, err := url.Parse(suffix)
if err != nil {
panic("failed to parse url suffix")
}
query := parsedSuffix.Query()
query.Add("api-version", c.config.APIVersion) query.Add("api-version", c.config.APIVersion)
// if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode())
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) {
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode())
}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
model, ok := args[0].(string)
if ok {
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
}
}
return fmt.Sprintf("%s/%s/%s/%s%s?%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
azureDeploymentName, suffix, query.Encode(),
)
} }
// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) {
if c.config.APIType == APITypeCloudflareAzure { baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix)
baseURL := c.config.BaseURL if containsSubstr(azureDeploymentsEndpoints, suffix) {
baseURL = strings.TrimRight(baseURL, "/") azureDeploymentName := c.config.GetAzureDeploymentByModel(model)
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) if azureDeploymentName == "" {
azureDeploymentName = "UNKNOWN"
} }
baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName)
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) }
return baseURL
} }
func (c *Client) handleErrorResp(resp *http.Response) error { func (c *Client) handleErrorResp(resp *http.Response) error {

View File

@@ -431,3 +431,99 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) {
t.Fatalf("Did not return error when request builder failed: %v", err) t.Fatalf("Did not return error when request builder failed: %v", err)
} }
} }
func TestClient_suffixWithAPIVersion(t *testing.T) {
type fields struct {
apiVersion string
}
type args struct {
suffix string
}
tests := []struct {
name string
fields fields
args args
want string
wantPanic string
}{
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "/assistants"},
"/assistants?api-version=2023-05",
"",
},
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "/assistants?limit=5"},
"/assistants?api-version=2023-05&limit=5",
"",
},
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "123:assistants?limit=5"},
"/assistants?api-version=2023-05&limit=5",
"failed to parse url suffix",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
config: ClientConfig{APIVersion: tt.fields.apiVersion},
}
defer func() {
if r := recover(); r != nil {
if r.(string) != tt.wantPanic {
t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic)
}
}
}()
if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want {
t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want)
}
})
}
}
func TestClient_baseURLWithAzureDeployment(t *testing.T) {
type args struct {
baseURL string
suffix string
model string
}
tests := []struct {
name string
args args
wantNewBaseURL string
}{
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini},
"https://test.openai.azure.com/openai",
},
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini},
"https://test.openai.azure.com/openai/deployments/gpt-4o-mini",
},
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""},
"https://test.openai.azure.com/openai/deployments/UNKNOWN",
},
}
client := NewClient("")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotNewBaseURL := client.baseURLWithAzureDeployment(
tt.args.baseURL,
tt.args.suffix,
tt.args.model,
); gotNewBaseURL != tt.wantNewBaseURL {
t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL)
}
})
}
}

View File

@@ -213,7 +213,12 @@ func (c *Client) CreateCompletion(
return return
} }
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil { if err != nil {
return return
} }

View File

@@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024.
You can use CreateChatCompletion or CreateChatCompletionStream instead. You can use CreateChatCompletion or CreateChatCompletionStream instead.
*/ */
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/edits", withModel(fmt.Sprint(request.Model))),
withBody(request),
)
if err != nil { if err != nil {
return return
} }

View File

@@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter, conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) { ) (res EmbeddingResponse, err error) {
baseReq := conv.Convert() baseReq := conv.Convert()
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
withBody(baseReq),
)
if err != nil { if err != nil {
return return
} }

View File

@@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() {
return return
} }
fmt.Printf(response.Choices[0].Delta.Content) fmt.Println(response.Choices[0].Delta.Content)
} }
} }

View File

@@ -68,7 +68,12 @@ type ImageResponseDataInner struct {
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
urlSuffix := "/images/generations" urlSuffix := "/images/generations"
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil { if err != nil {
return return
} }
@@ -132,8 +137,13 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return return
} }
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), req, err := c.newRequest(
withBody(body), withContentType(builder.FormDataContentType())) ctx,
http.MethodPost,
c.fullURL("/images/edits", withModel(request.Model)),
withBody(body),
withContentType(builder.FormDataContentType()),
)
if err != nil { if err != nil {
return return
} }
@@ -183,8 +193,13 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
return return
} }
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), req, err := c.newRequest(
withBody(body), withContentType(builder.FormDataContentType())) ctx,
http.MethodPost,
c.fullURL("/images/variations", withModel(request.Model)),
withBody(body),
withContentType(builder.FormDataContentType()),
)
if err != nil { if err != nil {
return return
} }

View File

@@ -88,7 +88,12 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re
err = ErrModerationInvalidModel err = ErrModerationInvalidModel
return return
} }
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/moderations", withModel(request.Model)),
withBody(&request),
)
if err != nil { if err != nil {
return return
} }

View File

@@ -44,7 +44,10 @@ type CreateSpeechRequest struct {
} }
func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/audio/speech", withModel(string(request.Model))),
withBody(request), withBody(request),
withContentType("application/json"), withContentType("application/json"),
) )

View File

@@ -3,6 +3,7 @@ package openai
import ( import (
"context" "context"
"errors" "errors"
"net/http"
) )
var ( var (
@@ -33,7 +34,12 @@ func (c *Client) CreateCompletionStream(
} }
request.Stream = true request.Stream = true
req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }