add optional params for audio api, e.g. prompt (#183)
* Compatible with the situation where the mask is empty in CreateEditImage. * Fix the test for the unnecessary removal of the mask.png file. * add image variation implementation * fix image variation bugs * fix ci-lint problem with max line character limit * add offitial doc link * just for codeball test * fix lint problem * add optional params for audio api, e.g. prompt * add comment for new args in translation
This commit is contained in:
49
audio.go
49
audio.go
@@ -16,9 +16,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AudioRequest represents a request structure for audio API.
|
// AudioRequest represents a request structure for audio API.
|
||||||
|
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
|
||||||
type AudioRequest struct {
|
type AudioRequest struct {
|
||||||
Model string
|
Model string
|
||||||
FilePath string
|
FilePath string
|
||||||
|
Prompt string // For translation, it should be in English
|
||||||
|
Temperature float32
|
||||||
|
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
|
||||||
}
|
}
|
||||||
|
|
||||||
// AudioResponse represents a response structure for audio API.
|
// AudioResponse represents a response structure for audio API.
|
||||||
@@ -94,6 +98,47 @@ func audioMultipartForm(request AudioRequest, w *multipart.Writer) error {
|
|||||||
if _, err = io.Copy(fw, modelName); err != nil {
|
if _, err = io.Copy(fw, modelName); err != nil {
|
||||||
return fmt.Errorf("writing model name: %w", err)
|
return fmt.Errorf("writing model name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a form field for the prompt (if provided)
|
||||||
|
if request.Prompt != "" {
|
||||||
|
fw, err = w.CreateFormField("prompt")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating form field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := bytes.NewReader([]byte(request.Prompt))
|
||||||
|
if _, err = io.Copy(fw, prompt); err != nil {
|
||||||
|
return fmt.Errorf("writing prompt: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a form field for the temperature (if provided)
|
||||||
|
if request.Temperature != 0 {
|
||||||
|
fw, err = w.CreateFormField("temperature")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating form field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature)))
|
||||||
|
if _, err = io.Copy(fw, temperature); err != nil {
|
||||||
|
return fmt.Errorf("writing temperature: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a form field for the language (if provided)
|
||||||
|
if request.Language != "" {
|
||||||
|
fw, err = w.CreateFormField("language")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating form field: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
language := bytes.NewReader([]byte(request.Language))
|
||||||
|
if _, err = io.Copy(fw, language); err != nil {
|
||||||
|
return fmt.Errorf("writing language: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the multipart writer
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -69,6 +69,59 @@ func TestAudio(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAudioWithOptionalArgs(t *testing.T) {
|
||||||
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
|
||||||
|
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
|
||||||
|
// create the test server
|
||||||
|
var err error
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
|
||||||
|
testcases := []struct {
|
||||||
|
name string
|
||||||
|
createFn func(context.Context, AudioRequest) (AudioResponse, error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"transcribe",
|
||||||
|
client.CreateTranscription,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"translate",
|
||||||
|
client.CreateTranslation,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dir, cleanup := createTestDirectory(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, tc := range testcases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
path := filepath.Join(dir, "fake.mp3")
|
||||||
|
createTestFile(t, path)
|
||||||
|
|
||||||
|
req := AudioRequest{
|
||||||
|
FilePath: path,
|
||||||
|
Model: "whisper-3",
|
||||||
|
Prompt: "用简体中文",
|
||||||
|
Temperature: 0.5,
|
||||||
|
Language: "zh",
|
||||||
|
}
|
||||||
|
_, err = tc.createFn(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("audio API error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// createTestFile creates a fake file with "hello" as the content.
|
// createTestFile creates a fake file with "hello" as the content.
|
||||||
func createTestFile(t *testing.T, path string) {
|
func createTestFile(t *testing.T, path string) {
|
||||||
file, err := os.Create(path)
|
file, err := os.Create(path)
|
||||||
|
|||||||
Reference in New Issue
Block a user