diff --git a/chat_stream_test.go b/chat_stream_test.go index 77d373c..19c2e3c 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -255,6 +255,67 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { + wantCode := "429" + wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " + + "version 2023-03-15-preview have exceeded token rate limit of your current OpenAI S0 pricing tier. " + + "Please retry after 20 seconds. " + + "Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit." + + server := test.NewTestServer() + server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + // Send test responses + dataBytes := []byte(`{"error": { "code": "` + wantCode + `", "message": "` + wantMessage + `"}}`) + _, err := w.Write(dataBytes) + + checks.NoError(t, err, "Write error") + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultAzureConfig(test.GetTestToken(), ts.URL) + client := NewClientWithConfig(config) + ctx := context.Background() + + request := ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + } + + apiErr := &APIError{} + _, err = client.CreateChatCompletionStream(ctx, request) + if !errors.As(err, &apiErr) { + t.Errorf("Did not return APIError: %+v\n", apiErr) + return + } + if apiErr.HTTPStatusCode != http.StatusTooManyRequests { + t.Errorf("Did not return HTTPStatusCode got = %d, want = %d\n", apiErr.HTTPStatusCode, http.StatusTooManyRequests) + return + } + code, ok := apiErr.Code.(string) + if !ok || code != wantCode { + t.Errorf("Did not return Code. got = %v, want = %s\n", apiErr.Code, wantCode) + return + } + if apiErr.Message != wantMessage { + t.Errorf("Did not return Message. got = %s, want = %s\n", apiErr.Message, wantMessage) + return + } +} + func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) { var err error server := test.NewTestServer() diff --git a/client_test.go b/client_test.go index 5e63539..c96ceb7 100644 --- a/client_test.go +++ b/client_test.go @@ -134,6 +134,15 @@ func TestHandleErrorResp(t *testing.T) { }`)), expected: "error, status code: 503, message: That model...", }, + { + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + body: bytes.NewReader([]byte(` + { + "error":{} + }`)), + expected: "error, status code: 503, message: ", + }, } for _, tc := range testCases { diff --git a/error.go b/error.go index 6354f43..b789ed7 100644 --- a/error.go +++ b/error.go @@ -44,9 +44,13 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { return } - err = json.Unmarshal(rawMap["type"], &e.Type) - if err != nil { - return + // optional fields for azure openai + // refs: https://github.com/sashabaranov/go-openai/issues/343 + if _, ok := rawMap["type"]; ok { + err = json.Unmarshal(rawMap["type"], &e.Type) + if err != nil { + return + } } // optional fields