fix(audio): fix audioTextResponse decode (#638)
* fix(audio): fix audioTextResponse decode * test(audio): add audioTextResponse decode test * test(audio): simplify code
This commit is contained in:
10
client.go
10
client.go
@@ -193,11 +193,15 @@ func decodeResponse(body io.Reader, v any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if result, ok := v.(*string); ok {
|
||||
return decodeString(body, result)
|
||||
}
|
||||
switch o := v.(type) {
|
||||
case *string:
|
||||
return decodeString(body, o)
|
||||
case *audioTextResponse:
|
||||
return decodeString(body, &o.Text)
|
||||
default:
|
||||
return json.NewDecoder(body).Decode(v)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeString(body io.Reader, output *string) error {
|
||||
b, err := io.ReadAll(body)
|
||||
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
var errTestRequestBuilderFailed = errors.New("test request builder failed")
|
||||
@@ -43,6 +45,7 @@ func TestDecodeResponse(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
expected interface{}
|
||||
body io.Reader
|
||||
hasError bool
|
||||
}{
|
||||
@@ -50,16 +53,21 @@ func TestDecodeResponse(t *testing.T) {
|
||||
name: "nil input",
|
||||
value: nil,
|
||||
body: bytes.NewReader([]byte("")),
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "string input",
|
||||
value: &stringInput,
|
||||
body: bytes.NewReader([]byte("test")),
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "map input",
|
||||
value: &map[string]interface{}{},
|
||||
body: bytes.NewReader([]byte(`{"test": "test"}`)),
|
||||
expected: map[string]interface{}{
|
||||
"test": "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reader return error",
|
||||
@@ -67,14 +75,38 @@ func TestDecodeResponse(t *testing.T) {
|
||||
body: &errorReader{err: errors.New("dummy")},
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "audio text input",
|
||||
value: &audioTextResponse{},
|
||||
body: bytes.NewReader([]byte("test")),
|
||||
expected: audioTextResponse{
|
||||
Text: "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assertEqual := func(t *testing.T, expected, actual interface{}) {
|
||||
t.Helper()
|
||||
if expected == actual {
|
||||
return
|
||||
}
|
||||
v := reflect.ValueOf(actual).Elem().Interface()
|
||||
if !reflect.DeepEqual(v, expected) {
|
||||
t.Fatalf("Unexpected value: %v, expected: %v", v, expected)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := decodeResponse(tc.body, tc.value)
|
||||
if (err != nil) != tc.hasError {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
if tc.hasError {
|
||||
checks.HasError(t, err, "Unexpected nil error")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
assertEqual(t, tc.expected, tc.value)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user