Adds support for audio captioning with Whisper (#267)
* Add speech to text example in docs * Add caption formats for audio transcription * Add caption example to README * Address sanity check errors * Add tests for decodeResponse * Use typechecker for audio response format * Decoding response refactors
This commit is contained in:
43
README.md
43
README.md
@@ -223,6 +223,47 @@ func main() {
|
|||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Audio Captions</summary>
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
openai "github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
c := openai.NewClient(os.Getenv("OPENAI_KEY"))
|
||||||
|
|
||||||
|
req := openai.AudioRequest{
|
||||||
|
Model: openai.Whisper1,
|
||||||
|
FilePath: os.Args[1],
|
||||||
|
Format: openai.AudioResponseFormatSRT,
|
||||||
|
}
|
||||||
|
resp, err := c.CreateTranscription(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Transcription error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f, err := os.Create(os.Args[1] + ".srt")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Could not open file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
if _, err := f.WriteString(resp.Text); err != nil {
|
||||||
|
fmt.Printf("Error writing to file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>DALL-E 2 image generation</summary>
|
<summary>DALL-E 2 image generation</summary>
|
||||||
|
|
||||||
@@ -420,4 +461,4 @@ func main() {
|
|||||||
fmt.Println(resp.Choices[0].Message.Content)
|
fmt.Println(resp.Choices[0].Message.Content)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|||||||
29
audio.go
29
audio.go
@@ -13,6 +13,15 @@ const (
|
|||||||
Whisper1 = "whisper-1"
|
Whisper1 = "whisper-1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Response formats; Whisper uses AudioResponseFormatJSON by default.
|
||||||
|
type AudioResponseFormat string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AudioResponseFormatJSON AudioResponseFormat = "json"
|
||||||
|
AudioResponseFormatSRT AudioResponseFormat = "srt"
|
||||||
|
AudioResponseFormatVTT AudioResponseFormat = "vtt"
|
||||||
|
)
|
||||||
|
|
||||||
// AudioRequest represents a request structure for audio API.
|
// AudioRequest represents a request structure for audio API.
|
||||||
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
|
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
|
||||||
type AudioRequest struct {
|
type AudioRequest struct {
|
||||||
@@ -21,6 +30,7 @@ type AudioRequest struct {
|
|||||||
Prompt string // For translation, it should be in English
|
Prompt string // For translation, it should be in English
|
||||||
Temperature float32
|
Temperature float32
|
||||||
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
|
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
|
||||||
|
Format AudioResponseFormat
|
||||||
}
|
}
|
||||||
|
|
||||||
// AudioResponse represents a response structure for audio API.
|
// AudioResponse represents a response structure for audio API.
|
||||||
@@ -66,10 +76,19 @@ func (c *Client) callAudioAPI(
|
|||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", builder.formDataContentType())
|
req.Header.Add("Content-Type", builder.formDataContentType())
|
||||||
|
|
||||||
err = c.sendRequest(req, &response)
|
if request.HasJSONResponse() {
|
||||||
|
err = c.sendRequest(req, &response)
|
||||||
|
} else {
|
||||||
|
err = c.sendRequest(req, &response.Text)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasJSONResponse returns true if the response format is JSON.
|
||||||
|
func (r AudioRequest) HasJSONResponse() bool {
|
||||||
|
return r.Format == "" || r.Format == AudioResponseFormatJSON
|
||||||
|
}
|
||||||
|
|
||||||
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
||||||
// audio processing.
|
// audio processing.
|
||||||
func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
||||||
@@ -97,6 +116,14 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a form field for the format (if provided)
|
||||||
|
if request.Format != "" {
|
||||||
|
err = b.writeField("response_format", string(request.Format))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing format: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create a form field for the temperature (if provided)
|
// Create a form field for the temperature (if provided)
|
||||||
if request.Temperature != 0 {
|
if request.Temperature != 0 {
|
||||||
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
|
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
|
|||||||
Prompt: "用简体中文",
|
Prompt: "用简体中文",
|
||||||
Temperature: 0.5,
|
Temperature: 0.5,
|
||||||
Language: "zh",
|
Language: "zh",
|
||||||
|
Format: AudioResponseFormatSRT,
|
||||||
}
|
}
|
||||||
_, err = tc.createFn(ctx, req)
|
_, err = tc.createFn(ctx, req)
|
||||||
checks.NoError(t, err, "audio API error")
|
checks.NoError(t, err, "audio API error")
|
||||||
@@ -179,6 +180,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
|||||||
Prompt: "test",
|
Prompt: "test",
|
||||||
Temperature: 0.5,
|
Temperature: 0.5,
|
||||||
Language: "en",
|
Language: "en",
|
||||||
|
Format: AudioResponseFormatSRT,
|
||||||
}
|
}
|
||||||
|
|
||||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||||
@@ -202,7 +204,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
failOn := []string{"model", "prompt", "temperature", "language"}
|
failOn := []string{"model", "prompt", "temperature", "language", "response_format"}
|
||||||
for _, failingField := range failOn {
|
for _, failingField := range failOn {
|
||||||
failForField = failingField
|
failForField = failingField
|
||||||
mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField)
|
mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField)
|
||||||
|
|||||||
24
client.go
24
client.go
@@ -43,7 +43,7 @@ func NewOrgClient(authToken, org string) *Client {
|
|||||||
return NewClientWithConfig(config)
|
return NewClientWithConfig(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
func (c *Client) sendRequest(req *http.Request, v any) error {
|
||||||
req.Header.Set("Accept", "application/json; charset=utf-8")
|
req.Header.Set("Accept", "application/json; charset=utf-8")
|
||||||
// Azure API Key authentication
|
// Azure API Key authentication
|
||||||
if c.config.APIType == APITypeAzure {
|
if c.config.APIType == APITypeAzure {
|
||||||
@@ -75,12 +75,26 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
|||||||
return c.handleErrorResp(res)
|
return c.handleErrorResp(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
if v != nil {
|
return decodeResponse(res.Body, v)
|
||||||
if err = json.NewDecoder(res.Body).Decode(v); err != nil {
|
}
|
||||||
return err
|
|
||||||
}
|
func decodeResponse(body io.Reader, v any) error {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result, ok := v.(*string); ok {
|
||||||
|
return decodeString(body, result)
|
||||||
|
}
|
||||||
|
return json.NewDecoder(body).Decode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeString(body io.Reader, output *string) error {
|
||||||
|
b, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*output = string(b)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,3 +22,38 @@ func TestClient(t *testing.T) {
|
|||||||
t.Errorf("Client does not contain proper orgID")
|
t.Errorf("Client does not contain proper orgID")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecodeResponse(t *testing.T) {
|
||||||
|
stringInput := ""
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
value interface{}
|
||||||
|
body io.Reader
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil input",
|
||||||
|
value: nil,
|
||||||
|
body: bytes.NewReader([]byte("")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string input",
|
||||||
|
value: &stringInput,
|
||||||
|
body: bytes.NewReader([]byte("test")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "map input",
|
||||||
|
value: &map[string]interface{}{},
|
||||||
|
body: bytes.NewReader([]byte(`{"test": "test"}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := decodeResponse(tc.body, tc.value)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user