diff --git a/api_internal_test.go b/api_internal_test.go index 0fb0f89..a590ec9 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -148,3 +148,39 @@ func TestAzureFullURL(t *testing.T) { }) } } + +func TestCloudflareAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Expect string + }{ + { + "CloudflareAzureBaseURLWithSlashAutoStrip", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + { + "CloudflareAzureBaseURLWithoutSlashOK", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL) + az.APIType = APITypeCloudflareAzure + + cli := NewClientWithConfig(az) + + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} diff --git a/client.go b/client.go index 77d6932..c57ba17 100644 --- a/client.go +++ b/client.go @@ -182,7 +182,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.APIType == APITypeAzure { + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else if c.config.authToken != "" { // OpenAI or Azure AD authentication @@ -246,7 +246,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { ) } - // c.config.APIType == APITypeOpenAI || c.config.APIType == "" + // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ + if c.config.APIType == APITypeCloudflareAzure { + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) + } + return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } diff --git a/config.go b/config.go index 599fa89..bb437c9 100644 --- a/config.go +++ b/config.go @@ -16,9 +16,10 @@ const ( type APIType string const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" + APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" ) const AzureAPIKeyHeader = "api-key"