Adds support for audio captioning with Whisper (#267)

* Add speech to text example in docs

* Add caption formats for audio transcription

* Add caption example to README

* Address sanity check errors

* Add tests for decodeResponse

* Use typechecker for audio response format

* Decoding response refactors
This commit is contained in:
Hoani Bryson
2023-04-21 01:07:04 +12:00
committed by GitHub
parent d6ab1b3a4f
commit ecdea45b67
5 changed files with 129 additions and 8 deletions

View File

@@ -223,6 +223,47 @@ func main() {
``` ```
</details> </details>
<details>
<summary>Audio Captions</summary>
```go
package main
import (
"context"
"fmt"
"os"
openai "github.com/sashabaranov/go-openai"
)
func main() {
c := openai.NewClient(os.Getenv("OPENAI_KEY"))
req := openai.AudioRequest{
Model: openai.Whisper1,
FilePath: os.Args[1],
Format: openai.AudioResponseFormatSRT,
}
resp, err := c.CreateTranscription(context.Background(), req)
if err != nil {
fmt.Printf("Transcription error: %v\n", err)
return
}
f, err := os.Create(os.Args[1] + ".srt")
if err != nil {
fmt.Printf("Could not open file: %v\n", err)
return
}
defer f.Close()
if _, err := f.WriteString(resp.Text); err != nil {
fmt.Printf("Error writing to file: %v\n", err)
return
}
}
```
</details>
<details> <details>
<summary>DALL-E 2 image generation</summary> <summary>DALL-E 2 image generation</summary>

View File

@@ -13,6 +13,15 @@ const (
Whisper1 = "whisper-1" Whisper1 = "whisper-1"
) )
// Response formats; Whisper uses AudioResponseFormatJSON by default.
type AudioResponseFormat string
const (
AudioResponseFormatJSON AudioResponseFormat = "json"
AudioResponseFormatSRT AudioResponseFormat = "srt"
AudioResponseFormatVTT AudioResponseFormat = "vtt"
)
// AudioRequest represents a request structure for audio API. // AudioRequest represents a request structure for audio API.
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. // ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
type AudioRequest struct { type AudioRequest struct {
@@ -21,6 +30,7 @@ type AudioRequest struct {
Prompt string // For translation, it should be in English Prompt string // For translation, it should be in English
Temperature float32 Temperature float32
Language string // For translation, just do not use it. It seems "en" works, not confirmed... Language string // For translation, just do not use it. It seems "en" works, not confirmed...
Format AudioResponseFormat
} }
// AudioResponse represents a response structure for audio API. // AudioResponse represents a response structure for audio API.
@@ -66,10 +76,19 @@ func (c *Client) callAudioAPI(
} }
req.Header.Add("Content-Type", builder.formDataContentType()) req.Header.Add("Content-Type", builder.formDataContentType())
err = c.sendRequest(req, &response) if request.HasJSONResponse() {
err = c.sendRequest(req, &response)
} else {
err = c.sendRequest(req, &response.Text)
}
return return
} }
// HasJSONResponse returns true if the response format is JSON.
func (r AudioRequest) HasJSONResponse() bool {
return r.Format == "" || r.Format == AudioResponseFormatJSON
}
// audioMultipartForm creates a form with audio file contents and the name of the model to use for // audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing. // audio processing.
func audioMultipartForm(request AudioRequest, b formBuilder) error { func audioMultipartForm(request AudioRequest, b formBuilder) error {
@@ -97,6 +116,14 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
} }
} }
// Create a form field for the format (if provided)
if request.Format != "" {
err = b.writeField("response_format", string(request.Format))
if err != nil {
return fmt.Errorf("writing format: %w", err)
}
}
// Create a form field for the temperature (if provided) // Create a form field for the temperature (if provided)
if request.Temperature != 0 { if request.Temperature != 0 {
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature)) err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))

View File

@@ -112,6 +112,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
Prompt: "用简体中文", Prompt: "用简体中文",
Temperature: 0.5, Temperature: 0.5,
Language: "zh", Language: "zh",
Format: AudioResponseFormatSRT,
} }
_, err = tc.createFn(ctx, req) _, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error") checks.NoError(t, err, "audio API error")
@@ -179,6 +180,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
Prompt: "test", Prompt: "test",
Temperature: 0.5, Temperature: 0.5,
Language: "en", Language: "en",
Format: AudioResponseFormatSRT,
} }
mockFailedErr := fmt.Errorf("mock form builder fail") mockFailedErr := fmt.Errorf("mock form builder fail")
@@ -202,7 +204,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
return nil return nil
} }
failOn := []string{"model", "prompt", "temperature", "language"} failOn := []string{"model", "prompt", "temperature", "language", "response_format"}
for _, failingField := range failOn { for _, failingField := range failOn {
failForField = failingField failForField = failingField
mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField)

View File

@@ -43,7 +43,7 @@ func NewOrgClient(authToken, org string) *Client {
return NewClientWithConfig(config) return NewClientWithConfig(config)
} }
func (c *Client) sendRequest(req *http.Request, v interface{}) error { func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Accept", "application/json; charset=utf-8") req.Header.Set("Accept", "application/json; charset=utf-8")
// Azure API Key authentication // Azure API Key authentication
if c.config.APIType == APITypeAzure { if c.config.APIType == APITypeAzure {
@@ -75,12 +75,26 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
return c.handleErrorResp(res) return c.handleErrorResp(res)
} }
if v != nil { return decodeResponse(res.Body, v)
if err = json.NewDecoder(res.Body).Decode(v); err != nil { }
return err
} func decodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
} }
if result, ok := v.(*string); ok {
return decodeString(body, result)
}
return json.NewDecoder(body).Decode(v)
}
func decodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)
if err != nil {
return err
}
*output = string(b)
return nil return nil
} }

View File

@@ -1,6 +1,8 @@
package openai //nolint:testpackage // testing private field package openai //nolint:testpackage // testing private field
import ( import (
"bytes"
"io"
"testing" "testing"
) )
@@ -20,3 +22,38 @@ func TestClient(t *testing.T) {
t.Errorf("Client does not contain proper orgID") t.Errorf("Client does not contain proper orgID")
} }
} }
func TestDecodeResponse(t *testing.T) {
stringInput := ""
testCases := []struct {
name string
value interface{}
body io.Reader
}{
{
name: "nil input",
value: nil,
body: bytes.NewReader([]byte("")),
},
{
name: "string input",
value: &stringInput,
body: bytes.NewReader([]byte("test")),
},
{
name: "map input",
value: &map[string]interface{}{},
body: bytes.NewReader([]byte(`{"test": "test"}`)),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := decodeResponse(tc.body, tc.value)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
})
}
}