From 0349b1ac06f8933a7bb3c985eaccaca464ec225c Mon Sep 17 00:00:00 2001 From: Thijs Schreijer Date: Tue, 7 May 2024 10:13:29 +0200 Subject: [PATCH] chore(ai-prompt-decorator): improve error handling and cleanup (#12907) * chore(ai-prompt-decorator): improve error handling and cleanup * chore(test): standard test filenames * chore(ai-prompt-guard): improve error handling and cleanup (cherry picked from commit b1b5ac99e06a78f9c50aa3bd7c3dc0a20f14cede) --- changelog/unreleased/kong/cleanup_ai.yml | 4 + kong/plugins/ai-prompt-decorator/handler.lua | 56 +- kong/plugins/ai-prompt-guard/handler.lua | 161 +++-- kong/plugins/ai-prompt-guard/schema.lua | 8 +- .../41-ai-prompt-decorator/01-unit_spec.lua | 27 +- .../02-integration_spec.lua | 87 +-- ...{00_config_spec.lua => 00-config_spec.lua} | 7 + .../{01_unit_spec.lua => 01-unit_spec.lua} | 55 +- .../02-integration_spec.lua | 636 ++++++++++-------- 9 files changed, 609 insertions(+), 432 deletions(-) create mode 100644 changelog/unreleased/kong/cleanup_ai.yml rename spec/03-plugins/42-ai-prompt-guard/{00_config_spec.lua => 00-config_spec.lua} (99%) rename spec/03-plugins/42-ai-prompt-guard/{01_unit_spec.lua => 01-unit_spec.lua} (76%) diff --git a/changelog/unreleased/kong/cleanup_ai.yml b/changelog/unreleased/kong/cleanup_ai.yml new file mode 100644 index 000000000000..61e9c2c70dc4 --- /dev/null +++ b/changelog/unreleased/kong/cleanup_ai.yml @@ -0,0 +1,4 @@ +message: | + Cleanup some AI plugins, and improve errorhandling. +type: bugfix +scope: Plugin diff --git a/kong/plugins/ai-prompt-decorator/handler.lua b/kong/plugins/ai-prompt-decorator/handler.lua index 891ea77f4515..7103ce5903b4 100644 --- a/kong/plugins/ai-prompt-decorator/handler.lua +++ b/kong/plugins/ai-prompt-decorator/handler.lua @@ -1,13 +1,12 @@ -local _M = {} - --- imports -local kong_meta = require "kong.meta" -local new_tab = require("table.new") +local new_tab = require("table.new") local EMPTY = {} --- -_M.PRIORITY = 772 -_M.VERSION = kong_meta.version + +local plugin = { + PRIORITY = 772, + VERSION = require("kong.meta").version +} + local function bad_request(msg) @@ -15,14 +14,16 @@ local function bad_request(msg) return kong.response.exit(400, { error = { message = msg } }) end -function _M.execute(request, conf) + + +-- Adds the prompts to the request prepend/append. +-- @tparam table request The deserialized JSON body of the request +-- @tparam table conf The plugin configuration +-- @treturn table The decorated request (same table, content updated) +local function execute(request, conf) local prepend = conf.prompts.prepend or EMPTY local append = conf.prompts.append or EMPTY - if #prepend == 0 and #append == 0 then - return request, nil - end - local old_messages = request.messages local new_messages = new_tab(#append + #prepend + #old_messages, 0) request.messages = new_messages @@ -44,29 +45,34 @@ function _M.execute(request, conf) new_messages[n] = { role = msg.role, content = msg.content } end - return request, nil + return request end -function _M:access(conf) + + +function plugin:access(conf) kong.service.request.enable_buffering() kong.ctx.shared.ai_prompt_decorated = true -- future use -- if plugin ordering was altered, receive the "decorated" request - local request, err = kong.request.get_body("application/json") - if err then + local request = kong.request.get_body("application/json") + if type(request) ~= "table" then return bad_request("this LLM route only supports application/json requests") end - if not request.messages or #request.messages < 1 then + if #(request.messages or EMPTY) < 1 then return bad_request("this LLM route only supports llm/chat type requests") end - local decorated_request, err = self.execute(request, conf) - if err then - return bad_request(err) - end - - kong.service.request.set_body(decorated_request, "application/json") + kong.service.request.set_body(execute(request, conf), "application/json") end -return _M + + +if _G._TEST then + -- only if we're testing export this function (using a different name!) + plugin._execute = execute +end + + +return plugin diff --git a/kong/plugins/ai-prompt-guard/handler.lua b/kong/plugins/ai-prompt-guard/handler.lua index 50c64315f712..321fefad2024 100644 --- a/kong/plugins/ai-prompt-guard/handler.lua +++ b/kong/plugins/ai-prompt-guard/handler.lua @@ -1,112 +1,145 @@ -local _M = {} - --- imports -local kong_meta = require "kong.meta" -local buffer = require("string.buffer") +local buffer = require("string.buffer") local ngx_re_find = ngx.re.find --- +local EMPTY = {} -_M.PRIORITY = 771 -_M.VERSION = kong_meta.version -local function bad_request(msg, reveal_msg_to_client) - -- don't let users know 'ai-prompt-guard' is in use - kong.log.info(msg) - if not reveal_msg_to_client then - msg = "bad request" - end + +local plugin = { + PRIORITY = 771, + VERSION = require("kong.meta").version +} + + + +local function bad_request(msg) + kong.log.debug(msg) return kong.response.exit(400, { error = { message = msg } }) end -function _M.execute(request, conf) - local user_prompt - -- concat all 'user' prompts into one string, if conversation history must be checked - if request.messages and not conf.allow_all_conversation_history then - local buf = buffer.new() - for _, v in ipairs(request.messages) do - if v.role == "user" then - buf:put(v.content) +local execute do + local bad_format_error = "ai-prompt-guard only supports llm/v1/chat or llm/v1/completions prompts" + + -- Checks the prompt for the given patterns. + -- _Note_: if a regex fails, it returns a 500, and exits the request. + -- @tparam table request The deserialized JSON body of the request + -- @tparam table conf The plugin configuration + -- @treturn[1] table The decorated request (same table, content updated) + -- @treturn[2] nil + -- @treturn[2] string The error message + function execute(request, conf) + local user_prompt + + -- concat all 'user' prompts into one string, if conversation history must be checked + if type(request.messages) == "table" and not conf.allow_all_conversation_history then + local buf = buffer.new() + + for _, v in ipairs(request.messages) do + if type(v.role) ~= "string" then + return nil, bad_format_error + end + if v.role == "user" then + if type(v.content) ~= "string" then + return nil, bad_format_error + end + buf:put(v.content) + end end - end - user_prompt = buf:get() - - elseif request.messages then - -- just take the trailing 'user' prompt - for _, v in ipairs(request.messages) do - if v.role == "user" then - user_prompt = v.content + user_prompt = buf:get() + + elseif type(request.messages) == "table" then + -- just take the trailing 'user' prompt + for _, v in ipairs(request.messages) do + if type(v.role) ~= "string" then + return nil, bad_format_error + end + if v.role == "user" then + if type(v.content) ~= "string" then + return nil, bad_format_error + end + user_prompt = v.content + end end - end - elseif request.prompt then - user_prompt = request.prompt + elseif type(request.prompt) == "string" then + user_prompt = request.prompt - else - return nil, "ai-prompt-guard only supports llm/v1/chat or llm/v1/completions prompts" - end + else + return nil, bad_format_error + end + + if not user_prompt then + return nil, "no 'prompt' or 'messages' received" + end - if not user_prompt then - return nil, "no 'prompt' or 'messages' received" - end - -- check the prompt for explcit ban patterns - if conf.deny_patterns and #conf.deny_patterns > 0 then - for _, v in ipairs(conf.deny_patterns) do + -- check the prompt for explcit ban patterns + for _, v in ipairs(conf.deny_patterns or EMPTY) do -- check each denylist; if prompt matches it, deny immediately local m, _, err = ngx_re_find(user_prompt, v, "jo") if err then - return nil, "bad regex execution for: " .. v + -- regex failed, that's an error by the administrator + kong.log.err("bad regex pattern '", v ,"', failed to execute: ", err) + return kong.response.exit(500) elseif m then return nil, "prompt pattern is blocked" end end - end - -- if any allow_patterns specified, make sure the prompt matches one of them - if conf.allow_patterns and #conf.allow_patterns > 0 then - local valid = false - for _, v in ipairs(conf.allow_patterns) do + if #(conf.allow_patterns or EMPTY) == 0 then + -- no allow_patterns, so we're good + return true + end + + -- if any allow_patterns specified, make sure the prompt matches one of them + for _, v in ipairs(conf.allow_patterns or EMPTY) do -- check each denylist; if prompt matches it, deny immediately local m, _, err = ngx_re_find(user_prompt, v, "jo") if err then - return nil, "bad regex execution for: " .. v + -- regex failed, that's an error by the administrator + kong.log.err("bad regex pattern '", v ,"', failed to execute: ", err) + return kong.response.exit(500) elseif m then - valid = true - break + return true -- got a match so is allowed, exit early end end - if not valid then - return false, "prompt doesn't match any allowed pattern" - end + return false, "prompt doesn't match any allowed pattern" end - - return true, nil end -function _M:access(conf) + + +function plugin:access(conf) kong.service.request.enable_buffering() kong.ctx.shared.ai_prompt_guarded = true -- future use -- if plugin ordering was altered, receive the "decorated" request - local request, err = kong.request.get_body("application/json") - - if err then - return bad_request("this LLM route only supports application/json requests", true) + local request = kong.request.get_body("application/json") + if type(request) ~= "table" then + return bad_request("this LLM route only supports application/json requests") end -- run access handler - local ok, err = self.execute(request, conf) + local ok, err = execute(request, conf) if not ok then - return bad_request(err, false) + kong.log.debug(err) + return bad_request("bad request") -- don't let users know 'ai-prompt-guard' is in use end end -return _M + + +if _G._TEST then + -- only if we're testing export this function (using a different name!) + plugin._execute = execute +end + + +return plugin diff --git a/kong/plugins/ai-prompt-guard/schema.lua b/kong/plugins/ai-prompt-guard/schema.lua index d5a8e8aa1bd9..9c0172752bdb 100644 --- a/kong/plugins/ai-prompt-guard/schema.lua +++ b/kong/plugins/ai-prompt-guard/schema.lua @@ -8,9 +8,9 @@ return { type = "record", fields = { { allow_patterns = { - description = "Array of valid patterns, or valid questions from the 'user' role in chat.", + description = "Array of valid regex patterns, or valid questions from the 'user' role in chat.", type = "array", - default = {}, + required = false, len_max = 10, elements = { type = "string", @@ -18,9 +18,9 @@ return { len_max = 500, }}}, { deny_patterns = { - description = "Array of invalid patterns, or invalid questions from the 'user' role in chat.", + description = "Array of invalid regex patterns, or invalid questions from the 'user' role in chat.", type = "array", - default = {}, + required = false, len_max = 10, elements = { type = "string", diff --git a/spec/03-plugins/41-ai-prompt-decorator/01-unit_spec.lua b/spec/03-plugins/41-ai-prompt-decorator/01-unit_spec.lua index 9477d0c29912..57254beaa3ad 100644 --- a/spec/03-plugins/41-ai-prompt-decorator/01-unit_spec.lua +++ b/spec/03-plugins/41-ai-prompt-decorator/01-unit_spec.lua @@ -1,8 +1,5 @@ local PLUGIN_NAME = "ai-prompt-decorator" --- imports -local access_handler = require("kong.plugins.ai-prompt-decorator.handler") --- local function deepcopy(o, seen) seen = seen or {} @@ -108,8 +105,24 @@ local injector_conf_both = { }, } + + describe(PLUGIN_NAME .. ": (unit)", function() + local access_handler + + setup(function() + _G._TEST = true + package.loaded["kong.plugins.ai-prompt-decorator.handler"] = nil + access_handler = require("kong.plugins.ai-prompt-decorator.handler") + end) + + teardown(function() + _G._TEST = nil + end) + + + describe("chat v1 operations", function() it("adds messages to the start of the array", function() @@ -121,12 +134,13 @@ describe(PLUGIN_NAME .. ": (unit)", function() table.insert(expected_request_copy.messages, 2, injector_conf_prepend.prompts.prepend[2]) table.insert(expected_request_copy.messages, 3, injector_conf_prepend.prompts.prepend[3]) - local decorated_request, err = access_handler.execute(request_copy, injector_conf_prepend) + local decorated_request, err = access_handler._execute(request_copy, injector_conf_prepend) assert.is_nil(err) assert.same(decorated_request, expected_request_copy) end) + it("adds messages to the end of the array", function() local request_copy = deepcopy(general_chat_request) local expected_request_copy = deepcopy(general_chat_request) @@ -135,12 +149,13 @@ describe(PLUGIN_NAME .. ": (unit)", function() table.insert(expected_request_copy.messages, #expected_request_copy.messages + 1, injector_conf_append.prompts.append[1]) table.insert(expected_request_copy.messages, #expected_request_copy.messages + 1, injector_conf_append.prompts.append[2]) - local decorated_request, err = access_handler.execute(request_copy, injector_conf_append) + local decorated_request, err = access_handler._execute(request_copy, injector_conf_append) assert.is_nil(err) assert.same(expected_request_copy, decorated_request) end) + it("adds messages to the start and the end of the array", function() local request_copy = deepcopy(general_chat_request) local expected_request_copy = deepcopy(general_chat_request) @@ -152,7 +167,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() table.insert(expected_request_copy.messages, #expected_request_copy.messages + 1, injector_conf_both.prompts.append[1]) table.insert(expected_request_copy.messages, #expected_request_copy.messages + 1, injector_conf_both.prompts.append[2]) - local decorated_request, err = access_handler.execute(request_copy, injector_conf_both) + local decorated_request, err = access_handler._execute(request_copy, injector_conf_both) assert.is_nil(err) assert.same(expected_request_copy, decorated_request) diff --git a/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua b/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua index 4fdc8b025324..1f89724821b6 100644 --- a/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua +++ b/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua @@ -1,5 +1,4 @@ local helpers = require "spec.helpers" -local cjson = require "cjson" local PLUGIN_NAME = "ai-prompt-decorator" @@ -53,60 +52,62 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then })) end) + lazy_teardown(function() helpers.stop_kong() end) + before_each(function() client = helpers.proxy_client() end) + after_each(function() if client then client:close() end end) - describe("request", function() - it("sends in a non-chat message", function() - local r = client:get("/request", { - headers = { - host = "test1.com", - ["Content-Type"] = "application/json", - }, - body = [[ - { - "anything": [ - { - "random": "data" - } - ] - }]], - method = "POST", - }) - - local body = assert.res_status(400, r) - local json = cjson.decode(body) - - assert.same(json, { error = { message = "this LLM route only supports llm/chat type requests" }}) - end) - - it("sends in an empty messages array", function() - local r = client:get("/request", { - headers = { - host = "test1.com", - ["Content-Type"] = "application/json", - }, - body = [[ - { - "messages": [] - }]], - method = "POST", - }) - - local body = assert.res_status(400, r) - local json = cjson.decode(body) - - assert.same(json, { error = { message = "this LLM route only supports llm/chat type requests" }}) - end) + + + it("blocks a non-chat message", function() + local r = client:get("/request", { + headers = { + host = "test1.com", + ["Content-Type"] = "application/json", + }, + body = [[ + { + "anything": [ + { + "random": "data" + } + ] + }]], + method = "POST", + }) + + assert.response(r).has.status(400) + local json = assert.response(r).has.jsonbody() + assert.same(json, { error = { message = "this LLM route only supports llm/chat type requests" }}) + end) + + + it("blocks an empty messages array", function() + local r = client:get("/request", { + headers = { + host = "test1.com", + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [] + }]], + method = "POST", + }) + + assert.response(r).has.status(400) + local json = assert.response(r).has.jsonbody() + assert.same(json, { error = { message = "this LLM route only supports llm/chat type requests" }}) end) end) diff --git a/spec/03-plugins/42-ai-prompt-guard/00_config_spec.lua b/spec/03-plugins/42-ai-prompt-guard/00-config_spec.lua similarity index 99% rename from spec/03-plugins/42-ai-prompt-guard/00_config_spec.lua rename to spec/03-plugins/42-ai-prompt-guard/00-config_spec.lua index 69dd6c82c52d..103ed45840a5 100644 --- a/spec/03-plugins/42-ai-prompt-guard/00_config_spec.lua +++ b/spec/03-plugins/42-ai-prompt-guard/00-config_spec.lua @@ -11,7 +11,10 @@ local validate do end end + + describe(PLUGIN_NAME .. ": (schema)", function() + it("won't allow both allow_patterns and deny_patterns to be unset", function() local config = { allow_all_conversation_history = true, @@ -24,6 +27,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1]) end) + it("won't allow both allow_patterns and deny_patterns to be empty arrays", function() local config = { allow_all_conversation_history = true, @@ -38,6 +42,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() assert.equal("at least one of these fields must be non-empty: 'config.allow_patterns', 'config.deny_patterns'", err["@entity"][1]) end) + it("won't allow patterns that are too long", function() local config = { allow_all_conversation_history = true, @@ -53,6 +58,7 @@ describe(PLUGIN_NAME .. ": (schema)", function() assert.same({ config = {allow_patterns = { [1] = "length must be at most 500" }}}, err) end) + it("won't allow too many array items", function() local config = { allow_all_conversation_history = true, @@ -77,4 +83,5 @@ describe(PLUGIN_NAME .. ": (schema)", function() assert.not_nil(err) assert.same({ config = {allow_patterns = "length must be at most 10" }}, err) end) + end) diff --git a/spec/03-plugins/42-ai-prompt-guard/01_unit_spec.lua b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua similarity index 76% rename from spec/03-plugins/42-ai-prompt-guard/01_unit_spec.lua rename to spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua index ac82622755cd..9007376fcf03 100644 --- a/spec/03-plugins/42-ai-prompt-guard/01_unit_spec.lua +++ b/spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua @@ -1,5 +1,5 @@ local PLUGIN_NAME = "ai-prompt-guard" -local access_handler = require("kong.plugins.ai-prompt-guard.handler") + local general_chat_request = { @@ -114,62 +114,84 @@ local both_patterns_no_history = { allow_all_conversation_history = true, } + + describe(PLUGIN_NAME .. ": (unit)", function() + local access_handler + + setup(function() + _G._TEST = true + package.loaded["kong.plugins.ai-prompt-guard.handler"] = nil + access_handler = require("kong.plugins.ai-prompt-guard.handler") + end) + + teardown(function() + _G._TEST = nil + end) + + describe("chat operations", function() it("allows request when only conf.allow_patterns is set", function() - local ok, err = access_handler.execute(general_chat_request, allow_patterns_no_history) + local ok, err = access_handler._execute(general_chat_request, allow_patterns_no_history) assert.is_truthy(ok) assert.is_nil(err) end) + it("allows request when only conf.deny_patterns is set, and pattern should not match", function() - local ok, err = access_handler.execute(general_chat_request, deny_patterns_no_history) + local ok, err = access_handler._execute(general_chat_request, deny_patterns_no_history) assert.is_truthy(ok) assert.is_nil(err) end) + it("denies request when only conf.allow_patterns is set, and pattern should not match", function() - local ok, err = access_handler.execute(denied_chat_request, allow_patterns_no_history) + local ok, err = access_handler._execute(denied_chat_request, allow_patterns_no_history) assert.is_falsy(ok) assert.equal(err, "prompt doesn't match any allowed pattern") end) + it("denies request when only conf.deny_patterns is set, and pattern should match", function() - local ok, err = access_handler.execute(denied_chat_request, deny_patterns_no_history) + local ok, err = access_handler._execute(denied_chat_request, deny_patterns_no_history) assert.is_falsy(ok) assert.equal(err, "prompt pattern is blocked") end) + it("allows request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches allow", function() - local ok, err = access_handler.execute(general_chat_request, both_patterns_no_history) + local ok, err = access_handler._execute(general_chat_request, both_patterns_no_history) assert.is_truthy(ok) assert.is_nil(err) end) + it("denies request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches neither", function() - local ok, err = access_handler.execute(neither_allowed_nor_denied_chat_request, both_patterns_no_history) + local ok, err = access_handler._execute(neither_allowed_nor_denied_chat_request, both_patterns_no_history) assert.is_falsy(ok) assert.equal(err, "prompt doesn't match any allowed pattern") end) + it("denies request when only conf.allow_patterns is set and previous chat history should not match", function() - local ok, err = access_handler.execute(general_chat_request_with_history, allow_patterns_with_history) + local ok, err = access_handler._execute(general_chat_request_with_history, allow_patterns_with_history) assert.is_falsy(ok) assert.equal(err, "prompt doesn't match any allowed pattern") end) + it("denies request when only conf.deny_patterns is set and previous chat history should match", function() - local ok, err = access_handler.execute(general_chat_request_with_history, deny_patterns_with_history) + local ok, err = access_handler._execute(general_chat_request_with_history, deny_patterns_with_history) assert.is_falsy(ok) assert.equal(err, "prompt pattern is blocked") @@ -181,35 +203,39 @@ describe(PLUGIN_NAME .. ": (unit)", function() describe("completions operations", function() it("allows request when only conf.allow_patterns is set", function() - local ok, err = access_handler.execute(general_completions_request, allow_patterns_no_history) + local ok, err = access_handler._execute(general_completions_request, allow_patterns_no_history) assert.is_truthy(ok) assert.is_nil(err) end) + it("allows request when only conf.deny_patterns is set, and pattern should not match", function() - local ok, err = access_handler.execute(general_completions_request, deny_patterns_no_history) + local ok, err = access_handler._execute(general_completions_request, deny_patterns_no_history) assert.is_truthy(ok) assert.is_nil(err) end) + it("denies request when only conf.allow_patterns is set, and pattern should not match", function() - local ok, err = access_handler.execute(denied_completions_request, allow_patterns_no_history) + local ok, err = access_handler._execute(denied_completions_request, allow_patterns_no_history) assert.is_falsy(ok) assert.equal(err, "prompt doesn't match any allowed pattern") end) + it("denies request when only conf.deny_patterns is set, and pattern should match", function() - local ok, err = access_handler.execute(denied_completions_request, deny_patterns_no_history) + local ok, err = access_handler._execute(denied_completions_request, deny_patterns_no_history) assert.is_falsy(ok) assert.equal("prompt pattern is blocked", err) end) + it("denies request when both conf.allow_patterns and conf.deny_patterns are set, and pattern matches neither", function() - local ok, err = access_handler.execute(neither_allowed_nor_denied_completions_request, both_patterns_no_history) + local ok, err = access_handler._execute(neither_allowed_nor_denied_completions_request, both_patterns_no_history) assert.is_falsy(ok) assert.equal(err, "prompt doesn't match any allowed pattern") @@ -217,5 +243,4 @@ describe(PLUGIN_NAME .. ": (unit)", function() end) - end) diff --git a/spec/03-plugins/42-ai-prompt-guard/02-integration_spec.lua b/spec/03-plugins/42-ai-prompt-guard/02-integration_spec.lua index 05258f659cc9..c31aa0cf0241 100644 --- a/spec/03-plugins/42-ai-prompt-guard/02-integration_spec.lua +++ b/spec/03-plugins/42-ai-prompt-guard/02-integration_spec.lua @@ -121,6 +121,37 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } -- + local bad_regex_allow = bp.routes:insert({ + paths = { "~/bad-regex-allow$" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = bad_regex_allow.id }, + config = { + deny_patterns = { + [1] = "[]", + }, + allow_all_conversation_history = false, + }, + } + + local bad_regex_deny = bp.routes:insert({ + paths = { "~/bad-regex-deny$" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = bad_regex_deny.id }, + config = { + deny_patterns = { + [1] = "[]", + }, + allow_all_conversation_history = false, + }, + } + + -- assert(helpers.start_kong({ database = strategy, nginx_conf = "spec/fixtures/custom_nginx.template", @@ -129,300 +160,355 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then })) end) + + lazy_teardown(function() helpers.stop_kong() end) + before_each(function() client = helpers.proxy_client() end) + after_each(function() if client then client:close() end end) - describe("request", function() - -- both - it("allows message with 'allow' and 'deny' set, with history", function() - local r = client:get("/permit-history", { - headers = { - ["Content-Type"] = "application/json", - }, - body = [[ - { - "messages": [ - { - "role": "system", - "content": "You run a cheese shop." - }, - { - "role": "user", - "content": "I think that cheddar is the best cheese." - }, - { - "role": "assistant", - "content": "No, brie is the best cheese." - }, - { - "role": "user", - "content": "Why brie?" - } - ] - } - ]], - method = "POST", - }) - - -- the body is just an echo, don't need to test it - assert.res_status(200, r) - end) - - it("allows message with 'allow' and 'deny' set, without history", function() - local r = client:get("/block-history", { - headers = { - ["Content-Type"] = "application/json", - }, - body = [[ - { - "messages": [ - { - "role": "system", - "content": "You run a cheese shop." - }, - { - "role": "user", - "content": "I think that cheddar is the best cheese." - }, - { - "role": "assistant", - "content": "No, brie is the best cheese." - }, - { - "role": "user", - "content": "Why brie?" - } - ] - } - ]], - method = "POST", - }) - - assert.res_status(200, r) - end) - - it("blocks message with 'allow' and 'deny' set, with history", function() - local r = client:get("/permit-history", { - headers = { - ["Content-Type"] = "application/json", - }, - body = [[ - { - "messages": [ - { - "role": "system", - "content": "You run a cheese shop." - }, - { - "role": "user", - "content": "I think that cheddar or edam are the best cheeses." - }, - { - "role": "assistant", - "content": "No, brie is the best cheese." - }, - { - "role": "user", - "content": "Why?" - } - ] - } - ]], - method = "POST", - }) - - assert.res_status(400, r) - end) - -- - -- allows only - it("allows message with 'allow' only set, with history", function() - local r = client:get("/allow-only-permit-history", { - headers = { - ["Content-Type"] = "application/json", - }, - body = [[ - { - "messages": [ - { - "role": "system", - "content": "You run a cheese shop." - }, - { - "role": "user", - "content": "I think that brie is the best cheese." - }, - { - "role": "assistant", - "content": "No, cheddar is the best cheese." - }, - { - "role": "user", - "content": "Why cheddar?" - } - ] - } - ]], - method = "POST", - }) - - assert.res_status(200, r) - end) - - it("allows message with 'allow' only set, without history", function() - local r = client:get("/allow-only-block-history", { - headers = { - ["Content-Type"] = "application/json", - }, - body = [[ - { - "messages": [ - { - "role": "system", - "content": "You run a cheese shop." - }, - { - "role": "user", - "content": "I think that brie is the best cheese." - }, - { - "role": "assistant", - "content": "No, cheddar is the best cheese." - }, - { - "role": "user", - "content": "Why cheddar?" - } - ] - } - ]], - method = "POST", - }) - - assert.res_status(200, r) - end) + -- both + it("allows message with 'allow' and 'deny' set, with history", function() + local r = client:get("/permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that cheddar is the best cheese." + }, + { + "role": "assistant", + "content": "No, brie is the best cheese." + }, + { + "role": "user", + "content": "Why brie?" + } + ] + } + ]], + method = "POST", + }) - -- denies only - it("allows message with 'deny' only set, permit history", function() - local r = client:get("/deny-only-permit-history", { - headers = { - ["Content-Type"] = "application/json", - }, - - -- this will be permitted, because the BAD PHRASE is only in chat history, - -- which the developer "controls" - body = [[ - { - "messages": [ - { - "role": "system", - "content": "You run a cheese shop." - }, - { - "role": "user", - "content": "I think that leicester is the best cheese." - }, - { - "role": "assistant", - "content": "No, cheddar is the best cheese." - }, - { - "role": "user", - "content": "Why cheddar?" - } - ] - } - ]], - method = "POST", - }) - - assert.res_status(200, r) - end) - - it("blocks message with 'deny' only set, permit history", function() - local r = client:get("/deny-only-permit-history", { - headers = { - ["Content-Type"] = "application/json", - }, + -- the body is just an echo, don't need to test it + assert.response(r).has.status(200) + end) - -- this will be blocks, because the BAD PHRASE is in the latest chat message, - -- which the user "controls" - body = [[ - { - "messages": [ - { - "role": "system", - "content": "You run a cheese shop." - }, - { - "role": "user", - "content": "I think that leicester is the best cheese." - }, - { - "role": "assistant", - "content": "No, edam is the best cheese." - }, - { - "role": "user", - "content": "Why edam?" - } - ] - } - ]], - method = "POST", - }) - - assert.res_status(400, r) - end) - - it("blocks message with 'deny' only set, scan history", function() - local r = client:get("/deny-only-block-history", { - headers = { - ["Content-Type"] = "application/json", - }, - -- this will NOT be permitted, because the BAD PHRASE is in chat history, - -- as specified by the Kong admins - body = [[ - { - "messages": [ - { - "role": "system", - "content": "You run a cheese shop." - }, - { - "role": "user", - "content": "I think that leicester is the best cheese." - }, - { - "role": "assistant", - "content": "No, cheddar is the best cheese." - }, - { - "role": "user", - "content": "Why cheddar?" - } - ] - } - ]], - method = "POST", - }) - - assert.res_status(400, r) - end) - -- + it("allows message with 'allow' and 'deny' set, without history", function() + local r = client:get("/block-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that cheddar is the best cheese." + }, + { + "role": "assistant", + "content": "No, brie is the best cheese." + }, + { + "role": "user", + "content": "Why brie?" + } + ] + } + ]], + method = "POST", + }) + assert.response(r).has.status(200) end) + + + it("blocks message with 'allow' and 'deny' set, with history", function() + local r = client:get("/permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that cheddar or edam are the best cheeses." + }, + { + "role": "assistant", + "content": "No, brie is the best cheese." + }, + { + "role": "user", + "content": "Why?" + } + ] + } + ]], + method = "POST", + }) + + assert.response(r).has.status(400) + end) + -- + + + -- allows only + it("allows message with 'allow' only set, with history", function() + local r = client:get("/allow-only-permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that brie is the best cheese." + }, + { + "role": "assistant", + "content": "No, cheddar is the best cheese." + }, + { + "role": "user", + "content": "Why cheddar?" + } + ] + } + ]], + method = "POST", + }) + + assert.response(r).has.status(200) + end) + + + it("allows message with 'allow' only set, without history", function() + local r = client:get("/allow-only-block-history", { + headers = { + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that brie is the best cheese." + }, + { + "role": "assistant", + "content": "No, cheddar is the best cheese." + }, + { + "role": "user", + "content": "Why cheddar?" + } + ] + } + ]], + method = "POST", + }) + + assert.response(r).has.status(200) + end) + + + -- denies only + it("allows message with 'deny' only set, permit history", function() + local r = client:get("/deny-only-permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + + -- this will be permitted, because the BAD PHRASE is only in chat history, + -- which the developer "controls" + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that leicester is the best cheese." + }, + { + "role": "assistant", + "content": "No, cheddar is the best cheese." + }, + { + "role": "user", + "content": "Why cheddar?" + } + ] + } + ]], + method = "POST", + }) + + assert.response(r).has.status(200) + end) + + + it("blocks message with 'deny' only set, permit history", function() + local r = client:get("/deny-only-permit-history", { + headers = { + ["Content-Type"] = "application/json", + }, + + -- this will be blocks, because the BAD PHRASE is in the latest chat message, + -- which the user "controls" + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that leicester is the best cheese." + }, + { + "role": "assistant", + "content": "No, edam is the best cheese." + }, + { + "role": "user", + "content": "Why edam?" + } + ] + } + ]], + method = "POST", + }) + + assert.response(r).has.status(400) + end) + + + it("blocks message with 'deny' only set, scan history", function() + local r = client:get("/deny-only-block-history", { + headers = { + ["Content-Type"] = "application/json", + }, + + -- this will NOT be permitted, because the BAD PHRASE is in chat history, + -- as specified by the Kong admins + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + }, + { + "role": "user", + "content": "I think that leicester is the best cheese." + }, + { + "role": "assistant", + "content": "No, cheddar is the best cheese." + }, + { + "role": "user", + "content": "Why cheddar?" + } + ] + } + ]], + method = "POST", + }) + + assert.response(r).has.status(400) + end) + + + it("returns a 500 on a bad regex in allow list", function() + local r = client:get("/bad-regex-allow", { + headers = { + ["Content-Type"] = "application/json", + }, + + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + } + ] + } + ]], + method = "POST", + }) + + assert.response(r).has.status(500) + end) + + + it("returns a 500 on a bad regex in deny list", function() + local r = client:get("/bad-regex-deny", { + headers = { + ["Content-Type"] = "application/json", + }, + + body = [[ + { + "messages": [ + { + "role": "system", + "content": "You run a cheese shop." + } + ] + } + ]], + method = "POST", + }) + + assert.response(r).has.status(500) + end) + end) end end