fix(audio): fix audioTextResponse decode (#638)

* fix(audio): fix audioTextResponse decode

* test(audio): add audioTextResponse decode test

* test(audio): simplify code
This commit is contained in:
Qiying Wang
2024-01-18 01:42:07 +08:00
committed by GitHub
parent 4ce03a919a
commit eff8dc1118
2 changed files with 47 additions and 11 deletions

View File

@@ -193,11 +193,15 @@ func decodeResponse(body io.Reader, v any) error {
return nil
}
if result, ok := v.(*string); ok {
return decodeString(body, result)
}
switch o := v.(type) {
case *string:
return decodeString(body, o)
case *audioTextResponse:
return decodeString(body, &o.Text)
default:
return json.NewDecoder(body).Decode(v)
}
}
func decodeString(body io.Reader, output *string) error {
b, err := io.ReadAll(body)

View File

@@ -7,9 +7,11 @@ import (
"fmt"
"io"
"net/http"
"reflect"
"testing"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
var errTestRequestBuilderFailed = errors.New("test request builder failed")
@@ -43,6 +45,7 @@ func TestDecodeResponse(t *testing.T) {
testCases := []struct {
name string
value interface{}
expected interface{}
body io.Reader
hasError bool
}{
@@ -50,16 +53,21 @@ func TestDecodeResponse(t *testing.T) {
name: "nil input",
value: nil,
body: bytes.NewReader([]byte("")),
expected: nil,
},
{
name: "string input",
value: &stringInput,
body: bytes.NewReader([]byte("test")),
expected: "test",
},
{
name: "map input",
value: &map[string]interface{}{},
body: bytes.NewReader([]byte(`{"test": "test"}`)),
expected: map[string]interface{}{
"test": "test",
},
},
{
name: "reader return error",
@@ -67,14 +75,38 @@ func TestDecodeResponse(t *testing.T) {
body: &errorReader{err: errors.New("dummy")},
hasError: true,
},
{
name: "audio text input",
value: &audioTextResponse{},
body: bytes.NewReader([]byte("test")),
expected: audioTextResponse{
Text: "test",
},
},
}
assertEqual := func(t *testing.T, expected, actual interface{}) {
t.Helper()
if expected == actual {
return
}
v := reflect.ValueOf(actual).Elem().Interface()
if !reflect.DeepEqual(v, expected) {
t.Fatalf("Unexpected value: %v, expected: %v", v, expected)
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := decodeResponse(tc.body, tc.value)
if (err != nil) != tc.hasError {
t.Errorf("Unexpected error: %v", err)
if tc.hasError {
checks.HasError(t, err, "Unexpected nil error")
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
assertEqual(t, tc.expected, tc.value)
})
}
}