diff --git a/README.md b/README.md
index 7526ea3..f7e6990 100644
--- a/README.md
+++ b/README.md
@@ -10,13 +10,13 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/).
* DALLĀ·E 2
* Whisper
-Installation:
+### Installation:
```
go get github.com/sashabaranov/go-openai
```
-ChatGPT example usage:
+### ChatGPT example usage:
```go
package main
@@ -52,9 +52,7 @@ func main() {
```
-
-
-Other examples:
+### Other examples:
ChatGPT streaming completion
@@ -462,3 +460,29 @@ func main() {
}
```
+
+
+Error handling
+
+Open-AI maintains clear documentation on how to [handle API errors](https://platform.openai.com/docs/guides/error-codes/api-errors)
+
+example:
+```
+e := &openai.APIError{}
+if errors.As(err, &e) {
+ switch e.HTTPStatusCode {
+ case 401:
+ // invalid auth or key (do not retry)
+ case 429:
+ // rate limiting or engine overload (wait and retry)
+ case 500:
+ // openai server error (retry)
+ default:
+ // unhandled
+ }
+}
+
+```
+
+
+
diff --git a/client.go b/client.go
index b3d7595..0f8aa41 100644
--- a/client.go
+++ b/client.go
@@ -149,15 +149,16 @@ func (c *Client) handleErrorResp(resp *http.Response) error {
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
- reqErr := RequestError{
+ reqErr := &RequestError{
HTTPStatusCode: resp.StatusCode,
Err: err,
}
if errRes.Error != nil {
reqErr.Err = errRes.Error
}
- return fmt.Errorf("error, %w", &reqErr)
+ return reqErr
}
+
errRes.Error.HTTPStatusCode = resp.StatusCode
- return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
+ return errRes.Error
}
diff --git a/client_test.go b/client_test.go
index 7ef6284..e30fa39 100644
--- a/client_test.go
+++ b/client_test.go
@@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field
import (
"bytes"
+ "errors"
"fmt"
"io"
"net/http"
@@ -106,7 +107,7 @@ func TestHandleErrorResp(t *testing.T) {
}
}`,
)),
- expected: "error, status code 401, message: Access denied due to Virtual Network/Firewall rules.",
+ expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.",
},
{
name: "503 Model Overloaded",
@@ -135,6 +136,12 @@ func TestHandleErrorResp(t *testing.T) {
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
t.Fail()
}
+
+ e := &APIError{}
+ if !errors.As(err, &e) {
+ t.Errorf("(%s) Expected error to be of type APIError", tc.name)
+ t.Fail()
+ }
})
}
}
diff --git a/error.go b/error.go
index 86b75f4..6354f43 100644
--- a/error.go
+++ b/error.go
@@ -25,6 +25,10 @@ type ErrorResponse struct {
}
func (e *APIError) Error() string {
+ if e.HTTPStatusCode > 0 {
+ return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message)
+ }
+
return e.Message
}
@@ -70,7 +74,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
}
func (e *RequestError) Error() string {
- return fmt.Sprintf("status code %d, message: %s", e.HTTPStatusCode, e.Err)
+ return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err)
}
func (e *RequestError) Unwrap() error {