Support get http header and x-ratelimit-* headers (#507)
* feat: add headers to http response * feat: support rate limit headers * fix: go lint * fix: test coverage * refactor streamReader * refactor streamReader * refactor: NewRateLimitHeaders to newRateLimitHeaders * refactor: RateLimitHeaders Resets filed * refactor: move RateLimitHeaders struct
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
package openai_test
|
||||
|
||||
import (
|
||||
. "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
. "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
||||
@@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
||||
t.Logf("%+v\n", apiErr)
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||
|
||||
// Send test responses
|
||||
//nolint:lll
|
||||
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
|
||||
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
checks.NoError(t, err, "Write error")
|
||||
})
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: GPT3Dot5Turbo,
|
||||
Messages: []ChatCompletionMessage{
|
||||
{
|
||||
Role: ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||
defer stream.Close()
|
||||
|
||||
value := stream.Header().Get(xCustomHeader)
|
||||
if value != xCustomHeaderValue {
|
||||
t.Errorf("expected %s to be %s", xCustomHeaderValue, value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
for k, v := range rateLimitHeaders {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
w.Header().Set(k, strconv.Itoa(val))
|
||||
default:
|
||||
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||
}
|
||||
}
|
||||
|
||||
// Send test responses
|
||||
//nolint:lll
|
||||
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
|
||||
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
checks.NoError(t, err, "Write error")
|
||||
})
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: GPT3Dot5Turbo,
|
||||
Messages: []ChatCompletionMessage{
|
||||
{
|
||||
Role: ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||
defer stream.Close()
|
||||
|
||||
headers := stream.GetRateLimitHeaders()
|
||||
bs1, _ := json.Marshal(headers)
|
||||
bs2, _ := json.Marshal(rateLimitHeaders)
|
||||
if string(bs1) != string(bs2) {
|
||||
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
|
||||
53
chat_test.go
53
chat_test.go
@@ -21,6 +21,17 @@ const (
|
||||
xCustomHeaderValue = "test"
|
||||
)
|
||||
|
||||
var (
|
||||
rateLimitHeaders = map[string]any{
|
||||
"x-ratelimit-limit-requests": 60,
|
||||
"x-ratelimit-limit-tokens": 150000,
|
||||
"x-ratelimit-remaining-requests": 59,
|
||||
"x-ratelimit-remaining-tokens": 149984,
|
||||
"x-ratelimit-reset-requests": "1s",
|
||||
"x-ratelimit-reset-tokens": "6m0s",
|
||||
}
|
||||
)
|
||||
|
||||
func TestChatCompletionsWrongModel(t *testing.T) {
|
||||
config := DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
@@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server.
|
||||
func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: GPT3Dot5Turbo,
|
||||
Messages: []ChatCompletionMessage{
|
||||
{
|
||||
Role: ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
})
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
|
||||
headers := resp.GetRateLimitHeaders()
|
||||
resetRequests := headers.ResetRequests.String()
|
||||
if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] {
|
||||
t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"])
|
||||
}
|
||||
resetRequestsTime := headers.ResetRequests.Time()
|
||||
if resetRequestsTime.Before(time.Now()) {
|
||||
t.Errorf("unexpected reset requetsts: %v", resetRequestsTime)
|
||||
}
|
||||
|
||||
bs1, _ := json.Marshal(headers)
|
||||
bs2, _ := json.Marshal(rateLimitHeaders)
|
||||
if string(bs1) != string(bs2) {
|
||||
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatCompletionsFunctions tests including a function call.
|
||||
func TestChatCompletionsFunctions(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
@@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
resBytes, _ = json.Marshal(res)
|
||||
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||
for k, v := range rateLimitHeaders {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
w.Header().Set(k, strconv.Itoa(val))
|
||||
default:
|
||||
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
|
||||
@@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) {
|
||||
*h = httpHeader(header)
|
||||
}
|
||||
|
||||
func (h httpHeader) Header() http.Header {
|
||||
return http.Header(h)
|
||||
func (h *httpHeader) Header() http.Header {
|
||||
return http.Header(*h)
|
||||
}
|
||||
|
||||
func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders {
|
||||
return newRateLimitHeaders(h.Header())
|
||||
}
|
||||
|
||||
// NewClient creates new OpenAI API client.
|
||||
@@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
|
||||
response: resp,
|
||||
errAccumulator: utils.NewErrorAccumulator(),
|
||||
unmarshaler: &utils.JSONUnmarshaler{},
|
||||
httpHeader: httpHeader(resp.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
43
ratelimit.go
Normal file
43
ratelimit.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimitHeaders struct represents Openai rate limits headers.
|
||||
type RateLimitHeaders struct {
|
||||
LimitRequests int `json:"x-ratelimit-limit-requests"`
|
||||
LimitTokens int `json:"x-ratelimit-limit-tokens"`
|
||||
RemainingRequests int `json:"x-ratelimit-remaining-requests"`
|
||||
RemainingTokens int `json:"x-ratelimit-remaining-tokens"`
|
||||
ResetRequests ResetTime `json:"x-ratelimit-reset-requests"`
|
||||
ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"`
|
||||
}
|
||||
|
||||
type ResetTime string
|
||||
|
||||
func (r ResetTime) String() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func (r ResetTime) Time() time.Time {
|
||||
d, _ := time.ParseDuration(string(r))
|
||||
return time.Now().Add(d)
|
||||
}
|
||||
|
||||
func newRateLimitHeaders(h http.Header) RateLimitHeaders {
|
||||
limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests"))
|
||||
limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens"))
|
||||
remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests"))
|
||||
remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens"))
|
||||
return RateLimitHeaders{
|
||||
LimitRequests: limitReq,
|
||||
LimitTokens: limitTokens,
|
||||
RemainingRequests: remainingReq,
|
||||
RemainingTokens: remainingTokens,
|
||||
ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")),
|
||||
ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")),
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,8 @@ type streamReader[T streamable] struct {
|
||||
response *http.Response
|
||||
errAccumulator utils.ErrorAccumulator
|
||||
unmarshaler utils.Unmarshaler
|
||||
|
||||
httpHeader
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Recv() (response T, err error) {
|
||||
|
||||
Reference in New Issue
Block a user