diff --git a/client.go b/client.go index 368947b..500b3d5 100644 --- a/client.go +++ b/client.go @@ -148,6 +148,9 @@ func (c *Client) handleErrorResp(resp *http.Response) error { HTTPStatusCode: resp.StatusCode, Err: err, } + if errRes.Error != nil { + reqErr.Err = errRes.Error + } return fmt.Errorf("error, %w", &reqErr) } errRes.Error.HTTPStatusCode = resp.StatusCode diff --git a/client_test.go b/client_test.go index 7bea6dd..ca5145c 100644 --- a/client_test.go +++ b/client_test.go @@ -2,7 +2,9 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "fmt" "io" + "net/http" "testing" ) @@ -57,3 +59,81 @@ func TestDecodeResponse(t *testing.T) { }) } } + +func TestHandleErrorResp(t *testing.T) { + // var errRes *ErrorResponse + var errRes ErrorResponse + var reqErr RequestError + t.Log(errRes, errRes.Error) + if errRes.Error != nil { + reqErr.Err = errRes.Error + } + t.Log(fmt.Errorf("error, %w", &reqErr)) + t.Log(errRes.Error, "nil pointer check Pass") + + const mockToken = "mock token" + client := NewClient(mockToken) + + testCases := []struct { + name string + httpCode int + body io.Reader + expected string + }{ + { + name: "401 Invalid Authentication", + httpCode: http.StatusUnauthorized, + body: bytes.NewReader([]byte( + `{ + "error":{ + "message":"You didn't provide an API key. ....", + "type":"invalid_request_error", + "param":null, + "code":null + } + }`, + )), + expected: "error, status code: 401, message: You didn't provide an API key. ....", + }, + { + name: "401 Azure Access Denied", + httpCode: http.StatusUnauthorized, + body: bytes.NewReader([]byte( + `{ + "error":{ + "code":"AccessDenied", + "message":"Access denied due to Virtual Network/Firewall rules." + } + }`, + )), + expected: "error, Access denied due to Virtual Network/Firewall rules.", + }, + { + name: "503 Model Overloaded", + httpCode: http.StatusServiceUnavailable, + body: bytes.NewReader([]byte(` + { + "error":{ + "message":"That model...", + "type":"server_error", + "param":null, + "code":null + } + }`)), + expected: "error, status code: 503, message: That model...", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testCase := &http.Response{} + testCase.StatusCode = tc.httpCode + testCase.Body = io.NopCloser(tc.body) + err := client.handleErrorResp(testCase) + if err.Error() != tc.expected { + t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected) + t.Fail() + } + }) + } +}