add optional params for audio api, e.g. prompt (#183)

* Compatible with the situation where the mask is empty in CreateEditImage.

* Fix the test for the unnecessary removal of the mask.png file.

* add image variation implementation

* fix image variation bugs

* fix ci-lint problem with max line character limit

* add offitial doc link

* just for codeball test

* fix lint problem

* add optional params for audio api, e.g. prompt

* add comment for new args in translation
This commit is contained in:
itegel
2023-03-20 18:02:19 +08:00
committed by GitHub
parent d529d13ba1
commit aa149c1bf8
2 changed files with 100 additions and 2 deletions

View File

@@ -16,9 +16,13 @@ const (
)
// AudioRequest represents a request structure for audio API.
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
type AudioRequest struct {
Model string
FilePath string
Model string
FilePath string
Prompt string // For translation, it should be in English
Temperature float32
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
}
// AudioResponse represents a response structure for audio API.
@@ -94,6 +98,47 @@ func audioMultipartForm(request AudioRequest, w *multipart.Writer) error {
if _, err = io.Copy(fw, modelName); err != nil {
return fmt.Errorf("writing model name: %w", err)
}
// Create a form field for the prompt (if provided)
if request.Prompt != "" {
fw, err = w.CreateFormField("prompt")
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}
prompt := bytes.NewReader([]byte(request.Prompt))
if _, err = io.Copy(fw, prompt); err != nil {
return fmt.Errorf("writing prompt: %w", err)
}
}
// Create a form field for the temperature (if provided)
if request.Temperature != 0 {
fw, err = w.CreateFormField("temperature")
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}
temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature)))
if _, err = io.Copy(fw, temperature); err != nil {
return fmt.Errorf("writing temperature: %w", err)
}
}
// Create a form field for the language (if provided)
if request.Language != "" {
fw, err = w.CreateFormField("language")
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}
language := bytes.NewReader([]byte(request.Language))
if _, err = io.Copy(fw, language); err != nil {
return fmt.Errorf("writing language: %w", err)
}
}
// Close the multipart writer
w.Close()
return nil

View File

@@ -69,6 +69,59 @@ func TestAudio(t *testing.T) {
}
}
func TestAudioWithOptionalArgs(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}
ctx := context.Background()
dir, cleanup := createTestDirectory(t)
defer cleanup()
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
createTestFile(t, path)
req := AudioRequest{
FilePath: path,
Model: "whisper-3",
Prompt: "用简体中文",
Temperature: 0.5,
Language: "zh",
}
_, err = tc.createFn(ctx, req)
if err != nil {
t.Fatalf("audio API error: %v", err)
}
})
}
}
// createTestFile creates a fake file with "hello" as the content.
func createTestFile(t *testing.T, path string) {
file, err := os.Create(path)