diff --git a/completion.go b/completion.go index 84ef2ad2..77ea8c3a 100644 --- a/completion.go +++ b/completion.go @@ -161,7 +161,23 @@ func checkEndpointSupportsModel(endpoint, model string) bool { func checkPromptType(prompt any) bool { _, isString := prompt.(string) _, isStringSlice := prompt.([]string) - return isString || isStringSlice + if isString || isStringSlice { + return true + } + + // check if it is prompt is []string hidden under []any + slice, isSlice := prompt.([]any) + if !isSlice { + return false + } + + for _, item := range slice { + _, itemIsString := item.(string) + if !itemIsString { + return false + } + } + return true // all items in the slice are string, so it is []string } var unsupportedToolsForO1Models = map[ToolType]struct{}{ diff --git a/completion_test.go b/completion_test.go index 89950bf9..935bbe86 100644 --- a/completion_test.go +++ b/completion_test.go @@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) { checks.NoError(t, err, "CreateCompletion error") } +// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts with wrong type. +func TestMultiplePromptsCompletionsWrong(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", 9}, + } + _, err := client.CreateCompletion(context.Background(), req) + if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err) + } +} + +// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts. +func TestMultiplePromptsCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"}, + } + _, err := client.CreateCompletion(context.Background(), req) + checks.NoError(t, err, "CreateCompletion error") +} + // handleCompletionEndpoint Handles the completion endpoint by the test server. func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error @@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if n == 0 { n = 1 } + // Handle different types of prompts: single string or list of strings + prompts := []string{} + switch v := completionReq.Prompt.(type) { + case string: + prompts = append(prompts, v) + case []interface{}: + for _, item := range v { + if str, ok := item.(string); ok { + prompts = append(prompts, str) + } + } + default: + http.Error(w, "Invalid prompt type", http.StatusBadRequest) + return + } + for i := 0; i < n; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt.(string) + completionStr + for _, prompt := range prompts { + // Generate a random string of length completionReq.MaxTokens + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = prompt + completionStr + } + + res.Choices = append(res.Choices, openai.CompletionChoice{ + Text: completionStr, + Index: len(res.Choices), + }) } - res.Choices = append(res.Choices, openai.CompletionChoice{ - Text: completionStr, - Index: i, - }) } - inputTokens := numTokens(completionReq.Prompt.(string)) * n - completionTokens := completionReq.MaxTokens * n + + inputTokens := 0 + for _, prompt := range prompts { + inputTokens += numTokens(prompt) + } + inputTokens *= n + completionTokens := completionReq.MaxTokens * len(prompts) * n res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } + + // Serialize the response and send it back resBytes, _ = json.Marshal(res) fmt.Fprintln(w, string(resBytes)) }