Adds support for audio captioning with Whisper (#267)
* Add speech to text example in docs * Add caption formats for audio transcription * Add caption example to README * Address sanity check errors * Add tests for decodeResponse * Use typechecker for audio response format * Decoding response refactors
This commit is contained in:
41
README.md
41
README.md
@@ -223,6 +223,47 @@ func main() {
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Audio Captions</summary>
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
c := openai.NewClient(os.Getenv("OPENAI_KEY"))
|
||||
|
||||
req := openai.AudioRequest{
|
||||
Model: openai.Whisper1,
|
||||
FilePath: os.Args[1],
|
||||
Format: openai.AudioResponseFormatSRT,
|
||||
}
|
||||
resp, err := c.CreateTranscription(context.Background(), req)
|
||||
if err != nil {
|
||||
fmt.Printf("Transcription error: %v\n", err)
|
||||
return
|
||||
}
|
||||
f, err := os.Create(os.Args[1] + ".srt")
|
||||
if err != nil {
|
||||
fmt.Printf("Could not open file: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
if _, err := f.WriteString(resp.Text); err != nil {
|
||||
fmt.Printf("Error writing to file: %v\n", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DALL-E 2 image generation</summary>
|
||||
|
||||
|
||||
29
audio.go
29
audio.go
@@ -13,6 +13,15 @@ const (
|
||||
Whisper1 = "whisper-1"
|
||||
)
|
||||
|
||||
// Response formats; Whisper uses AudioResponseFormatJSON by default.
|
||||
type AudioResponseFormat string
|
||||
|
||||
const (
|
||||
AudioResponseFormatJSON AudioResponseFormat = "json"
|
||||
AudioResponseFormatSRT AudioResponseFormat = "srt"
|
||||
AudioResponseFormatVTT AudioResponseFormat = "vtt"
|
||||
)
|
||||
|
||||
// AudioRequest represents a request structure for audio API.
|
||||
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
|
||||
type AudioRequest struct {
|
||||
@@ -21,6 +30,7 @@ type AudioRequest struct {
|
||||
Prompt string // For translation, it should be in English
|
||||
Temperature float32
|
||||
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
|
||||
Format AudioResponseFormat
|
||||
}
|
||||
|
||||
// AudioResponse represents a response structure for audio API.
|
||||
@@ -66,10 +76,19 @@ func (c *Client) callAudioAPI(
|
||||
}
|
||||
req.Header.Add("Content-Type", builder.formDataContentType())
|
||||
|
||||
err = c.sendRequest(req, &response)
|
||||
if request.HasJSONResponse() {
|
||||
err = c.sendRequest(req, &response)
|
||||
} else {
|
||||
err = c.sendRequest(req, &response.Text)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// HasJSONResponse returns true if the response format is JSON.
|
||||
func (r AudioRequest) HasJSONResponse() bool {
|
||||
return r.Format == "" || r.Format == AudioResponseFormatJSON
|
||||
}
|
||||
|
||||
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
||||
// audio processing.
|
||||
func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
||||
@@ -97,6 +116,14 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Create a form field for the format (if provided)
|
||||
if request.Format != "" {
|
||||
err = b.writeField("response_format", string(request.Format))
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a form field for the temperature (if provided)
|
||||
if request.Temperature != 0 {
|
||||
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
|
||||
|
||||
@@ -112,6 +112,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
|
||||
Prompt: "用简体中文",
|
||||
Temperature: 0.5,
|
||||
Language: "zh",
|
||||
Format: AudioResponseFormatSRT,
|
||||
}
|
||||
_, err = tc.createFn(ctx, req)
|
||||
checks.NoError(t, err, "audio API error")
|
||||
@@ -179,6 +180,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
||||
Prompt: "test",
|
||||
Temperature: 0.5,
|
||||
Language: "en",
|
||||
Format: AudioResponseFormatSRT,
|
||||
}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
@@ -202,7 +204,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
failOn := []string{"model", "prompt", "temperature", "language"}
|
||||
failOn := []string{"model", "prompt", "temperature", "language", "response_format"}
|
||||
for _, failingField := range failOn {
|
||||
failForField = failingField
|
||||
mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField)
|
||||
|
||||
24
client.go
24
client.go
@@ -43,7 +43,7 @@ func NewOrgClient(authToken, org string) *Client {
|
||||
return NewClientWithConfig(config)
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||
func (c *Client) sendRequest(req *http.Request, v any) error {
|
||||
req.Header.Set("Accept", "application/json; charset=utf-8")
|
||||
// Azure API Key authentication
|
||||
if c.config.APIType == APITypeAzure {
|
||||
@@ -75,12 +75,26 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||
return c.handleErrorResp(res)
|
||||
}
|
||||
|
||||
if v != nil {
|
||||
if err = json.NewDecoder(res.Body).Decode(v); err != nil {
|
||||
return err
|
||||
}
|
||||
return decodeResponse(res.Body, v)
|
||||
}
|
||||
|
||||
func decodeResponse(body io.Reader, v any) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if result, ok := v.(*string); ok {
|
||||
return decodeString(body, result)
|
||||
}
|
||||
return json.NewDecoder(body).Decode(v)
|
||||
}
|
||||
|
||||
func decodeString(body io.Reader, output *string) error {
|
||||
b, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*output = string(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -20,3 +22,38 @@ func TestClient(t *testing.T) {
|
||||
t.Errorf("Client does not contain proper orgID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse(t *testing.T) {
|
||||
stringInput := ""
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
body io.Reader
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
value: nil,
|
||||
body: bytes.NewReader([]byte("")),
|
||||
},
|
||||
{
|
||||
name: "string input",
|
||||
value: &stringInput,
|
||||
body: bytes.NewReader([]byte("test")),
|
||||
},
|
||||
{
|
||||
name: "map input",
|
||||
value: &map[string]interface{}{},
|
||||
body: bytes.NewReader([]byte(`{"test": "test"}`)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := decodeResponse(tc.body, tc.value)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user