diff --git a/changelog/unreleased/kong/ai-anthropic-fix-function-calling.yml b/changelog/unreleased/kong/ai-anthropic-fix-function-calling.yml new file mode 100644 index 00000000000..41d2592f46d --- /dev/null +++ b/changelog/unreleased/kong/ai-anthropic-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Anthropic would return empty results." +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/ai-bedrock-fix-function-calling.yml b/changelog/unreleased/kong/ai-bedrock-fix-function-calling.yml new file mode 100644 index 00000000000..622e0532f1d --- /dev/null +++ b/changelog/unreleased/kong/ai-bedrock-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Bedrock would return empty results." +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/ai-bedrock-fix-guardrails.yml b/changelog/unreleased/kong/ai-bedrock-fix-guardrails.yml new file mode 100644 index 00000000000..d29cd7bab36 --- /dev/null +++ b/changelog/unreleased/kong/ai-bedrock-fix-guardrails.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where Bedrock Guardrail config was ignored." +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/ai-cohere-fix-function-calling.yml b/changelog/unreleased/kong/ai-cohere-fix-function-calling.yml new file mode 100644 index 00000000000..6e4885a2a43 --- /dev/null +++ b/changelog/unreleased/kong/ai-cohere-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Cohere would return empty results." +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/ai-gemini-blocks-content-safety.yml b/changelog/unreleased/kong/ai-gemini-blocks-content-safety.yml new file mode 100644 index 00000000000..3cdd2e3a284 --- /dev/null +++ b/changelog/unreleased/kong/ai-gemini-blocks-content-safety.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where Gemini provider would return an error if content safety failed in AI Proxy." +type: bugfix +scope: Plugin diff --git a/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml b/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml new file mode 100644 index 00000000000..59e6f5baa27 --- /dev/null +++ b/changelog/unreleased/kong/ai-gemini-fix-function-calling.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where tools (function) calls to Gemini (or via Vertex) would return empty results." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 508b62c4851..e09da82817e 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -44,17 +44,47 @@ local function kong_messages_to_claude_prompt(messages) return buf:get() end +local inject_tool_calls = function(tool_calls) + local tools + for _, n in ipairs(tool_calls) do + tools = tools or {} + table.insert(tools, { + type = "tool_use", + id = n.id, + name = n["function"].name, + input = cjson.decode(n["function"].arguments) + }) + end + + return tools +end + -- reuse the messages structure of prompt -- extract messages and system from kong request local function kong_messages_to_claude_messages(messages) local msgs, system, n = {}, nil, 1 for _, v in ipairs(messages) do - if v.role ~= "assistant" and v.role ~= "user" then + if v.role ~= "assistant" and v.role ~= "user" and v.role ~= "tool" then system = v.content - else - msgs[n] = v + if v.role == "assistant" and v.tool_calls then + msgs[n] = { + role = v.role, + content = inject_tool_calls(v.tool_calls), + } + elseif v.role == "tool" then + msgs[n] = { + role = "user", + content = {{ + type = "tool_result", + tool_use_id = v.tool_call_id, + content = v.content + }}, + } + else + msgs[n] = v + end n = n + 1 end end @@ -62,7 +92,6 @@ local function kong_messages_to_claude_messages(messages) return msgs, system end - local function to_claude_prompt(req) if req.prompt then return kong_prompt_to_claude_prompt(req.prompt) @@ -83,6 +112,21 @@ local function to_claude_messages(req) return nil, nil, "request is missing .messages command" end +local function to_tools(in_tools) + local out_tools = {} + + for i, v in ipairs(in_tools) do + if v['function'] then + v['function'].input_schema = v['function'].parameters + v['function'].parameters = nil + + table.insert(out_tools, v['function']) + end + end + + return out_tools +end + local transformers_to = { ["llm/v1/chat"] = function(request_table, model) local messages = {} @@ -98,6 +142,10 @@ local transformers_to = { messages.model = model.name or request_table.model messages.stream = request_table.stream or false -- explicitly set this if nil + -- handle function calling translation from OpenAI format + messages.tools = request_table.tools and to_tools(request_table.tools) + messages.tool_choice = request_table.tool_choice + return messages, "application/json", nil end, @@ -243,16 +291,37 @@ local transformers_from = { local function extract_text_from_content(content) local buf = buffer.new() for i, v in ipairs(content) do - if i ~= 1 then - buf:put("\n") + if v.text then + if i ~= 1 then + buf:put("\n") + end + buf:put(v.text) end - - buf:put(v.text) end return buf:tostring() end + local function extract_tools_from_content(content) + local tools + for i, v in ipairs(content) do + if v.type == "tool_use" then + tools = tools or {} + + table.insert(tools, { + id = v.id, + type = "function", + ['function'] = { + name = v.name, + arguments = cjson.encode(v.input), + } + }) + end + end + + return tools + end + if response_table.content then local usage = response_table.usage @@ -275,13 +344,14 @@ local transformers_from = { message = { role = "assistant", content = extract_text_from_content(response_table.content), + tool_calls = extract_tools_from_content(response_table.content) }, finish_reason = response_table.stop_reason, }, }, usage = usage, model = response_table.model, - object = "chat.content", + object = "chat.completion", } return cjson.encode(res) @@ -330,7 +400,10 @@ function _M.from_format(response_string, model_info, route_type) end local ok, response_string, err, metadata = pcall(transform, response_string, model_info, route_type) - if not ok or err then + if not ok then + err = response_string + end + if err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, @@ -488,7 +561,7 @@ function _M.configure_request(conf) end end - -- if auth_param_location is "form", it will have already been set in a pre-request hook + -- if auth_param_location is "body", it will have already been set in a pre-request hook return true, nil end diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index a4221a97c81..4ffd1c94b6a 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -7,7 +7,6 @@ local ai_shared = require("kong.llm.drivers.shared") local socket_url = require("socket.url") local string_gsub = string.gsub local table_insert = table.insert -local string_lower = string.lower local signer = require("resty.aws.request.sign") local ai_plugin_ctx = require("kong.llm.plugin.ctx") -- @@ -21,6 +20,14 @@ local _OPENAI_ROLE_MAPPING = { ["system"] = "assistant", ["user"] = "user", ["assistant"] = "assistant", + ["tool"] = "user", +} + +local _OPENAI_STOP_REASON_MAPPING = { + ["max_tokens"] = "length", + ["end_turn"] = "stop", + ["tool_use"] = "tool_calls", + ["guardrail_intervened"] = "guardrail_intervened", } _M.bedrock_unsupported_system_role_patterns = { @@ -40,18 +47,81 @@ local function to_bedrock_generation_config(request_table) } end +local function to_bedrock_guardrail_config(guardrail_config) + return guardrail_config -- may be nil; this is handled +end + +-- this is a placeholder and is archaic now, +-- leave it in for backwards compatibility local function to_additional_request_fields(request_table) return { request_table.bedrock.additionalModelRequestFields } end +-- this is a placeholder and is archaic now, +-- leave it in for backwards compatibility local function to_tool_config(request_table) return { request_table.bedrock.toolConfig } end +local function to_tools(in_tools) + if not in_tools then + return nil + end + + local out_tools + + for i, v in ipairs(in_tools) do + if v['function'] then + out_tools = out_tools or {} + + out_tools[i] = { + toolSpec = { + name = v['function'].name, + description = v['function'].description, + inputSchema = { + json = v['function'].parameters, + }, + }, + } + end + end + + return out_tools +end + +local function from_tool_call_response(content) + if not content then return nil end + + local tools_used + + for _, t in ipairs(content) do + if t.toolUse then + tools_used = tools_used or {} + + local arguments + if t.toolUse['input'] and next(t.toolUse['input']) then + arguments = cjson.encode(t.toolUse['input']) + end + + tools_used[#tools_used+1] = { + -- set explicit numbering to ensure ordering in later modifications + ['function'] = { + arguments = arguments, + name = t.toolUse.name, + }, + id = t.toolUse.toolUseId, + type = "function", + } + end + end + + return tools_used +end + local function handle_stream_event(event_t, model_info, route_type) local new_event, metadata @@ -114,7 +184,7 @@ local function handle_stream_event(event_t, model_info, route_type) [1] = { delta = {}, index = 0, - finish_reason = body.stopReason, + finish_reason = _OPENAI_STOP_REASON_MAPPING[body.stopReason] or "stop", logprobs = cjson.null, }, }, @@ -145,7 +215,7 @@ local function handle_stream_event(event_t, model_info, route_type) end local function to_bedrock_chat_openai(request_table, model_info, route_type) - if not request_table then -- try-catch type mechanism + if not request_table then local err = "empty request table received for transformation" ngx.log(ngx.ERR, "[bedrock] ", err) return nil, nil, err @@ -165,10 +235,56 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) if v.role and v.role == "system" then system_prompts[#system_prompts+1] = { text = v.content } + elseif v.role and v.role == "tool" then + local tool_execution_content, err = cjson.decode(v.content) + if err then + return nil, nil, "failed to decode function response arguments, not JSON format" + end + + local content = { + { + toolResult = { + toolUseId = v.tool_call_id, + content = { + { + json = tool_execution_content, + }, + }, + status = v.status, + }, + }, + } + + new_r.messages = new_r.messages or {} + table_insert(new_r.messages, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + content = content, + }) + else local content if type(v.content) == "table" then content = v.content + + elseif v.tool_calls and (type(v.tool_calls) == "table") then + for k, tool in ipairs(v.tool_calls) do + local inputs, err = cjson.decode(tool['function'].arguments) + if err then + return nil, nil, "failed to decode function response arguments from assistant's message, not JSON format" + end + + content = { + { + toolUse = { + toolUseId = tool.id, + name = tool['function'].name, + input = inputs, + }, + }, + } + + end + else content = { { @@ -199,10 +315,20 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type) end new_r.inferenceConfig = to_bedrock_generation_config(request_table) + new_r.guardrailConfig = to_bedrock_guardrail_config(request_table.guardrailConfig) + -- backwards compatibility new_r.toolConfig = request_table.bedrock and request_table.bedrock.toolConfig and to_tool_config(request_table) + + if request_table.tools + and type(request_table.tools) == "table" + and #request_table.tools > 0 then + + new_r.toolConfig = new_r.toolConfig or {} + new_r.toolConfig.tools = to_tools(request_table.tools) + end new_r.additionalModelRequestFields = request_table.bedrock and request_table.bedrock.additionalModelRequestFields @@ -220,23 +346,22 @@ local function from_bedrock_chat_openai(response, model_info, route_type) return nil, err_client end - -- messages/choices table is only 1 size, so don't need to static allocate local client_response = {} client_response.choices = {} if response.output and response.output.message and response.output.message.content - and #response.output.message.content > 0 - and response.output.message.content[1].text then + and #response.output.message.content > 0 then - client_response.choices[1] = { + client_response.choices[1] = { index = 0, message = { role = "assistant", - content = response.output.message.content[1].text, + content = response.output.message.content[1].text, -- may be nil + tool_calls = from_tool_call_response(response.output.message.content), }, - finish_reason = string_lower(response.stopReason), + finish_reason = _OPENAI_STOP_REASON_MAPPING[response.stopReason] or "stop", } client_response.object = "chat.completion" client_response.model = model_info.name @@ -256,6 +381,8 @@ local function from_bedrock_chat_openai(response, model_info, route_type) } end + client_response.trace = response.trace -- may be nil, **do not** map to cjson.null + return cjson.encode(client_response) end @@ -277,7 +404,10 @@ function _M.from_format(response_string, model_info, route_type) end local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) - if not ok or err then + if not ok then + err = response_string + end + if err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, @@ -295,7 +425,7 @@ function _M.to_format(request_table, model_info, route_type) -- do nothing return request_table, nil, nil end - + if not transformers_to[route_type] then return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) end @@ -475,4 +605,13 @@ function _M.configure_request(conf, aws_sdk) return true end + +if _G._TEST then + -- export locals for testing + _M._to_tools = to_tools + _M._to_bedrock_chat_openai = to_bedrock_chat_openai + _M._from_tool_call_response = from_tool_call_response +end + + return _M diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 5f29a928bb0..38acdefc68a 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -4,6 +4,7 @@ local _M = {} local cjson = require("cjson.safe") local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") +local openai_driver = require("kong.llm.drivers.openai") local socket_url = require "socket.url" local table_new = require("table.new") local string_gsub = string.gsub @@ -260,6 +261,37 @@ local transformers_from = { and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens), } messages.usage = stats + + elseif response_table.message then + -- this is a "co.chat" + + messages.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response_table.message.tool_plan or response_table.message.content, + tool_calls = response_table.message.tool_calls + }, + finish_reason = response_table.finish_reason, + } + messages.object = "chat.completion" + messages.model = model_info.name + messages.id = response_table.id + + local stats = { + completion_tokens = response_table.usage + and response_table.usage.billed_units + and response_table.usage.billed_units.output_tokens, + + prompt_tokens = response_table.usage + and response_table.usage.billed_units + and response_table.usage.billed_units.input_tokens, + + total_tokens = response_table.usage + and response_table.usage.billed_units + and (response_table.usage.billed_units.output_tokens + response_table.usage.billed_units.input_tokens), + } + messages.usage = stats else -- probably a fault return nil, "'text' or 'generations' missing from cohere response body" @@ -343,7 +375,10 @@ function _M.from_format(response_string, model_info, route_type) end local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) - if not ok or err then + if not ok then + err = response_string + end + if err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, @@ -357,6 +392,10 @@ end function _M.to_format(request_table, model_info, route_type) ngx.log(ngx.DEBUG, "converting from kong type to ", model_info.provider, "/", route_type) + if request_table.tools then + return openai_driver.to_format(request_table, model_info, route_type) + end + if route_type == "preserve" then -- do nothing return request_table, nil, nil @@ -497,7 +536,7 @@ function _M.configure_request(conf) end end - -- if auth_param_location is "form", it will have already been set in a pre-request hook + -- if auth_param_location is "body", it will have already been set in a pre-request hook return true, nil end diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 71e223a9fc3..1cbdd6095e5 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -10,6 +10,7 @@ local buffer = require("string.buffer") local table_insert = table.insert local string_lower = string.lower local ai_plugin_ctx = require("kong.llm.plugin.ctx") +local ai_plugin_base = require("kong.llm.plugin.base") -- -- globals @@ -33,6 +34,14 @@ local function to_gemini_generation_config(request_table) } end +local function is_content_safety_failure(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].finishReason + and content.candidates[1].finishReason == "SAFETY" +end + local function is_response_content(content) return content and content.candidates @@ -43,6 +52,25 @@ local function is_response_content(content) and content.candidates[1].content.parts[1].text end +local function is_tool_content(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].content + and content.candidates[1].content.parts + and #content.candidates[1].content.parts > 0 + and content.candidates[1].content.parts[1].functionCall +end + +local function is_function_call_message(message) + return message + and message.role + and message.role == "assistant" + and message.tool_calls + and type(message.tool_calls) == "table" + and #message.tool_calls > 0 +end + local function handle_stream_event(event_t, model_info, route_type) -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then @@ -84,10 +112,32 @@ local function handle_stream_event(event_t, model_info, route_type) end end +local function to_tools(in_tools) + if not in_tools then + return nil + end + + local out_tools + + for i, v in ipairs(in_tools) do + if v['function'] then + out_tools = out_tools or { + [1] = { + function_declarations = {} + } + } + + out_tools[1].function_declarations[i] = v['function'] + end + end + + return out_tools +end + local function to_gemini_chat_openai(request_table, model_info, route_type) - if request_table then -- try-catch type mechanism - local new_r = {} + local new_r = {} + if request_table then if request_table.messages and #request_table.messages > 0 then local system_prompt @@ -97,18 +147,60 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) if v.role and v.role == "system" then system_prompt = system_prompt or buffer.new() system_prompt:put(v.content or "") + + elseif v.role and v.role == "tool" then + -- handle tool execution output + table_insert(new_r.contents, { + role = "function", + parts = { + { + function_response = { + response = { + content = { + v.content, + }, + }, + name = "get_product_info", + }, + }, + }, + }) + + elseif is_function_call_message(v) then + -- treat specific 'assistant function call' tool execution input message + local function_calls = {} + for i, t in ipairs(v.tool_calls) do + function_calls[i] = { + function_call = { + name = t['function'].name, + }, + } + end + + table_insert(new_r.contents, { + role = "function", + parts = function_calls, + }) + else -- for any other role, just construct the chat history as 'parts.text' type new_r.contents = new_r.contents or {} + + local part = v.content + if type(v.content) == "string" then + part = { + text = v.content + } + end + table_insert(new_r.contents, { role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' parts = { - { - text = v.content or "" - }, + part, }, }) end + end -- This was only added in Gemini 1.5 @@ -128,42 +220,19 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) new_r.generationConfig = to_gemini_generation_config(request_table) - return new_r, "application/json", nil + -- handle function calling translation from OpenAI format + new_r.tools = request_table.tools and to_tools(request_table.tools) + new_r.tool_config = request_table.tool_config end - local new_r = {} - - if request_table.messages and #request_table.messages > 0 then - local system_prompt - - for i, v in ipairs(request_table.messages) do - - -- for 'system', we just concat them all into one Gemini instruction - if v.role and v.role == "system" then - system_prompt = system_prompt or buffer.new() - system_prompt:put(v.content or "") - else - -- for any other role, just construct the chat history as 'parts.text' type - new_r.contents = new_r.contents or {} - table_insert(new_r.contents, { - role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' - parts = { - { - text = v.content or "" - }, - }, - }) - end - end - end - - new_r.generationConfig = to_gemini_generation_config(request_table) - return new_r, "application/json", nil end local function from_gemini_chat_openai(response, model_info, route_type) - local response, err = cjson.decode(response) + local err + if response and (type(response) == "string") then + response, err = cjson.decode(response) + end if err then local err_client = "failed to decode response from Gemini" @@ -175,20 +244,48 @@ local function from_gemini_chat_openai(response, model_info, route_type) local messages = {} messages.choices = {} - if response.candidates - and #response.candidates > 0 - and is_response_content(response) then + if response.candidates and #response.candidates > 0 then + -- for transformer plugins only + if is_content_safety_failure(response) and + (ai_plugin_base.has_filter_executed("ai-request-transformer-transform-request") or + ai_plugin_base.has_filter_executed("ai-response-transformer-transform-response")) then - messages.choices[1] = { - index = 0, - message = { - role = "assistant", - content = response.candidates[1].content.parts[1].text, - }, - finish_reason = string_lower(response.candidates[1].finishReason), - } - messages.object = "chat.completion" - messages.model = model_info.name + local err = "transformation generation candidate breached Gemini content safety" + ngx.log(ngx.ERR, err) + + return nil, err + + elseif is_response_content(response) then + messages.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response.candidates[1].content.parts[1].text, + }, + finish_reason = string_lower(response.candidates[1].finishReason), + } + messages.object = "chat.completion" + messages.model = model_info.name + + elseif is_tool_content(response) then + local function_call_responses = response.candidates[1].content.parts + for i, v in ipairs(function_call_responses) do + messages.choices[i] = { + index = 0, + message = { + role = "assistant", + tool_calls = { + { + ['function'] = { + name = v.functionCall.name, + arguments = cjson.encode(v.functionCall.args), + }, + }, + }, + }, + } + end + end -- process analytics if response.usageMetadata then @@ -199,15 +296,7 @@ local function from_gemini_chat_openai(response, model_info, route_type) } end - elseif response.candidates - and #response.candidates > 0 - and response.candidates[1].finishReason - and response.candidates[1].finishReason == "SAFETY" then - local err = "transformation generation candidate breached Gemini content safety" - ngx.log(ngx.ERR, err) - return nil, err - - else-- probably a server fault or other unexpected response + else -- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) return nil, err @@ -235,7 +324,10 @@ function _M.from_format(response_string, model_info, route_type) end local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) - if not ok or err then + if not ok then + err = response_string + end + if err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, @@ -472,4 +564,12 @@ function _M.configure_request(conf, identity_interface) return true end + +if _G._TEST then + -- export locals for testing + _M._to_tools = to_tools + _M._from_gemini_chat_openai = from_gemini_chat_openai +end + + return _M diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index 25a18f91edb..788fcf9ba93 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -165,7 +165,10 @@ function _M.from_format(response_string, model_info, route_type) model_info, route_type ) - if not ok or err then + if not ok then + err = response_string + end + if err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, err or "unexpected_error") end diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index ad558ccd5f4..396e379e597 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -47,7 +47,10 @@ function _M.from_format(response_string, model_info, route_type) model_info, route_type ) - if not ok or err then + if not ok then + err = response_string + end + if err then return nil, fmt("transformation failed from type %s://%s/%s: %s", model_info.provider, route_type, model_info.options.mistral_version, err or "unexpected_error") end diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index f6c99b246b5..549b51ea04a 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -74,7 +74,10 @@ function _M.from_format(response_string, model_info, route_type) end local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info) - if not ok or err then + if not ok then + err = response_string + end + if err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 19c17537287..139cc739d97 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -460,6 +460,10 @@ function _M.to_ollama(request_table, model) input.stream = request_table.stream or false -- for future capability input.model = model.name or request_table.name + -- handle function calling translation from Ollama format + input.tools = request_table.tools + input.tool_choice = request_table.tool_choice + if model.options then input.options = {} @@ -522,7 +526,7 @@ function _M.from_ollama(response_string, model_info, route_type) output.object = "chat.completion" output.choices = { { - finish_reason = stop_reason, + finish_reason = response_table.finish_reason or stop_reason, index = 0, message = response_table.message, } diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 2afa28da2f0..da53b1be7a2 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -4,6 +4,10 @@ local cjson = require("cjson.safe") local fmt = string.format local EMPTY = require("kong.tools.table").EMPTY +local EMPTY_ARRAY = { + EMPTY, +} + -- The module table local _M = { @@ -142,10 +146,10 @@ do if err then return nil, err end - + -- run the shared logging/analytics/auth function ai_shared.pre_request(self.conf, ai_request) - + -- send it to the ai service local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false, self.identity_interface) if err then @@ -166,7 +170,7 @@ do return nil, "failed to convert AI response to JSON: " .. err end - local new_request_body = ((ai_response.choices or EMPTY)[1].message or EMPTY).content + local new_request_body = (((ai_response.choices or EMPTY_ARRAY)[1] or EMPTY).message or EMPTY).content if not new_request_body then return nil, "no 'choices' in upstream AI service response" end diff --git a/kong/llm/plugin/base.lua b/kong/llm/plugin/base.lua index 24c873ff01d..3a23c78a873 100644 --- a/kong/llm/plugin/base.lua +++ b/kong/llm/plugin/base.lua @@ -194,6 +194,10 @@ function _M.register_filter(f) return f end +function _M.has_filter_executed(name) + return ngx.ctx.ai_executed_filters and ngx.ctx.ai_executed_filters[name] +end + -- enable the filter for current sub plugin function _M:enable(filter) if type(filter) ~= "table" or not filter.NAME then diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 644235911a9..7c0814ad7c4 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -37,6 +37,24 @@ local SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS = { another_extra_param = 0.5, } +local SAMPLE_LLM_V1_CHAT_WITH_GUARDRAILS = { + messages = { + [1] = { + role = "system", + content = "You are a mathematician." + }, + [2] = { + role = "assistant", + content = "What is 1 + 1?" + }, + }, + guardrailConfig = { + guardrailIdentifier = "yu5xwvfp4sud", + guardrailVersion = "1", + trace = "enabled", + }, +} + local SAMPLE_DOUBLE_FORMAT = { messages = { [1] = { @@ -51,6 +69,80 @@ local SAMPLE_DOUBLE_FORMAT = { prompt = "Hi world", } +local SAMPLE_OPENAI_TOOLS_REQUEST = { + messages = { + [1] = { + role = "user", + content = "Is the NewPhone in stock?" + }, + }, + tools = { + [1] = { + ['function'] = { + parameters = { + ['type'] = "object", + properties = { + product_name = { + ['type'] = "string", + }, + }, + required = { + "product_name", + }, + }, + name = "check_stock", + description = "Check a product is in stock." + }, + ['type'] = "function", + }, + }, +} + +local SAMPLE_GEMINI_TOOLS_RESPONSE = { + candidates = { { + content = { + role = "model", + parts = { { + functionCall = { + name = "sql_execute", + args = { + product_name = "NewPhone" + } + } + } } + }, + finishReason = "STOP", + } }, +} + +local SAMPLE_BEDROCK_TOOLS_RESPONSE = { + metrics = { + latencyMs = 3781 + }, + output = { + message = { + content = { { + text = "Certainly! To calculate the sum of 121, 212, and 313, we can use the \"sumArea\" function that's available to us." + }, { + toolUse = { + input = { + areas = { 121, 212, 313 } + }, + name = "sumArea", + toolUseId = "tooluse_4ZakZPY9SiWoKWrAsY7_eg" + } + } }, + role = "assistant" + } + }, + stopReason = "tool_use", + usage = { + inputTokens = 410, + outputTokens = 115, + totalTokens = 525 + } +} + local FORMATS = { openai = { ["llm/v1/chat"] = { @@ -775,4 +867,149 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) end) + describe("gemini tools", function() + local gemini_driver + + setup(function() + _G._TEST = true + package.loaded["kong.llm.drivers.gemini"] = nil + gemini_driver = require("kong.llm.drivers.gemini") + end) + + teardown(function() + _G._TEST = nil + end) + + it("transforms openai tools to gemini tools GOOD", function() + local gemini_tools = gemini_driver._to_tools(SAMPLE_OPENAI_TOOLS_REQUEST.tools) + + assert.not_nil(gemini_tools) + assert.same(gemini_tools, { + { + function_declarations = { + { + description = "Check a product is in stock.", + name = "check_stock", + parameters = { + properties = { + product_name = { + type = "string" + } + }, + required = { + "product_name" + }, + type = "object" + } + } + } + } + }) + end) + + it("transforms openai tools to gemini tools NO_TOOLS", function() + local gemini_tools = gemini_driver._to_tools(SAMPLE_LLM_V1_CHAT) + + assert.is_nil(gemini_tools) + end) + + it("transforms openai tools to gemini tools NIL", function() + local gemini_tools = gemini_driver._to_tools(nil) + + assert.is_nil(gemini_tools) + end) + + it("transforms gemini tools to openai tools GOOD", function() + local openai_tools = gemini_driver._from_gemini_chat_openai(SAMPLE_GEMINI_TOOLS_RESPONSE, {}, "llm/v1/chat") + + assert.not_nil(openai_tools) + + openai_tools = cjson.decode(openai_tools) + assert.same(openai_tools.choices[1].message.tool_calls[1]['function'], { + name = "sql_execute", + arguments = "{\"product_name\":\"NewPhone\"}" + }) + end) + end) + + describe("bedrock tools", function() + local bedrock_driver + + setup(function() + _G._TEST = true + package.loaded["kong.llm.drivers.bedrock"] = nil + bedrock_driver = require("kong.llm.drivers.bedrock") + end) + + teardown(function() + _G._TEST = nil + end) + + it("transforms openai tools to bedrock tools GOOD", function() + local bedrock_tools = bedrock_driver._to_tools(SAMPLE_OPENAI_TOOLS_REQUEST.tools) + + assert.not_nil(bedrock_tools) + assert.same(bedrock_tools, { + { + toolSpec = { + description = "Check a product is in stock.", + inputSchema = { + json = { + properties = { + product_name = { + type = "string" + } + }, + required = { + "product_name" + }, + type = "object" + } + }, + name = "check_stock" + } + } + }) + end) + + it("transforms openai tools to bedrock tools NO_TOOLS", function() + local bedrock_tools = bedrock_driver._to_tools(SAMPLE_LLM_V1_CHAT) + + assert.is_nil(bedrock_tools) + end) + + it("transforms openai tools to bedrock tools NIL", function() + local bedrock_tools = bedrock_driver._to_tools(nil) + + assert.is_nil(bedrock_tools) + end) + + it("transforms bedrock tools to openai tools GOOD", function() + local openai_tools = bedrock_driver._from_tool_call_response(SAMPLE_BEDROCK_TOOLS_RESPONSE.output.message.content) + + assert.not_nil(openai_tools) + + assert.same(openai_tools[1]['function'], { + name = "sumArea", + arguments = "{\"areas\":[121,212,313]}" + }) + end) + + it("transforms guardrails into bedrock generation config", function() + local model_info = { + route_type = "llm/v1/chat", + name = "some-model", + provider = "bedrock", + } + local bedrock_guardrails = bedrock_driver._to_bedrock_chat_openai(SAMPLE_LLM_V1_CHAT_WITH_GUARDRAILS, model_info, "llm/v1/chat") + + assert.not_nil(bedrock_guardrails) + + assert.same(bedrock_guardrails.guardrailConfig, { + ['guardrailIdentifier'] = 'yu5xwvfp4sud', + ['guardrailVersion'] = '1', + ['trace'] = 'enabled', + }) + end) + end) end) diff --git a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua index dd568ed79b9..00239f04f1f 100644 --- a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua @@ -570,7 +570,7 @@ for _, strategy in helpers.all_strategies() do -- check this is in the 'kong' response format -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") assert.equals(json.model, "claude-2.1") - assert.equals(json.object, "chat.content") + assert.equals(json.object, "chat.completion") assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") assert.is_table(json.choices) @@ -597,7 +597,7 @@ for _, strategy in helpers.all_strategies() do -- check this is in the 'kong' response format -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") assert.equals(json.model, "claude-2.1") - assert.equals(json.object, "chat.content") + assert.equals(json.object, "chat.completion") assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") assert.is_table(json.choices) @@ -642,7 +642,7 @@ for _, strategy in helpers.all_strategies() do -- check this is in the 'kong' response format -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") assert.equals(json.model, "claude-2.1") - assert.equals(json.object, "chat.content") + assert.equals(json.object, "chat.completion") assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") assert.is_table(json.choices) @@ -669,7 +669,7 @@ for _, strategy in helpers.all_strategies() do -- check this is in the 'kong' response format -- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2") assert.equals(json.model, "claude-2.1") - assert.equals(json.object, "chat.content") + assert.equals(json.object, "chat.completion") assert.equals(r.headers["X-Kong-LLM-Model"], "anthropic/claude-2.1") assert.is_table(json.choices)