diff --git a/README.md b/README.md index 7526ea3..f7e6990 100644 --- a/README.md +++ b/README.md @@ -10,13 +10,13 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/). * DALLĀ·E 2 * Whisper -Installation: +### Installation: ``` go get github.com/sashabaranov/go-openai ``` -ChatGPT example usage: +### ChatGPT example usage: ```go package main @@ -52,9 +52,7 @@ func main() { ``` - - -Other examples: +### Other examples:
ChatGPT streaming completion @@ -462,3 +460,29 @@ func main() { } ```
+ +
+Error handling + +Open-AI maintains clear documentation on how to [handle API errors](https://platform.openai.com/docs/guides/error-codes/api-errors) + +example: +``` +e := &openai.APIError{} +if errors.As(err, &e) { + switch e.HTTPStatusCode { + case 401: + // invalid auth or key (do not retry) + case 429: + // rate limiting or engine overload (wait and retry) + case 500: + // openai server error (retry) + default: + // unhandled + } +} + +``` +
+ + diff --git a/client.go b/client.go index b3d7595..0f8aa41 100644 --- a/client.go +++ b/client.go @@ -149,15 +149,16 @@ func (c *Client) handleErrorResp(resp *http.Response) error { var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) if err != nil || errRes.Error == nil { - reqErr := RequestError{ + reqErr := &RequestError{ HTTPStatusCode: resp.StatusCode, Err: err, } if errRes.Error != nil { reqErr.Err = errRes.Error } - return fmt.Errorf("error, %w", &reqErr) + return reqErr } + errRes.Error.HTTPStatusCode = resp.StatusCode - return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error) + return errRes.Error } diff --git a/client_test.go b/client_test.go index 7ef6284..e30fa39 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "errors" "fmt" "io" "net/http" @@ -106,7 +107,7 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code 401, message: Access denied due to Virtual Network/Firewall rules.", + expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.", }, { name: "503 Model Overloaded", @@ -135,6 +136,12 @@ func TestHandleErrorResp(t *testing.T) { t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected) t.Fail() } + + e := &APIError{} + if !errors.As(err, &e) { + t.Errorf("(%s) Expected error to be of type APIError", tc.name) + t.Fail() + } }) } } diff --git a/error.go b/error.go index 86b75f4..6354f43 100644 --- a/error.go +++ b/error.go @@ -25,6 +25,10 @@ type ErrorResponse struct { } func (e *APIError) Error() string { + if e.HTTPStatusCode > 0 { + return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message) + } + return e.Message } @@ -70,7 +74,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("status code %d, message: %s", e.HTTPStatusCode, e.Err) + return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err) } func (e *RequestError) Unwrap() error {