85 lines
1.9 KiB
Go
85 lines
1.9 KiB
Go
package openai
|
|
|
|
import (
|
|
"net/http"
|
|
"regexp"
|
|
)
|
|
|
|
const (
|
|
openaiAPIURLv1 = "https://api.openai.com/v1"
|
|
defaultEmptyMessagesLimit uint = 300
|
|
|
|
azureAPIPrefix = "openai"
|
|
azureDeploymentsPrefix = "deployments"
|
|
)
|
|
|
|
type APIType string
|
|
|
|
const (
|
|
APITypeOpenAI APIType = "OPEN_AI"
|
|
APITypeAzure APIType = "AZURE"
|
|
APITypeAzureAD APIType = "AZURE_AD"
|
|
)
|
|
|
|
const AzureAPIKeyHeader = "api-key"
|
|
|
|
const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024.
|
|
|
|
// ClientConfig is a configuration of a client.
|
|
type ClientConfig struct {
|
|
authToken string
|
|
|
|
BaseURL string
|
|
OrgID string
|
|
APIType APIType
|
|
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
|
|
AssistantVersion string
|
|
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
|
HTTPClient *http.Client
|
|
|
|
EmptyMessagesLimit uint
|
|
}
|
|
|
|
func DefaultConfig(authToken string) ClientConfig {
|
|
return ClientConfig{
|
|
authToken: authToken,
|
|
BaseURL: openaiAPIURLv1,
|
|
APIType: APITypeOpenAI,
|
|
AssistantVersion: defaultAssistantVersion,
|
|
OrgID: "",
|
|
|
|
HTTPClient: &http.Client{},
|
|
|
|
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
|
}
|
|
}
|
|
|
|
func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
|
|
return ClientConfig{
|
|
authToken: apiKey,
|
|
BaseURL: baseURL,
|
|
OrgID: "",
|
|
APIType: APITypeAzure,
|
|
APIVersion: "2023-05-15",
|
|
AzureModelMapperFunc: func(model string) string {
|
|
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
|
},
|
|
|
|
HTTPClient: &http.Client{},
|
|
|
|
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
|
}
|
|
}
|
|
|
|
func (ClientConfig) String() string {
|
|
return "<OpenAI API ClientConfig>"
|
|
}
|
|
|
|
func (c ClientConfig) GetAzureDeploymentByModel(model string) string {
|
|
if c.AzureModelMapperFunc != nil {
|
|
return c.AzureModelMapperFunc(model)
|
|
}
|
|
|
|
return model
|
|
}
|