Updated checkPromptType function to handle prompt list in completions (#885)
* updated checkPromptType function to handle prompt list in completions * removed generated test file * added corresponding unit testcases * Updated to use less nesting with early returns
This commit is contained in:
@@ -161,7 +161,23 @@ func checkEndpointSupportsModel(endpoint, model string) bool {
|
|||||||
func checkPromptType(prompt any) bool {
|
func checkPromptType(prompt any) bool {
|
||||||
_, isString := prompt.(string)
|
_, isString := prompt.(string)
|
||||||
_, isStringSlice := prompt.([]string)
|
_, isStringSlice := prompt.([]string)
|
||||||
return isString || isStringSlice
|
if isString || isStringSlice {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if it is prompt is []string hidden under []any
|
||||||
|
slice, isSlice := prompt.([]any)
|
||||||
|
if !isSlice {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range slice {
|
||||||
|
_, itemIsString := item.(string)
|
||||||
|
if !itemIsString {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true // all items in the slice are string, so it is []string
|
||||||
}
|
}
|
||||||
|
|
||||||
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
||||||
|
|||||||
@@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateCompletion error")
|
checks.NoError(t, err, "CreateCompletion error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server
|
||||||
|
// where the completions requests has a list of prompts with wrong type.
|
||||||
|
func TestMultiplePromptsCompletionsWrong(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
|
||||||
|
req := openai.CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: "ada",
|
||||||
|
Prompt: []interface{}{"Lorem ipsum", 9},
|
||||||
|
}
|
||||||
|
_, err := client.CreateCompletion(context.Background(), req)
|
||||||
|
if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) {
|
||||||
|
t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server
|
||||||
|
// where the completions requests has a list of prompts.
|
||||||
|
func TestMultiplePromptsCompletions(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
|
||||||
|
req := openai.CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: "ada",
|
||||||
|
Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"},
|
||||||
|
}
|
||||||
|
_, err := client.CreateCompletion(context.Background(), req)
|
||||||
|
checks.NoError(t, err, "CreateCompletion error")
|
||||||
|
}
|
||||||
|
|
||||||
// handleCompletionEndpoint Handles the completion endpoint by the test server.
|
// handleCompletionEndpoint Handles the completion endpoint by the test server.
|
||||||
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||||
var err error
|
var err error
|
||||||
@@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
if n == 0 {
|
if n == 0 {
|
||||||
n = 1
|
n = 1
|
||||||
}
|
}
|
||||||
for i := 0; i < n; i++ {
|
// Handle different types of prompts: single string or list of strings
|
||||||
// generate a random string of length completionReq.Length
|
prompts := []string{}
|
||||||
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
switch v := completionReq.Prompt.(type) {
|
||||||
if completionReq.Echo {
|
case string:
|
||||||
completionStr = completionReq.Prompt.(string) + completionStr
|
prompts = append(prompts, v)
|
||||||
|
case []interface{}:
|
||||||
|
for _, item := range v {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
prompts = append(prompts, str)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
res.Choices = append(res.Choices, openai.CompletionChoice{
|
default:
|
||||||
Text: completionStr,
|
http.Error(w, "Invalid prompt type", http.StatusBadRequest)
|
||||||
Index: i,
|
return
|
||||||
})
|
|
||||||
}
|
}
|
||||||
inputTokens := numTokens(completionReq.Prompt.(string)) * n
|
|
||||||
completionTokens := completionReq.MaxTokens * n
|
for i := 0; i < n; i++ {
|
||||||
|
for _, prompt := range prompts {
|
||||||
|
// Generate a random string of length completionReq.MaxTokens
|
||||||
|
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
||||||
|
if completionReq.Echo {
|
||||||
|
completionStr = prompt + completionStr
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Choices = append(res.Choices, openai.CompletionChoice{
|
||||||
|
Text: completionStr,
|
||||||
|
Index: len(res.Choices),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputTokens := 0
|
||||||
|
for _, prompt := range prompts {
|
||||||
|
inputTokens += numTokens(prompt)
|
||||||
|
}
|
||||||
|
inputTokens *= n
|
||||||
|
completionTokens := completionReq.MaxTokens * len(prompts) * n
|
||||||
res.Usage = openai.Usage{
|
res.Usage = openai.Usage{
|
||||||
PromptTokens: inputTokens,
|
PromptTokens: inputTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: inputTokens + completionTokens,
|
TotalTokens: inputTokens + completionTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Serialize the response and send it back
|
||||||
resBytes, _ = json.Marshal(res)
|
resBytes, _ = json.Marshal(res)
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user