Support get http header and x-ratelimit-* headers (#507)

* feat: add headers to http response

* feat: support rate limit headers

* fix: go lint

* fix: test coverage

* refactor streamReader

* refactor streamReader

* refactor: NewRateLimitHeaders to newRateLimitHeaders

* refactor: RateLimitHeaders Resets filed

* refactor: move RateLimitHeaders struct
This commit is contained in:
Liu Shuang
2023-10-10 10:29:41 -05:00
committed by GitHub
parent 8e165dc9aa
commit b77d01edca
5 changed files with 191 additions and 5 deletions

View File

@@ -1,15 +1,17 @@
package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"testing"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
func TestChatCompletionsStreamWrongModel(t *testing.T) {
@@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}
func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set(xCustomHeader, xCustomHeaderValue)
// Send test responses
//nolint:lll
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
value := stream.Header().Get(xCustomHeader)
if value != xCustomHeaderValue {
t.Errorf("expected %s to be %s", xCustomHeaderValue, value)
}
}
func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}
// Send test responses
//nolint:lll
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
headers := stream.GetRateLimitHeaders()
bs1, _ := json.Marshal(headers)
bs2, _ := json.Marshal(rateLimitHeaders)
if string(bs1) != string(bs2) {
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
}
}
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()