Skip to content

Commit

Permalink
feat(ai-prompt-gaurd): add match_all_roles option to match non-user
Browse files Browse the repository at this point in the history
message
  • Loading branch information
fffonion committed Jul 11, 2024
1 parent 8e3a665 commit 409cf9c
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 11 deletions.
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/feat-ai-prompt-guard-all-roles.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: "**AI-Prompt-Guard**: add `match_all_roles` option to allow match all roles in addition to `user`."
type: feature
scope: Plugin
1 change: 1 addition & 0 deletions kong/clustering/compat/removed_fields.lua
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ return {
"max_request_body_size",
},
ai_prompt_guard = {
"match_all_roles",
"max_request_body_size",
},
ai_prompt_template = {
Expand Down
6 changes: 3 additions & 3 deletions kong/plugins/ai-prompt-guard/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ local execute do
-- _Note_: if a regex fails, it returns a 500, and exits the request.
-- @tparam table request The deserialized JSON body of the request
-- @tparam table conf The plugin configuration
-- @treturn[1] table The decorated request (same table, content updated)
-- @treturn[2] string The error message
-- @treturn table The decorated request (same table, content updated)
-- @treturn nil|string The error message
function execute(request, conf)
local collected_prompts
local messages = request.messages
Expand All @@ -44,7 +44,7 @@ local execute do
if type(v.role) ~= "string" then
return nil, bad_format_error
end
if v.role == "user" then
if v.role == "user" or conf.match_all_roles then
if type(v.content) ~= "string" then
return nil, bad_format_error
end
Expand Down
14 changes: 11 additions & 3 deletions kong/plugins/ai-prompt-guard/schema.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,23 @@ return {
type = "integer",
default = 8 * 1024,
gt = 0,
description = "max allowed body size allowed to be introspected",}
},
description = "max allowed body size allowed to be introspected" } },
{ match_all_roles = {
description = "If true, will match all roles in addition to 'user' role in conversation history.",
type = "boolean",
required = true,
default = false } },
}
}
}
},
entity_checks = {
{
at_least_one_of = { "config.allow_patterns", "config.deny_patterns" },
}
},
{ conditional = {
if_field = "config.match_all_roles", if_match = { eq = true },
then_field = "config.allow_all_conversation_history", then_match = { eq = false },
} },
}
}
24 changes: 24 additions & 0 deletions spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,30 @@ describe("CP/DP config compat transformations #" .. strategy, function()
-- cleanup
admin.plugins:remove({ id = ai_response_transformer.id })
end)

it("[ai-prompt-guard] sets unsupported match_all_roles to nil or defaults", function()
-- [[ 3.8.x ]] --
local ai_prompt_guard = admin.plugins:insert {
name = "ai-prompt-guard",
enabled = true,
config = {
allow_patterns = { "a" },
allow_all_conversation_history = false,
match_all_roles = true,
max_request_body_size = 8192,
},
}
-- ]]

local expected = cycle_aware_deep_copy(ai_prompt_guard)
expected.config.match_all_roles = nil
expected.config.max_request_body_size = nil

do_assert(uuid(), "3.7.0", expected)

-- cleanup
admin.plugins:remove({ id = ai_prompt_guard.id })
end)
end)

describe("www-authenticate header in plugins (realm config)", function()
Expand Down
18 changes: 18 additions & 0 deletions spec/03-plugins/42-ai-prompt-guard/00-config_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,22 @@ describe(PLUGIN_NAME .. ": (schema)", function()
assert.same({ config = {allow_patterns = "length must be at most 10" }}, err)
end)

it("allow_all_conversation_history needs to be false if match_all_roles is set to true", function()
local config = {
allow_patterns = { "wat" },
allow_all_conversation_history = true,
match_all_roles = true,
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.same({
["@entity"] = {
[1] = 'failed conditional validation given value of field \'config.match_all_roles\'' },
["config"] = {
["allow_all_conversation_history"] = 'value must be false' }}, err)
end)

end)
20 changes: 15 additions & 5 deletions spec/03-plugins/42-ai-prompt-guard/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()


for _, request_type in ipairs({"chat", "completions"}) do

describe(request_type .. " operations", function()
it("allows a user request when nothing is set", function()
-- deny_pattern in this case should be made to have no effect
Expand All @@ -82,7 +83,13 @@ describe(PLUGIN_NAME .. ": (unit)", function()
assert.is_nil(err)
end)

-- only chat has history
-- match_all_roles require history
for _, has_history in ipairs({false, request_type == "chat" and true or nil}) do
for _, match_all_roles in ipairs({false, has_history and true or nil}) do

-- we only have user or not user, so testing "assistant" is not necessary
local role = match_all_roles and "system" or "user"

describe("conf.allow_patterns is set", function()
for _, has_deny_patterns in ipairs({true, false}) do
Expand All @@ -92,7 +99,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()

it("allows a matching user request" .. test_description, function()
-- deny_pattern in this case should be made to have no effect
local ctx = create_request(request_type):append_message("user", "pattern")
local ctx = create_request(request_type):append_message(role, "pattern")

if has_history then
ctx:append_message("user", "no match")
Expand All @@ -103,6 +110,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
},
deny_patterns = has_deny_patterns and {"deny match"} or nil,
allow_all_conversation_history = not has_history,
match_all_roles = match_all_roles,
})

assert.is_truthy(ok)
Expand All @@ -117,7 +125,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
ctx:append_message("user", "no match")
else
-- if we are ignoring history, actually put a matched message in history to test edge case
ctx:append_message("user", "pattern"):append_message("user", "no match")
ctx:append_message(role, "pattern"):append_message("user", "no match")
end

local ok, err = access_handler._execute(ctx, {
Expand All @@ -126,6 +134,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
},
deny_patterns = has_deny_patterns and {"deny match"} or nil,
allow_all_conversation_history = not has_history,
match_all_roles = match_all_roles,
})

assert.is_falsy(ok)
Expand All @@ -143,7 +152,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()

it("denies a matching user request" .. test_description, function()
-- allow_pattern in this case should be made to have no effect
local ctx = create_request(request_type):append_message("user", "pattern")
local ctx = create_request(request_type):append_message(role, "pattern")

if has_history then
ctx:append_message("user", "no match")
Expand All @@ -162,13 +171,13 @@ describe(PLUGIN_NAME .. ": (unit)", function()

it("allows unmatched user request" .. test_description, function()
-- allow_pattern in this case should be made to have no effect
local ctx = create_request(request_type):append_message("user", "allow match")
local ctx = create_request(request_type):append_message(role, "allow match")

if has_history then
ctx:append_message("user", "no match")
else
-- if we are ignoring history, actually put a matched message in history to test edge case
ctx:append_message("user", "pattern"):append_message("user", "allow match")
ctx:append_message(role, "pattern"):append_message(role, "allow match")
end

local ok, err = access_handler._execute(ctx, {
Expand All @@ -185,6 +194,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
end -- for for _, has_allow_patterns in ipairs({true, false}) do
end)

end -- for _, match_all_role in ipairs(false, true)) do
end -- for _, has_history in ipairs({true, false}) do
end)
end -- for _, request_type in ipairs({"chat", "completions"}) do
Expand Down

0 comments on commit 409cf9c

Please sign in to comment.