add optional params for audio api, e.g. prompt (#183)
* Compatible with the situation where the mask is empty in CreateEditImage. * Fix the test for the unnecessary removal of the mask.png file. * add image variation implementation * fix image variation bugs * fix ci-lint problem with max line character limit * add offitial doc link * just for codeball test * fix lint problem * add optional params for audio api, e.g. prompt * add comment for new args in translation
This commit is contained in:
49
audio.go
49
audio.go
@@ -16,9 +16,13 @@ const (
|
||||
)
|
||||
|
||||
// AudioRequest represents a request structure for audio API.
|
||||
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
|
||||
type AudioRequest struct {
|
||||
Model string
|
||||
FilePath string
|
||||
Model string
|
||||
FilePath string
|
||||
Prompt string // For translation, it should be in English
|
||||
Temperature float32
|
||||
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
|
||||
}
|
||||
|
||||
// AudioResponse represents a response structure for audio API.
|
||||
@@ -94,6 +98,47 @@ func audioMultipartForm(request AudioRequest, w *multipart.Writer) error {
|
||||
if _, err = io.Copy(fw, modelName); err != nil {
|
||||
return fmt.Errorf("writing model name: %w", err)
|
||||
}
|
||||
|
||||
// Create a form field for the prompt (if provided)
|
||||
if request.Prompt != "" {
|
||||
fw, err = w.CreateFormField("prompt")
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form field: %w", err)
|
||||
}
|
||||
|
||||
prompt := bytes.NewReader([]byte(request.Prompt))
|
||||
if _, err = io.Copy(fw, prompt); err != nil {
|
||||
return fmt.Errorf("writing prompt: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a form field for the temperature (if provided)
|
||||
if request.Temperature != 0 {
|
||||
fw, err = w.CreateFormField("temperature")
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form field: %w", err)
|
||||
}
|
||||
|
||||
temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature)))
|
||||
if _, err = io.Copy(fw, temperature); err != nil {
|
||||
return fmt.Errorf("writing temperature: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a form field for the language (if provided)
|
||||
if request.Language != "" {
|
||||
fw, err = w.CreateFormField("language")
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form field: %w", err)
|
||||
}
|
||||
|
||||
language := bytes.NewReader([]byte(request.Language))
|
||||
if _, err = io.Copy(fw, language); err != nil {
|
||||
return fmt.Errorf("writing language: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer
|
||||
w.Close()
|
||||
|
||||
return nil
|
||||
|
||||
@@ -69,6 +69,59 @@ func TestAudio(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAudioWithOptionalArgs(t *testing.T) {
|
||||
server := test.NewTestServer()
|
||||
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
|
||||
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
|
||||
// create the test server
|
||||
var err error
|
||||
ts := server.OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
|
||||
testcases := []struct {
|
||||
name string
|
||||
createFn func(context.Context, AudioRequest) (AudioResponse, error)
|
||||
}{
|
||||
{
|
||||
"transcribe",
|
||||
client.CreateTranscription,
|
||||
},
|
||||
{
|
||||
"translate",
|
||||
client.CreateTranslation,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
dir, cleanup := createTestDirectory(t)
|
||||
defer cleanup()
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
path := filepath.Join(dir, "fake.mp3")
|
||||
createTestFile(t, path)
|
||||
|
||||
req := AudioRequest{
|
||||
FilePath: path,
|
||||
Model: "whisper-3",
|
||||
Prompt: "用简体中文",
|
||||
Temperature: 0.5,
|
||||
Language: "zh",
|
||||
}
|
||||
_, err = tc.createFn(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("audio API error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// createTestFile creates a fake file with "hello" as the content.
|
||||
func createTestFile(t *testing.T, path string) {
|
||||
file, err := os.Create(path)
|
||||
|
||||
Reference in New Issue
Block a user