diff --git a/chat.go b/chat.go index efb14fd..a1eb117 100644 --- a/chat.go +++ b/chat.go @@ -216,6 +216,16 @@ type ChatCompletionRequest struct { Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` +} + +type StreamOptions struct { + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. + IncludeUsage bool `json:"include_usage,omitempty"` } type ToolType string diff --git a/chat_stream.go b/chat_stream.go index 159f9f4..ffd512f 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -33,6 +33,10 @@ type ChatCompletionStreamResponse struct { SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + // An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + // When present, it contains a null value except for the last chunk which contains the token usage statistics + // for the entire request. + Usage *Usage `json:"usage,omitempty"` } // ChatCompletionStream diff --git a/chat_stream_test.go b/chat_stream_test.go index bd1c737..63e45ee 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -388,6 +388,120 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } } +func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + var dataBytes []byte + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + StreamOptions: &openai.StreamOptions{ + IncludeUsage: true, + }, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "3", + Object: "completion", + Created: 1598069256, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{}, + Usage: &openai.Usage{ + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished") + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) + } +} + // Helper funcs. func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { @@ -401,6 +515,15 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return false } } + if r1.Usage != nil || r2.Usage != nil { + if r1.Usage == nil || r2.Usage == nil { + return false + } + if r1.Usage.PromptTokens != r2.Usage.PromptTokens || r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || + r1.Usage.TotalTokens != r2.Usage.TotalTokens { + return false + } + } return true }