From f22da8a7ed896d19661dfcce3e330e4b209b2eb3 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Wed, 21 Jun 2023 08:58:27 -0400 Subject: [PATCH] feat: allow more input types to functions, fix tests (#377) * feat: use json.rawMessage, test functions * chore: lint * fix: tests the ChatCompletion mock server doesn't actually run otherwise. N=0 is the default request but the server will treat it as n=1 * fix: tests should default to n=1 completions * chore: add back removed interfaces, custom marshal * chore: lint * chore: lint * chore: add some tests * chore: appease lint * clean up JSON schema + tests * chore: lint * feat: remove backwards compatible functions for illustrative purposes * fix: revert params change * chore: use interface{} * chore: add test * chore: add back FunctionDefine * chore: /s/interface{}/any * chore: add back jsonschemadefinition * chore: testcov * chore: lint * chore: remove pointers * chore: update comment * chore: address CR added test for compatibility as well --------- Co-authored-by: James --- chat.go | 34 +++++----- chat_test.go | 157 ++++++++++++++++++++++++++++++++++++++++++++- completion_test.go | 10 ++- 3 files changed, 180 insertions(+), 21 deletions(-) diff --git a/chat.go b/chat.go index 4764e36..f99af27 100644 --- a/chat.go +++ b/chat.go @@ -54,23 +54,23 @@ type ChatCompletionRequest struct { FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` LogitBias map[string]int `json:"logit_bias,omitempty"` User string `json:"user,omitempty"` - Functions []*FunctionDefine `json:"functions,omitempty"` - FunctionCall string `json:"function_call,omitempty"` + Functions []FunctionDefinition `json:"functions,omitempty"` + FunctionCall any `json:"function_call,omitempty"` } -type FunctionDefine struct { +type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` - // it's required in function call - Parameters *FunctionParams `json:"parameters"` + // Parameters is an object describing the function. + // You can pass a raw byte array describing the schema, + // or you can pass in a struct which serializes to the proper JSONSchema. + // The JSONSchemaDefinition struct is provided for convenience, but you should + // consider another specialized library for more complex schemas. + Parameters any `json:"parameters"` } -type FunctionParams struct { - // the Type must be JSONSchemaTypeObject - Type JSONSchemaType `json:"type"` - Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` -} +// Deprecated: use FunctionDefinition instead. +type FunctionDefine = FunctionDefinition type JSONSchemaType string @@ -83,8 +83,9 @@ const ( JSONSchemaTypeBoolean JSONSchemaType = "boolean" ) -// JSONSchemaDefine is a struct for JSON Schema. -type JSONSchemaDefine struct { +// JSONSchemaDefinition is a struct for JSON Schema. +// It is fairly limited and you may have better luck using a third-party library. +type JSONSchemaDefinition struct { // Type is a type of JSON Schema. Type JSONSchemaType `json:"type,omitempty"` // Description is a description of JSON Schema. @@ -92,13 +93,16 @@ type JSONSchemaDefine struct { // Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString. Enum []string `json:"enum,omitempty"` // Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject. - Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"` // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. Required []string `json:"required,omitempty"` // Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray. - Items *JSONSchemaDefine `json:"items,omitempty"` + Items *JSONSchemaDefinition `json:"items,omitempty"` } +// Deprecated: use JSONSchemaDefinition instead. +type JSONSchemaDefine = JSONSchemaDefinition + type FinishReason string const ( diff --git a/chat_test.go b/chat_test.go index a43bb4a..3c759b3 100644 --- a/chat_test.go +++ b/chat_test.go @@ -67,6 +67,130 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestChatCompletionsFunctions tests including a function call. +func TestChatCompletionsFunctions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + t.Run("bytes", func(t *testing.T) { + //nolint:lll + msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefine{{ + Name: "test", + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("struct", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefinition{{ + Name: "test", + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("JSONSchemaDefine", func(t *testing.T) { + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefinition{{ + Name: "test", + Parameters: &JSONSchemaDefinition{ + Type: JSONSchemaTypeObject, + Properties: map[string]JSONSchemaDefinition{ + "count": { + Type: JSONSchemaTypeNumber, + Description: "total number of words in sentence", + }, + "words": { + Type: JSONSchemaTypeArray, + Description: "list of words in sentence", + Items: &JSONSchemaDefinition{ + Type: JSONSchemaTypeString, + }, + }, + "enumTest": { + Type: JSONSchemaTypeString, + Enum: []string{"hello", "world"}, + }, + }, + }, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) { + // this is a compatibility check + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefine{{ + Name: "test", + Parameters: &JSONSchemaDefine{ + Type: JSONSchemaTypeObject, + Properties: map[string]JSONSchemaDefine{ + "count": { + Type: JSONSchemaTypeNumber, + Description: "total number of words in sentence", + }, + "words": { + Type: JSONSchemaTypeArray, + Description: "list of words in sentence", + Items: &JSONSchemaDefine{ + Type: JSONSchemaTypeString, + }, + }, + "enumTest": { + Type: JSONSchemaTypeString, + Enum: []string{"hello", "world"}, + }, + }, + }, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) +} + func TestAzureChatCompletions(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() @@ -109,7 +233,34 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { + n := completionReq.N + if n == 0 { + n = 1 + } + for i := 0; i < n; i++ { + // if there are functions, include them + if len(completionReq.Functions) > 0 { + var fcb []byte + b := completionReq.Functions[0].Parameters + fcb, err = json.Marshal(b) + if err != nil { + http.Error(w, "could not marshal function parameters", http.StatusInternalServerError) + return + } + + res.Choices = append(res.Choices, ChatCompletionChoice{ + Message: ChatCompletionMessage{ + Role: ChatMessageRoleFunction, + // this is valid json so it should be fine + FunctionCall: &FunctionCall{ + Name: completionReq.Functions[0].Name, + Arguments: string(fcb), + }, + }, + Index: i, + }) + continue + } // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) @@ -121,8 +272,8 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Index: i, }) } - inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n res.Usage = Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, diff --git a/completion_test.go b/completion_test.go index aeddcfc..844ef48 100644 --- a/completion_test.go +++ b/completion_test.go @@ -83,7 +83,11 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { + n := completionReq.N + if n == 0 { + n = 1 + } + for i := 0; i < n; i++ { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) if completionReq.Echo { @@ -94,8 +98,8 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Index: i, }) } - inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N + inputTokens := numTokens(completionReq.Prompt.(string)) * n + completionTokens := completionReq.MaxTokens * n res.Usage = Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens,