run_id string Optional (#855)

Filter messages by the run ID that generated them.

Co-authored-by: wappi <support@wappi.pro>
This commit is contained in:
floodwm
2024-09-20 23:54:25 +03:00
committed by GitHub
parent 9a4f3a7dbf
commit e095df5325
4 changed files with 9 additions and 3 deletions

0
.zshrc Normal file
View File

View File

@@ -340,7 +340,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
return client.CreateMessage(ctx, "", MessageRequest{}) return client.CreateMessage(ctx, "", MessageRequest{})
}}, }},
{"ListMessage", func() (any, error) { {"ListMessage", func() (any, error) {
return client.ListMessage(ctx, "", nil, nil, nil, nil) return client.ListMessage(ctx, "", nil, nil, nil, nil, nil)
}}, }},
{"RetrieveMessage", func() (any, error) { {"RetrieveMessage", func() (any, error) {
return client.RetrieveMessage(ctx, "", "") return client.RetrieveMessage(ctx, "", "")

View File

@@ -100,6 +100,7 @@ func (c *Client) ListMessage(ctx context.Context, threadID string,
order *string, order *string,
after *string, after *string,
before *string, before *string,
runID *string,
) (messages MessagesList, err error) { ) (messages MessagesList, err error) {
urlValues := url.Values{} urlValues := url.Values{}
if limit != nil { if limit != nil {
@@ -114,6 +115,10 @@ func (c *Client) ListMessage(ctx context.Context, threadID string,
if before != nil { if before != nil {
urlValues.Add("before", *before) urlValues.Add("before", *before)
} }
if runID != nil {
urlValues.Add("run_id", *runID)
}
encodedValues := "" encodedValues := ""
if len(urlValues) > 0 { if len(urlValues) > 0 {
encodedValues = "?" + urlValues.Encode() encodedValues = "?" + urlValues.Encode()

View File

@@ -208,7 +208,7 @@ func TestMessages(t *testing.T) {
} }
var msgs openai.MessagesList var msgs openai.MessagesList
msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil, nil)
checks.NoError(t, err, "ListMessages error") checks.NoError(t, err, "ListMessages error")
if len(msgs.Messages) != 1 { if len(msgs.Messages) != 1 {
t.Fatalf("unexpected length of fetched messages") t.Fatalf("unexpected length of fetched messages")
@@ -219,7 +219,8 @@ func TestMessages(t *testing.T) {
order := "desc" order := "desc"
after := "obj_foo" after := "obj_foo"
before := "obj_bar" before := "obj_bar"
msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) runID := "run_abc123"
msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before, &runID)
checks.NoError(t, err, "ListMessages error") checks.NoError(t, err, "ListMessages error")
if len(msgs.Messages) != 1 { if len(msgs.Messages) != 1 {
t.Fatalf("unexpected length of fetched messages") t.Fatalf("unexpected length of fetched messages")