diff --git a/client_test.go b/client_test.go index 24cb5ff..1c90845 100644 --- a/client_test.go +++ b/client_test.go @@ -358,6 +358,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListRunSteps", func() (any, error) { return client.ListRunSteps(ctx, "", "", Pagination{}) }}, + {"CreateSpeech", func() (any, error) { + return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) + }}, } for _, testCase := range testCases { diff --git a/speech.go b/speech.go new file mode 100644 index 0000000..a3d5f5d --- /dev/null +++ b/speech.go @@ -0,0 +1,87 @@ +package openai + +import ( + "context" + "errors" + "io" + "net/http" +) + +type SpeechModel string + +const ( + TTSModel1 SpeechModel = "tts-1" + TTsModel1HD SpeechModel = "tts-1-hd" +) + +type SpeechVoice string + +const ( + VoiceAlloy SpeechVoice = "alloy" + VoiceEcho SpeechVoice = "echo" + VoiceFable SpeechVoice = "fable" + VoiceOnyx SpeechVoice = "onyx" + VoiceNova SpeechVoice = "nova" + VoiceShimmer SpeechVoice = "shimmer" +) + +type SpeechResponseFormat string + +const ( + SpeechResponseFormatMp3 SpeechResponseFormat = "mp3" + SpeechResponseFormatOpus SpeechResponseFormat = "opus" + SpeechResponseFormatAac SpeechResponseFormat = "aac" + SpeechResponseFormatFlac SpeechResponseFormat = "flac" +) + +var ( + ErrInvalidSpeechModel = errors.New("invalid speech model") + ErrInvalidVoice = errors.New("invalid voice") +) + +type CreateSpeechRequest struct { + Model SpeechModel `json:"model"` + Input string `json:"input"` + Voice SpeechVoice `json:"voice"` + ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 + Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 +} + +func contains[T comparable](s []T, e T) bool { + for _, v := range s { + if v == e { + return true + } + } + return false +} + +func isValidSpeechModel(model SpeechModel) bool { + return contains([]SpeechModel{TTSModel1, TTsModel1HD}, model) +} + +func isValidVoice(voice SpeechVoice) bool { + return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) +} + +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { + if !isValidSpeechModel(request.Model) { + err = ErrInvalidSpeechModel + return + } + if !isValidVoice(request.Voice) { + err = ErrInvalidVoice + return + } + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", request.Model), + withBody(request), + withContentType("application/json; charset=utf-8"), + ) + if err != nil { + return + } + + response, err = c.sendRequestRaw(req) + + return +} diff --git a/speech_test.go b/speech_test.go new file mode 100644 index 0000000..d9ba58b --- /dev/null +++ b/speech_test.go @@ -0,0 +1,115 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestSpeechIntegration(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { + dir, cleanup := test.CreateTestDirectory(t) + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + defer cleanup() + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if mediaType != "application/json" { + http.Error(w, "request is not json", http.StatusBadRequest) + return + } + + // Parse the JSON body of the request + var params map[string]interface{} + err = json.NewDecoder(r.Body).Decode(¶ms) + if err != nil { + http.Error(w, "failed to parse request body", http.StatusBadRequest) + return + } + + // Check if each required field is present in the parsed JSON object + reqParams := []string{"model", "input", "voice"} + for _, param := range reqParams { + _, ok := params[param] + if !ok { + http.Error(w, fmt.Sprintf("no %s in params", param), http.StatusBadRequest) + return + } + } + + // read audio file content + audioFile, err := os.ReadFile(path) + if err != nil { + http.Error(w, "failed to read audio file", http.StatusInternalServerError) + return + } + + // write audio file content to response + w.Header().Set("Content-Type", "audio/mpeg") + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Connection", "keep-alive") + _, err = w.Write(audioFile) + if err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } + }) + + t.Run("happy path", func(t *testing.T) { + res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.NoError(t, err, "CreateSpeech error") + defer res.Close() + + buf, err := io.ReadAll(res) + checks.NoError(t, err, "ReadAll error") + + // save buf to file as mp3 + err = os.WriteFile("test.mp3", buf, 0644) + checks.NoError(t, err, "Create error") + }) + t.Run("invalid model", func(t *testing.T) { + _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: "invalid_model", + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") + }) + + t.Run("invalid voice", func(t *testing.T) { + _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: "invalid_voice", + }) + checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") + }) +}