diff --git a/client.go b/client.go index 056226c..8bbbb87 100644 --- a/client.go +++ b/client.go @@ -193,10 +193,14 @@ func decodeResponse(body io.Reader, v any) error { return nil } - if result, ok := v.(*string); ok { - return decodeString(body, result) + switch o := v.(type) { + case *string: + return decodeString(body, o) + case *audioTextResponse: + return decodeString(body, &o.Text) + default: + return json.NewDecoder(body).Decode(v) } - return json.NewDecoder(body).Decode(v) } func decodeString(body io.Reader, output *string) error { diff --git a/client_test.go b/client_test.go index 664f9fb..bc5133e 100644 --- a/client_test.go +++ b/client_test.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "net/http" + "reflect" "testing" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) var errTestRequestBuilderFailed = errors.New("test request builder failed") @@ -43,23 +45,29 @@ func TestDecodeResponse(t *testing.T) { testCases := []struct { name string value interface{} + expected interface{} body io.Reader hasError bool }{ { - name: "nil input", - value: nil, - body: bytes.NewReader([]byte("")), + name: "nil input", + value: nil, + body: bytes.NewReader([]byte("")), + expected: nil, }, { - name: "string input", - value: &stringInput, - body: bytes.NewReader([]byte("test")), + name: "string input", + value: &stringInput, + body: bytes.NewReader([]byte("test")), + expected: "test", }, { name: "map input", value: &map[string]interface{}{}, body: bytes.NewReader([]byte(`{"test": "test"}`)), + expected: map[string]interface{}{ + "test": "test", + }, }, { name: "reader return error", @@ -67,14 +75,38 @@ func TestDecodeResponse(t *testing.T) { body: &errorReader{err: errors.New("dummy")}, hasError: true, }, + { + name: "audio text input", + value: &audioTextResponse{}, + body: bytes.NewReader([]byte("test")), + expected: audioTextResponse{ + Text: "test", + }, + }, + } + + assertEqual := func(t *testing.T, expected, actual interface{}) { + t.Helper() + if expected == actual { + return + } + v := reflect.ValueOf(actual).Elem().Interface() + if !reflect.DeepEqual(v, expected) { + t.Fatalf("Unexpected value: %v, expected: %v", v, expected) + } } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := decodeResponse(tc.body, tc.value) - if (err != nil) != tc.hasError { - t.Errorf("Unexpected error: %v", err) + if tc.hasError { + checks.HasError(t, err, "Unexpected nil error") + return } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + assertEqual(t, tc.expected, tc.value) }) } }