Support get http header and x-ratelimit-* headers (#507)
* feat: add headers to http response * feat: support rate limit headers * fix: go lint * fix: test coverage * refactor streamReader * refactor streamReader * refactor: NewRateLimitHeaders to newRateLimitHeaders * refactor: RateLimitHeaders Resets filed * refactor: move RateLimitHeaders struct
This commit is contained in:
@@ -1,15 +1,17 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
||||||
@@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
t.Logf("%+v\n", apiErr)
|
t.Logf("%+v\n", apiErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||||
|
|
||||||
|
// Send test responses
|
||||||
|
//nolint:lll
|
||||||
|
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
|
||||||
|
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
|
||||||
|
|
||||||
|
_, err := w.Write(dataBytes)
|
||||||
|
checks.NoError(t, err, "Write error")
|
||||||
|
})
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Dot5Turbo,
|
||||||
|
Messages: []ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
value := stream.Header().Get(xCustomHeader)
|
||||||
|
if value != xCustomHeaderValue {
|
||||||
|
t.Errorf("expected %s to be %s", xCustomHeaderValue, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
for k, v := range rateLimitHeaders {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case int:
|
||||||
|
w.Header().Set(k, strconv.Itoa(val))
|
||||||
|
default:
|
||||||
|
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send test responses
|
||||||
|
//nolint:lll
|
||||||
|
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
|
||||||
|
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
|
||||||
|
|
||||||
|
_, err := w.Write(dataBytes)
|
||||||
|
checks.NoError(t, err, "Write error")
|
||||||
|
})
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Dot5Turbo,
|
||||||
|
Messages: []ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
headers := stream.GetRateLimitHeaders()
|
||||||
|
bs1, _ := json.Marshal(headers)
|
||||||
|
bs2, _ := json.Marshal(rateLimitHeaders)
|
||||||
|
if string(bs1) != string(bs2) {
|
||||||
|
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|||||||
53
chat_test.go
53
chat_test.go
@@ -21,6 +21,17 @@ const (
|
|||||||
xCustomHeaderValue = "test"
|
xCustomHeaderValue = "test"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
rateLimitHeaders = map[string]any{
|
||||||
|
"x-ratelimit-limit-requests": 60,
|
||||||
|
"x-ratelimit-limit-tokens": 150000,
|
||||||
|
"x-ratelimit-remaining-requests": 59,
|
||||||
|
"x-ratelimit-remaining-tokens": 149984,
|
||||||
|
"x-ratelimit-reset-requests": "1s",
|
||||||
|
"x-ratelimit-reset-tokens": "6m0s",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func TestChatCompletionsWrongModel(t *testing.T) {
|
func TestChatCompletionsWrongModel(t *testing.T) {
|
||||||
config := DefaultConfig("whatever")
|
config := DefaultConfig("whatever")
|
||||||
config.BaseURL = "http://localhost/v1"
|
config.BaseURL = "http://localhost/v1"
|
||||||
@@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server.
|
||||||
|
func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||||
|
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Dot5Turbo,
|
||||||
|
Messages: []ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
|
|
||||||
|
headers := resp.GetRateLimitHeaders()
|
||||||
|
resetRequests := headers.ResetRequests.String()
|
||||||
|
if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] {
|
||||||
|
t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"])
|
||||||
|
}
|
||||||
|
resetRequestsTime := headers.ResetRequests.Time()
|
||||||
|
if resetRequestsTime.Before(time.Now()) {
|
||||||
|
t.Errorf("unexpected reset requetsts: %v", resetRequestsTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
bs1, _ := json.Marshal(headers)
|
||||||
|
bs2, _ := json.Marshal(rateLimitHeaders)
|
||||||
|
if string(bs1) != string(bs2) {
|
||||||
|
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestChatCompletionsFunctions tests including a function call.
|
// TestChatCompletionsFunctions tests including a function call.
|
||||||
func TestChatCompletionsFunctions(t *testing.T) {
|
func TestChatCompletionsFunctions(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
@@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
resBytes, _ = json.Marshal(res)
|
resBytes, _ = json.Marshal(res)
|
||||||
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||||
|
for k, v := range rateLimitHeaders {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case int:
|
||||||
|
w.Header().Set(k, strconv.Itoa(val))
|
||||||
|
default:
|
||||||
|
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||||
|
}
|
||||||
|
}
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) {
|
|||||||
*h = httpHeader(header)
|
*h = httpHeader(header)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h httpHeader) Header() http.Header {
|
func (h *httpHeader) Header() http.Header {
|
||||||
return http.Header(h)
|
return http.Header(*h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders {
|
||||||
|
return newRateLimitHeaders(h.Header())
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates new OpenAI API client.
|
// NewClient creates new OpenAI API client.
|
||||||
@@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
|
|||||||
response: resp,
|
response: resp,
|
||||||
errAccumulator: utils.NewErrorAccumulator(),
|
errAccumulator: utils.NewErrorAccumulator(),
|
||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
|
httpHeader: httpHeader(resp.Header),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
43
ratelimit.go
Normal file
43
ratelimit.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimitHeaders struct represents Openai rate limits headers.
|
||||||
|
type RateLimitHeaders struct {
|
||||||
|
LimitRequests int `json:"x-ratelimit-limit-requests"`
|
||||||
|
LimitTokens int `json:"x-ratelimit-limit-tokens"`
|
||||||
|
RemainingRequests int `json:"x-ratelimit-remaining-requests"`
|
||||||
|
RemainingTokens int `json:"x-ratelimit-remaining-tokens"`
|
||||||
|
ResetRequests ResetTime `json:"x-ratelimit-reset-requests"`
|
||||||
|
ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResetTime string
|
||||||
|
|
||||||
|
func (r ResetTime) String() string {
|
||||||
|
return string(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ResetTime) Time() time.Time {
|
||||||
|
d, _ := time.ParseDuration(string(r))
|
||||||
|
return time.Now().Add(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRateLimitHeaders(h http.Header) RateLimitHeaders {
|
||||||
|
limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests"))
|
||||||
|
limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens"))
|
||||||
|
remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests"))
|
||||||
|
remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens"))
|
||||||
|
return RateLimitHeaders{
|
||||||
|
LimitRequests: limitReq,
|
||||||
|
LimitTokens: limitTokens,
|
||||||
|
RemainingRequests: remainingReq,
|
||||||
|
RemainingTokens: remainingTokens,
|
||||||
|
ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")),
|
||||||
|
ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -27,6 +27,8 @@ type streamReader[T streamable] struct {
|
|||||||
response *http.Response
|
response *http.Response
|
||||||
errAccumulator utils.ErrorAccumulator
|
errAccumulator utils.ErrorAccumulator
|
||||||
unmarshaler utils.Unmarshaler
|
unmarshaler utils.Unmarshaler
|
||||||
|
|
||||||
|
httpHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *streamReader[T]) Recv() (response T, err error) {
|
func (stream *streamReader[T]) Recv() (response T, err error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user