Updated checkPromptType function to handle prompt list in completions (#885)
* updated checkPromptType function to handle prompt list in completions * removed generated test file * added corresponding unit testcases * Updated to use less nesting with early returns
This commit is contained in:
@@ -161,7 +161,23 @@ func checkEndpointSupportsModel(endpoint, model string) bool {
|
||||
func checkPromptType(prompt any) bool {
|
||||
_, isString := prompt.(string)
|
||||
_, isStringSlice := prompt.([]string)
|
||||
return isString || isStringSlice
|
||||
if isString || isStringSlice {
|
||||
return true
|
||||
}
|
||||
|
||||
// check if it is prompt is []string hidden under []any
|
||||
slice, isSlice := prompt.([]any)
|
||||
if !isSlice {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, item := range slice {
|
||||
_, itemIsString := item.(string)
|
||||
if !itemIsString {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true // all items in the slice are string, so it is []string
|
||||
}
|
||||
|
||||
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
||||
|
||||
@@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) {
|
||||
checks.NoError(t, err, "CreateCompletion error")
|
||||
}
|
||||
|
||||
// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server
|
||||
// where the completions requests has a list of prompts with wrong type.
|
||||
func TestMultiplePromptsCompletionsWrong(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
|
||||
req := openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: "ada",
|
||||
Prompt: []interface{}{"Lorem ipsum", 9},
|
||||
}
|
||||
_, err := client.CreateCompletion(context.Background(), req)
|
||||
if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server
|
||||
// where the completions requests has a list of prompts.
|
||||
func TestMultiplePromptsCompletions(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
|
||||
req := openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: "ada",
|
||||
Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"},
|
||||
}
|
||||
_, err := client.CreateCompletion(context.Background(), req)
|
||||
checks.NoError(t, err, "CreateCompletion error")
|
||||
}
|
||||
|
||||
// handleCompletionEndpoint Handles the completion endpoint by the test server.
|
||||
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
@@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
// generate a random string of length completionReq.Length
|
||||
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
||||
if completionReq.Echo {
|
||||
completionStr = completionReq.Prompt.(string) + completionStr
|
||||
// Handle different types of prompts: single string or list of strings
|
||||
prompts := []string{}
|
||||
switch v := completionReq.Prompt.(type) {
|
||||
case string:
|
||||
prompts = append(prompts, v)
|
||||
case []interface{}:
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
prompts = append(prompts, str)
|
||||
}
|
||||
}
|
||||
res.Choices = append(res.Choices, openai.CompletionChoice{
|
||||
Text: completionStr,
|
||||
Index: i,
|
||||
})
|
||||
default:
|
||||
http.Error(w, "Invalid prompt type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
inputTokens := numTokens(completionReq.Prompt.(string)) * n
|
||||
completionTokens := completionReq.MaxTokens * n
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
for _, prompt := range prompts {
|
||||
// Generate a random string of length completionReq.MaxTokens
|
||||
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
||||
if completionReq.Echo {
|
||||
completionStr = prompt + completionStr
|
||||
}
|
||||
|
||||
res.Choices = append(res.Choices, openai.CompletionChoice{
|
||||
Text: completionStr,
|
||||
Index: len(res.Choices),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
inputTokens := 0
|
||||
for _, prompt := range prompts {
|
||||
inputTokens += numTokens(prompt)
|
||||
}
|
||||
inputTokens *= n
|
||||
completionTokens := completionReq.MaxTokens * len(prompts) * n
|
||||
res.Usage = openai.Usage{
|
||||
PromptTokens: inputTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: inputTokens + completionTokens,
|
||||
}
|
||||
|
||||
// Serialize the response and send it back
|
||||
resBytes, _ = json.Marshal(res)
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user