feat: add Anthropic API support with custom version header (#934)
* feat: add Anthropic API support with custom version header * refactor: use switch statement for API type header handling * refactor: add OpenAI & AzureAD types to be exhaustive * Update client.go need explicit fallthrough in empty case statements * constant for APIVersion; addtl tests
This commit is contained in:
14
client.go
14
client.go
@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
|
||||
|
||||
func (c *Client) setCommonHeaders(req *http.Request) {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
||||
switch c.config.APIType {
|
||||
case APITypeAzure, APITypeCloudflareAzure:
|
||||
// Azure API Key authentication
|
||||
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
|
||||
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
||||
} else if c.config.authToken != "" {
|
||||
// OpenAI or Azure AD authentication
|
||||
case APITypeAnthropic:
|
||||
// https://docs.anthropic.com/en/api/versioning
|
||||
req.Header.Set("anthropic-version", c.config.APIVersion)
|
||||
case APITypeOpenAI, APITypeAzureAD:
|
||||
fallthrough
|
||||
default:
|
||||
if c.config.authToken != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.OrgID != "" {
|
||||
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommonHeadersAnthropic(t *testing.T) {
|
||||
config := DefaultAnthropicConfig("mock-token", "")
|
||||
client := NewClientWithConfig(config)
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
client.setCommonHeaders(req)
|
||||
|
||||
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
|
||||
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse(t *testing.T) {
|
||||
stringInput := ""
|
||||
|
||||
|
||||
22
config.go
22
config.go
@@ -11,6 +11,8 @@ const (
|
||||
|
||||
azureAPIPrefix = "openai"
|
||||
azureDeploymentsPrefix = "deployments"
|
||||
|
||||
AnthropicAPIVersion = "2023-06-01"
|
||||
)
|
||||
|
||||
type APIType string
|
||||
@@ -20,6 +22,7 @@ const (
|
||||
APITypeAzure APIType = "AZURE"
|
||||
APITypeAzureAD APIType = "AZURE_AD"
|
||||
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
|
||||
APITypeAnthropic APIType = "ANTHROPIC"
|
||||
)
|
||||
|
||||
const AzureAPIKeyHeader = "api-key"
|
||||
@@ -37,7 +40,7 @@ type ClientConfig struct {
|
||||
BaseURL string
|
||||
OrgID string
|
||||
APIType APIType
|
||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
|
||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
|
||||
AssistantVersion string
|
||||
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
||||
HTTPClient HTTPDoer
|
||||
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com/v1"
|
||||
}
|
||||
return ClientConfig{
|
||||
authToken: apiKey,
|
||||
BaseURL: baseURL,
|
||||
OrgID: "",
|
||||
APIType: APITypeAnthropic,
|
||||
APIVersion: AnthropicAPIVersion,
|
||||
|
||||
HTTPClient: &http.Client{},
|
||||
|
||||
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func (ClientConfig) String() string {
|
||||
return "<OpenAI API ClientConfig>"
|
||||
}
|
||||
|
||||
@@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAnthropicConfig(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
baseURL := "https://api.anthropic.com/v1"
|
||||
|
||||
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
|
||||
|
||||
if config.APIType != openai.APITypeAnthropic {
|
||||
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||
}
|
||||
|
||||
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
|
||||
}
|
||||
|
||||
if config.BaseURL != baseURL {
|
||||
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
|
||||
}
|
||||
|
||||
if config.EmptyMessagesLimit != 300 {
|
||||
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
|
||||
config := openai.DefaultAnthropicConfig("", "")
|
||||
|
||||
if config.APIType != openai.APITypeAnthropic {
|
||||
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||
}
|
||||
|
||||
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
|
||||
}
|
||||
|
||||
expectedBaseURL := "https://api.anthropic.com/v1"
|
||||
if config.BaseURL != expectedBaseURL {
|
||||
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user