fix: fullURL endpoint generation (#817)
This commit is contained in:
84
client.go
84
client.go
@@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fullURLOptions struct {
|
||||
model string
|
||||
}
|
||||
|
||||
type fullURLOption func(*fullURLOptions)
|
||||
|
||||
func withModel(model string) fullURLOption {
|
||||
return func(args *fullURLOptions) {
|
||||
args.model = model
|
||||
}
|
||||
}
|
||||
|
||||
var azureDeploymentsEndpoints = []string{
|
||||
"/completions",
|
||||
"/embeddings",
|
||||
"/chat/completions",
|
||||
"/audio/transcriptions",
|
||||
"/audio/translations",
|
||||
"/audio/speech",
|
||||
"/images/generations",
|
||||
}
|
||||
|
||||
// fullURL returns full URL for request.
|
||||
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
|
||||
func (c *Client) fullURL(suffix string, args ...any) string {
|
||||
// /openai/deployments/{model}/chat/completions?api-version={api_version}
|
||||
func (c *Client) fullURL(suffix string, setters ...fullURLOption) string {
|
||||
baseURL := strings.TrimRight(c.config.BaseURL, "/")
|
||||
args := fullURLOptions{}
|
||||
for _, setter := range setters {
|
||||
setter(&args)
|
||||
}
|
||||
|
||||
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
|
||||
baseURL := c.config.BaseURL
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
parseURL, _ := url.Parse(baseURL)
|
||||
query := parseURL.Query()
|
||||
query.Add("api-version", c.config.APIVersion)
|
||||
// if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01
|
||||
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
|
||||
if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) {
|
||||
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode())
|
||||
}
|
||||
azureDeploymentName := "UNKNOWN"
|
||||
if len(args) > 0 {
|
||||
model, ok := args[0].(string)
|
||||
if ok {
|
||||
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s/%s%s?%s",
|
||||
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
|
||||
azureDeploymentName, suffix, query.Encode(),
|
||||
)
|
||||
baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model)
|
||||
}
|
||||
|
||||
// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
|
||||
if c.config.APIType == APITypeCloudflareAzure {
|
||||
baseURL := c.config.BaseURL
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
|
||||
if c.config.APIVersion != "" {
|
||||
suffix = c.suffixWithAPIVersion(suffix)
|
||||
}
|
||||
return fmt.Sprintf("%s%s", baseURL, suffix)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
|
||||
func (c *Client) suffixWithAPIVersion(suffix string) string {
|
||||
parsedSuffix, err := url.Parse(suffix)
|
||||
if err != nil {
|
||||
panic("failed to parse url suffix")
|
||||
}
|
||||
query := parsedSuffix.Query()
|
||||
query.Add("api-version", c.config.APIVersion)
|
||||
return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode())
|
||||
}
|
||||
|
||||
func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) {
|
||||
baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix)
|
||||
if containsSubstr(azureDeploymentsEndpoints, suffix) {
|
||||
azureDeploymentName := c.config.GetAzureDeploymentByModel(model)
|
||||
if azureDeploymentName == "" {
|
||||
azureDeploymentName = "UNKNOWN"
|
||||
}
|
||||
baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName)
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
func (c *Client) handleErrorResp(resp *http.Response) error {
|
||||
|
||||
Reference in New Issue
Block a user