102
api_test.go
102
api_test.go
@@ -1,6 +1,8 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
@@ -110,7 +112,7 @@ func TestAPIError(t *testing.T) {
|
|||||||
c := NewClient(apiToken + "_invalid")
|
c := NewClient(apiToken + "_invalid")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
_, err = c.ListEngines(ctx)
|
_, err = c.ListEngines(ctx)
|
||||||
checks.NoError(t, err, "ListEngines did not fail")
|
checks.HasError(t, err, "ListEngines should fail with an invalid key")
|
||||||
|
|
||||||
var apiErr *APIError
|
var apiErr *APIError
|
||||||
if !errors.As(err, &apiErr) {
|
if !errors.As(err, &apiErr) {
|
||||||
@@ -120,14 +122,108 @@ func TestAPIError(t *testing.T) {
|
|||||||
if apiErr.StatusCode != 401 {
|
if apiErr.StatusCode != 401 {
|
||||||
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
|
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode)
|
||||||
}
|
}
|
||||||
if *apiErr.Code != "invalid_api_key" {
|
|
||||||
t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
|
switch v := apiErr.Code.(type) {
|
||||||
|
case string:
|
||||||
|
if v != "invalid_api_key" {
|
||||||
|
t.Fatalf("Unexpected API error code: %s", v)
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
t.Fatalf("Unexpected API error code type: %T", v)
|
||||||
|
}
|
||||||
|
|
||||||
if apiErr.Error() == "" {
|
if apiErr.Error() == "" {
|
||||||
t.Fatal("Empty error message occurred")
|
t.Fatal("Empty error message occurred")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAPIErrorUnmarshalJSONInteger(t *testing.T) {
|
||||||
|
var apiErr APIError
|
||||||
|
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
|
||||||
|
err := json.Unmarshal([]byte(response), &apiErr)
|
||||||
|
checks.NoError(t, err, "Unexpected Unmarshal API response error")
|
||||||
|
|
||||||
|
switch v := apiErr.Code.(type) {
|
||||||
|
case int:
|
||||||
|
if v != 418 {
|
||||||
|
t.Fatalf("Unexpected API code integer: %d; expected 418", v)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Fatalf("Unexpected API error code type: %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIErrorUnmarshalJSONString(t *testing.T) {
|
||||||
|
var apiErr APIError
|
||||||
|
response := `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
|
||||||
|
err := json.Unmarshal([]byte(response), &apiErr)
|
||||||
|
checks.NoError(t, err, "Unexpected Unmarshal API response error")
|
||||||
|
|
||||||
|
switch v := apiErr.Code.(type) {
|
||||||
|
case string:
|
||||||
|
if v != "teapot" {
|
||||||
|
t.Fatalf("Unexpected API code string: %s; expected `teapot`", v)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Fatalf("Unexpected API error code type: %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIErrorUnmarshalJSONNoCode(t *testing.T) {
|
||||||
|
// test integer code
|
||||||
|
response := `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
|
||||||
|
var apiErr APIError
|
||||||
|
err := json.Unmarshal([]byte(response), &apiErr)
|
||||||
|
checks.NoError(t, err, "Unexpected Unmarshal API response error")
|
||||||
|
|
||||||
|
switch v := apiErr.Code.(type) {
|
||||||
|
case nil:
|
||||||
|
default:
|
||||||
|
t.Fatalf("Unexpected API error code type: %T", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIErrorUnmarshalInvalidData(t *testing.T) {
|
||||||
|
apiErr := APIError{}
|
||||||
|
data := []byte(`--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`)
|
||||||
|
err := apiErr.UnmarshalJSON(data)
|
||||||
|
checks.HasError(t, err, "Expected error when unmarshaling invalid data")
|
||||||
|
|
||||||
|
if apiErr.Code != nil {
|
||||||
|
t.Fatalf("Expected nil code, got %q", apiErr.Code)
|
||||||
|
}
|
||||||
|
if apiErr.Message != "" {
|
||||||
|
t.Fatalf("Expected empty message, got %q", apiErr.Message)
|
||||||
|
}
|
||||||
|
if apiErr.Param != nil {
|
||||||
|
t.Fatalf("Expected nil param, got %q", *apiErr.Param)
|
||||||
|
}
|
||||||
|
if apiErr.Type != "" {
|
||||||
|
t.Fatalf("Expected empty type, got %q", apiErr.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIErrorUnmarshalJSONInvalidParam(t *testing.T) {
|
||||||
|
var apiErr APIError
|
||||||
|
response := `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}`
|
||||||
|
err := json.Unmarshal([]byte(response), &apiErr)
|
||||||
|
checks.HasError(t, err, "Param should be a string")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) {
|
||||||
|
var apiErr APIError
|
||||||
|
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}`
|
||||||
|
err := json.Unmarshal([]byte(response), &apiErr)
|
||||||
|
checks.HasError(t, err, "Type should be a string")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
|
||||||
|
var apiErr APIError
|
||||||
|
response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}`
|
||||||
|
err := json.Unmarshal([]byte(response), &apiErr)
|
||||||
|
checks.HasError(t, err, "Message should be a string")
|
||||||
|
}
|
||||||
|
|
||||||
func TestRequestError(t *testing.T) {
|
func TestRequestError(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
|||||||
48
error.go
48
error.go
@@ -1,10 +1,13 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
// APIError provides error information returned by the OpenAI API.
|
// APIError provides error information returned by the OpenAI API.
|
||||||
type APIError struct {
|
type APIError struct {
|
||||||
Code *string `json:"code,omitempty"`
|
Code any `json:"code,omitempty"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Param *string `json:"param,omitempty"`
|
Param *string `json:"param,omitempty"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -25,6 +28,47 @@ func (e *APIError) Error() string {
|
|||||||
return e.Message
|
return e.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *APIError) UnmarshalJSON(data []byte) (err error) {
|
||||||
|
var rawMap map[string]json.RawMessage
|
||||||
|
err = json.Unmarshal(data, &rawMap)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(rawMap["message"], &e.Message)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(rawMap["type"], &e.Type)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// optional fields
|
||||||
|
if _, ok := rawMap["param"]; ok {
|
||||||
|
err = json.Unmarshal(rawMap["param"], &e.Param)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := rawMap["code"]; !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the api returned a number, we need to force an integer
|
||||||
|
// since the json package defaults to float64
|
||||||
|
var intCode int
|
||||||
|
err = json.Unmarshal(rawMap["code"], &intCode)
|
||||||
|
if err == nil {
|
||||||
|
e.Code = intCode
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(rawMap["code"], &e.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *RequestError) Error() string {
|
func (e *RequestError) Error() string {
|
||||||
if e.Err != nil {
|
if e.Err != nil {
|
||||||
return e.Err.Error()
|
return e.Err.Error()
|
||||||
|
|||||||
Reference in New Issue
Block a user