change azure engine config to modelMapper (#306)
* change azure engine config to azure modelMapper config * Update go.mod * Revert "Update go.mod" This reverts commit 78d14c58f2a9ce668da43f6adbe20b60afcfe0d7. * lint fix * add test * lint fix * lint fix * lint fix * opt * opt * opt * opt
This commit is contained in:
@@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) {
|
|||||||
az.OrgID = c.OrgID
|
az.OrgID = c.OrgID
|
||||||
|
|
||||||
cli := NewClientWithConfig(az)
|
cli := NewClientWithConfig(az)
|
||||||
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil)
|
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed to create request: %v", err)
|
t.Errorf("Failed to create request: %v", err)
|
||||||
}
|
}
|
||||||
@@ -111,12 +111,14 @@ func TestAzureFullURL(t *testing.T) {
|
|||||||
cases := []struct {
|
cases := []struct {
|
||||||
Name string
|
Name string
|
||||||
BaseURL string
|
BaseURL string
|
||||||
Engine string
|
AzureModelMapper map[string]string
|
||||||
|
Model string
|
||||||
Expect string
|
Expect string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"AzureBaseURLWithSlashAutoStrip",
|
"AzureBaseURLWithSlashAutoStrip",
|
||||||
"https://httpbin.org/",
|
"https://httpbin.org/",
|
||||||
|
nil,
|
||||||
"chatgpt-demo",
|
"chatgpt-demo",
|
||||||
"https://httpbin.org/" +
|
"https://httpbin.org/" +
|
||||||
"openai/deployments/chatgpt-demo" +
|
"openai/deployments/chatgpt-demo" +
|
||||||
@@ -125,6 +127,7 @@ func TestAzureFullURL(t *testing.T) {
|
|||||||
{
|
{
|
||||||
"AzureBaseURLWithoutSlashOK",
|
"AzureBaseURLWithoutSlashOK",
|
||||||
"https://httpbin.org",
|
"https://httpbin.org",
|
||||||
|
nil,
|
||||||
"chatgpt-demo",
|
"chatgpt-demo",
|
||||||
"https://httpbin.org/" +
|
"https://httpbin.org/" +
|
||||||
"openai/deployments/chatgpt-demo" +
|
"openai/deployments/chatgpt-demo" +
|
||||||
@@ -134,10 +137,10 @@ func TestAzureFullURL(t *testing.T) {
|
|||||||
|
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run(c.Name, func(t *testing.T) {
|
t.Run(c.Name, func(t *testing.T) {
|
||||||
az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine)
|
az := DefaultAzureConfig("dummy", c.BaseURL)
|
||||||
cli := NewClientWithConfig(az)
|
cli := NewClientWithConfig(az)
|
||||||
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
|
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
|
||||||
actual := cli.fullURL("/chat/completions")
|
actual := cli.fullURL("/chat/completions", c.Model)
|
||||||
if actual != c.Expect {
|
if actual != c.Expect {
|
||||||
t.Errorf("Expected %s, got %s", c.Expect, actual)
|
t.Errorf("Expected %s, got %s", c.Expect, actual)
|
||||||
}
|
}
|
||||||
|
|||||||
2
audio.go
2
audio.go
@@ -68,7 +68,7 @@ func (c *Client) callAudioAPI(
|
|||||||
}
|
}
|
||||||
|
|
||||||
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
|
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AudioResponse{}, err
|
return AudioResponse{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
2
chat.go
2
chat.go
@@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Stream = true
|
request.Stream = true
|
||||||
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
|
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
22
client.go
22
client.go
@@ -98,8 +98,10 @@ func decodeString(body io.Reader, output *string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) fullURL(suffix string) string {
|
// fullURL returns full URL for request.
|
||||||
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
|
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
|
||||||
|
func (c *Client) fullURL(suffix string, args ...any) string {
|
||||||
|
// /openai/deployments/{model}/chat/completions?api-version={api_version}
|
||||||
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
|
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
|
||||||
baseURL := c.config.BaseURL
|
baseURL := c.config.BaseURL
|
||||||
baseURL = strings.TrimRight(baseURL, "/")
|
baseURL = strings.TrimRight(baseURL, "/")
|
||||||
@@ -108,8 +110,17 @@ func (c *Client) fullURL(suffix string) string {
|
|||||||
if strings.Contains(suffix, "/models") {
|
if strings.Contains(suffix, "/models") {
|
||||||
return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion)
|
return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion)
|
||||||
}
|
}
|
||||||
|
azureDeploymentName := "UNKNOWN"
|
||||||
|
if len(args) > 0 {
|
||||||
|
model, ok := args[0].(string)
|
||||||
|
if ok {
|
||||||
|
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
|
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
|
||||||
baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion)
|
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
|
||||||
|
azureDeploymentName, suffix, c.config.APIVersion,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
|
// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
|
||||||
@@ -120,8 +131,9 @@ func (c *Client) newStreamRequest(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
method string,
|
method string,
|
||||||
urlSuffix string,
|
urlSuffix string,
|
||||||
body any) (*http.Request, error) {
|
body any,
|
||||||
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
|
model string) (*http.Request, error) {
|
||||||
|
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func (c *Client) CreateCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
18
config.go
18
config.go
@@ -2,6 +2,7 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -30,8 +31,7 @@ type ClientConfig struct {
|
|||||||
OrgID string
|
OrgID string
|
||||||
APIType APIType
|
APIType APIType
|
||||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
|
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
|
||||||
Engine string // required when APIType is APITypeAzure or APITypeAzureAD
|
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
||||||
|
|
||||||
HTTPClient *http.Client
|
HTTPClient *http.Client
|
||||||
|
|
||||||
EmptyMessagesLimit uint
|
EmptyMessagesLimit uint
|
||||||
@@ -50,14 +50,16 @@ func DefaultConfig(authToken string) ClientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
|
func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
|
||||||
return ClientConfig{
|
return ClientConfig{
|
||||||
authToken: apiKey,
|
authToken: apiKey,
|
||||||
BaseURL: baseURL,
|
BaseURL: baseURL,
|
||||||
OrgID: "",
|
OrgID: "",
|
||||||
APIType: APITypeAzure,
|
APIType: APITypeAzure,
|
||||||
APIVersion: "2023-03-15-preview",
|
APIVersion: "2023-03-15-preview",
|
||||||
Engine: engine,
|
AzureModelMapperFunc: func(model string) string {
|
||||||
|
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
||||||
|
},
|
||||||
|
|
||||||
HTTPClient: &http.Client{},
|
HTTPClient: &http.Client{},
|
||||||
|
|
||||||
@@ -68,3 +70,11 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
|
|||||||
func (ClientConfig) String() string {
|
func (ClientConfig) String() string {
|
||||||
return "<OpenAI API ClientConfig>"
|
return "<OpenAI API ClientConfig>"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c ClientConfig) GetAzureDeploymentByModel(model string) string {
|
||||||
|
if c.AzureModelMapperFunc != nil {
|
||||||
|
return c.AzureModelMapperFunc(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|||||||
62
config_test.go
Normal file
62
config_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package openai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetAzureDeploymentByModel(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
Model string
|
||||||
|
AzureModelMapperFunc func(model string) string
|
||||||
|
Expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Model: "gpt-3.5-turbo",
|
||||||
|
Expect: "gpt-35-turbo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: "gpt-3.5-turbo-0301",
|
||||||
|
Expect: "gpt-35-turbo-0301",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: "text-embedding-ada-002",
|
||||||
|
Expect: "text-embedding-ada-002",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: "",
|
||||||
|
Expect: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: "models",
|
||||||
|
Expect: "models",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: "gpt-3.5-turbo",
|
||||||
|
Expect: "my-gpt35",
|
||||||
|
AzureModelMapperFunc: func(model string) string {
|
||||||
|
modelmapper := map[string]string{
|
||||||
|
"gpt-3.5-turbo": "my-gpt35",
|
||||||
|
}
|
||||||
|
if val, ok := modelmapper[model]; ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return model
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.Model, func(t *testing.T) {
|
||||||
|
conf := DefaultAzureConfig("", "https://test.openai.azure.com/")
|
||||||
|
if c.AzureModelMapperFunc != nil {
|
||||||
|
conf.AzureModelMapperFunc = c.AzureModelMapperFunc
|
||||||
|
}
|
||||||
|
actual := conf.GetAzureDeploymentByModel(c.Model)
|
||||||
|
if actual != c.Expect {
|
||||||
|
t.Errorf("Expected %s, got %s", c.Expect, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
3
edits.go
3
edits.go
@@ -2,6 +2,7 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,7 +32,7 @@ type EditsResponse struct {
|
|||||||
|
|
||||||
// Perform an API call to the Edits endpoint.
|
// Perform an API call to the Edits endpoint.
|
||||||
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request)
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ type EmbeddingRequest struct {
|
|||||||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
||||||
// https://beta.openai.com/docs/api-reference/embeddings/create
|
// https://beta.openai.com/docs/api-reference/embeddings/create
|
||||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -305,8 +305,7 @@ func Example_chatbot() {
|
|||||||
func ExampleDefaultAzureConfig() {
|
func ExampleDefaultAzureConfig() {
|
||||||
azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key
|
azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key
|
||||||
azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint
|
azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint
|
||||||
azureModel := os.Getenv("AZURE_OPENAI_MODEL") // Your model deployment name
|
config := openai.DefaultAzureConfig(azureKey, azureEndpoint)
|
||||||
config := openai.DefaultAzureConfig(azureKey, azureEndpoint, azureModel)
|
|
||||||
client := openai.NewClientWithConfig(config)
|
client := openai.NewClientWithConfig(config)
|
||||||
resp, err := client.CreateChatCompletion(
|
resp, err := client.CreateChatCompletion(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func TestAzureListModels(t *testing.T) {
|
|||||||
ts.Start()
|
ts.Start()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/", "dummyengine")
|
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
|
||||||
config.BaseURL = ts.URL
|
config.BaseURL = ts.URL
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ type ModerationResponse struct {
|
|||||||
// Moderations — perform a moderation api call over a string.
|
// Moderations — perform a moderation api call over a string.
|
||||||
// Input can be an array or slice but a string will reduce the complexity.
|
// Input can be an array or slice but a string will reduce the complexity.
|
||||||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
||||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request)
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func (c *Client) CreateCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Stream = true
|
request.Stream = true
|
||||||
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
|
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user