From 03caea89b75c4e6a5ac32f6e60e69e309d852e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Kintzi?= Date: Fri, 24 Nov 2023 13:17:00 +0000 Subject: [PATCH] Add support for multi part chat messages (and gpt-4-vision-preview model) (#580) * Add support for multi part chat messages OpenAI has recently introduced a new model called gpt-4-visual-preview, which now supports images as input. The chat completion endpoint accepts multi-part chat messages, where the content can be an array of structs in addition to the usual string format. This commit introduces new structures and constants to represent different types of content parts. It also implements the json.Marshaler and json.Unmarshaler interfaces on ChatCompletionMessage. * Add ImageURLDetail and ChatMessagePartType types * Optimize ChatCompletionMessage deserialization * Add ErrContentFieldsMisused error --- chat.go | 91 ++++++++++++++++++++++++++++++++++++++++++++- chat_test.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index ebdc0e2..5b87b6b 100644 --- a/chat.go +++ b/chat.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "errors" "net/http" ) @@ -20,6 +21,7 @@ const chatCompletionsSuffix = "/chat/completions" var ( ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll + ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously") ) type Hate struct { @@ -51,9 +53,36 @@ type PromptAnnotation struct { ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } +type ImageURLDetail string + +const ( + ImageURLDetailHigh ImageURLDetail = "high" + ImageURLDetailLow ImageURLDetail = "low" + ImageURLDetailAuto ImageURLDetail = "auto" +) + +type ChatMessageImageURL struct { + URL string `json:"url,omitempty"` + Detail ImageURLDetail `json:"detail,omitempty"` +} + +type ChatMessagePartType string + +const ( + ChatMessagePartTypeText ChatMessagePartType = "text" + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" +) + +type ChatMessagePart struct { + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` +} + type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in // the documentation for the official library for python: @@ -70,6 +99,64 @@ type ChatCompletionMessage struct { ToolCallID string `json:"tool_call_id,omitempty"` } +func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { + if m.Content != "" && m.MultiContent != nil { + return nil, ErrContentFieldsMisused + } + if len(m.MultiContent) > 0 { + msg := struct { + Role string `json:"role"` + Content string `json:"-"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) + } + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) +} + +func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &msg); err == nil { + *m = ChatCompletionMessage(msg) + return nil + } + multiMsg := struct { + Role string `json:"role"` + Content string + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &multiMsg); err != nil { + return err + } + *m = ChatCompletionMessage(multiMsg) + return nil +} + type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` diff --git a/chat_test.go b/chat_test.go index 8377809..520bf5c 100644 --- a/chat_test.go +++ b/chat_test.go @@ -3,6 +3,7 @@ package openai_test import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -296,6 +297,108 @@ func TestAzureChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateAzureChatCompletion error") } +func TestMultipartChatCompletions(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) + + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Hello!", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: "URL", + Detail: openai.ImageURLDetailLow, + }, + }, + }, + }, + }, + }) + checks.NoError(t, err, "CreateAzureChatCompletion error") +} + +func TestMultipartChatMessageSerialization(t *testing.T) { + jsonText := `[{"role":"system","content":"system-message"},` + + `{"role":"user","content":[{"type":"text","text":"nice-text"},` + + `{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]` + + var msgs []openai.ChatCompletionMessage + err := json.Unmarshal([]byte(jsonText), &msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + if len(msgs) != 2 { + t.Errorf("unexpected number of messages") + } + if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil { + t.Errorf("invalid user message: %v", msgs[0]) + } + if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 { + t.Errorf("invalid user message") + } + parts := msgs[1].MultiContent + if parts[0].Type != "text" || parts[0].Text != "nice-text" { + t.Errorf("invalid text part: %v", parts[0]) + } + if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" { + t.Errorf("invalid image_url part") + } + + s, err := json.Marshal(msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + res := strings.ReplaceAll(string(s), " ", "") + if res != jsonText { + t.Fatalf("invalid message: %s", string(s)) + } + + invalidMsg := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "some-text", + MultiContent: []openai.ChatMessagePart{ + { + Type: "text", + Text: "nice-text", + }, + }, + }, + } + _, err = json.Marshal(invalidMsg) + if !errors.Is(err, openai.ErrContentFieldsMisused) { + t.Fatalf("Expected error: %s", err) + } + + err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs) + if err == nil { + t.Fatalf("Expected error") + } + + emptyMultiContentMsg := openai.ChatCompletionMessage{ + Role: "user", + MultiContent: []openai.ChatMessagePart{}, + } + s, err = json.Marshal(emptyMultiContentMsg) + if err != nil { + t.Fatalf("Unexpected error") + } + res = strings.ReplaceAll(string(s), " ", "") + if res != `{"role":"user","content":""}` { + t.Fatalf("invalid message: %s", string(s)) + } +} + // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error