diff --git a/api.go b/api.go index c339afe..0a7bf15 100644 --- a/api.go +++ b/api.go @@ -66,9 +66,14 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { var errRes ErrorResponse err = json.NewDecoder(res.Body).Decode(&errRes) if err != nil || errRes.Error == nil { - return fmt.Errorf("error, status code: %d", res.StatusCode) + reqErr := RequestError{ + StatusCode: res.StatusCode, + Err: err, + } + return fmt.Errorf("error, %w", &reqErr) } - return fmt.Errorf("error, status code: %d, message: %s", res.StatusCode, errRes.Error.Message) + errRes.Error.StatusCode = res.StatusCode + return fmt.Errorf("error, status code: %d, message: %w", res.StatusCode, errRes.Error) } if v != nil { diff --git a/api_test.go b/api_test.go index 7843bef..4c8732f 100644 --- a/api_test.go +++ b/api_test.go @@ -81,6 +81,53 @@ func TestAPI(t *testing.T) { } } +func TestAPIError(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken + "_invalid") + ctx := context.Background() + _, err = c.ListEngines(ctx) + if err == nil { + t.Fatal("ListEngines did not fail") + } + + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Error is not an APIError: %+v", err) + } + + if apiErr.StatusCode != 401 { + t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) + } + if *apiErr.Code != "invalid_api_key" { + t.Fatalf("Unexpected API error code: %s", *apiErr.Code) + } +} + +func TestRequestError(t *testing.T) { + var err error + c := NewClient("dummy") + c.BaseURL = "https://httpbin.org/status/418?" + ctx := context.Background() + _, err = c.ListEngines(ctx) + if err == nil { + t.Fatal("ListEngines request did not fail") + } + + var reqErr *RequestError + if !errors.As(err, &reqErr) { + t.Fatalf("Error is not a RequestError: %+v", err) + } + + if reqErr.StatusCode != 418 { + t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) + } +} + // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer diff --git a/error.go b/error.go index 4d0a324..927fafd 100644 --- a/error.go +++ b/error.go @@ -1,10 +1,37 @@ package gogpt -type ErrorResponse struct { - Error *struct { - Code *int `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - } `json:"error,omitempty"` +import "fmt" + +// APIError provides error information returned by the OpenAI API. +type APIError struct { + Code *string `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` + StatusCode int `json:"-"` +} + +// RequestError provides informations about generic request errors. +type RequestError struct { + StatusCode int + Err error +} + +type ErrorResponse struct { + Error *APIError `json:"error,omitempty"` +} + +func (e *APIError) Error() string { + return e.Message +} + +func (e *RequestError) Error() string { + if e.Err != nil { + return e.Err.Error() + } + return fmt.Sprintf("status code %d", e.StatusCode) +} + +func (e *RequestError) Unwrap() error { + return e.Err }