-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: ai-content-moderation plugin #11541
Changes from 22 commits
c80f932
d29f928
8ac0738
7ffa489
5214d0d
2fe1ea2
5cecf2a
cf08f04
460081e
d475eb0
6350d15
f713f87
e16a823
b21b64b
ee34e37
12529f0
57c59ab
093d7a9
7b52fa5
6e3bee2
6a2d575
6bb399c
1f4528d
ef16068
f6f3451
f3672fa
8447d6d
0949327
3a616e1
4c1f2a6
5b1be91
a3e47b2
81958e4
3da00a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,176 @@ | ||||||||
-- | ||||||||
-- Licensed to the Apache Software Foundation (ASF) under one or more | ||||||||
-- contributor license agreements. See the NOTICE file distributed with | ||||||||
-- this work for additional information regarding copyright ownership. | ||||||||
-- The ASF licenses this file to You under the Apache License, Version 2.0 | ||||||||
-- (the "License"); you may not use this file except in compliance with | ||||||||
-- the License. You may obtain a copy of the License at | ||||||||
-- | ||||||||
-- http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
-- | ||||||||
-- Unless required by applicable law or agreed to in writing, software | ||||||||
-- distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
-- See the License for the specific language governing permissions and | ||||||||
-- limitations under the License. | ||||||||
-- | ||||||||
local core = require("apisix.core") | ||||||||
local aws = require("resty.aws") | ||||||||
local http = require("resty.http") | ||||||||
local fetch_secrets = require("apisix.secret").fetch_secrets | ||||||||
|
||||||||
local aws_instance = aws() | ||||||||
local next = next | ||||||||
local pairs = pairs | ||||||||
local unpack = unpack | ||||||||
local type = type | ||||||||
local ipairs = ipairs | ||||||||
local require = require | ||||||||
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
const var. sugg. this style |
||||||||
local bad_request = ngx.HTTP_BAD_REQUEST | ||||||||
|
||||||||
|
||||||||
local aws_comprehend_schema = { | ||||||||
type = "object", | ||||||||
properties = { | ||||||||
access_key_id = { type = "string" }, | ||||||||
secret_access_key = { type = "string" }, | ||||||||
region = { type = "string" }, | ||||||||
endpoint = { | ||||||||
type = "string", | ||||||||
pattern = [[^https?://]] | ||||||||
}, | ||||||||
}, | ||||||||
required = { "access_key_id", "secret_access_key", "region", } | ||||||||
} | ||||||||
|
||||||||
local moderation_categories_pattern = "^(PROFANITY|HATE_SPEECH|INSULT|".. | ||||||||
"HARASSMENT_OR_ABUSE|SEXUAL|VIOLENCE_OR_THREAT)$" | ||||||||
local schema = { | ||||||||
type = "object", | ||||||||
properties = { | ||||||||
provider = { | ||||||||
type = "object", | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
To make sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||
properties = { | ||||||||
aws_comprehend = aws_comprehend_schema | ||||||||
}, | ||||||||
-- change to oneOf/enum while implementing support for other services | ||||||||
required = { "aws_comprehend" } | ||||||||
}, | ||||||||
moderation_categories = { | ||||||||
type = "object", | ||||||||
patternProperties = { | ||||||||
-- luacheck: push max code line length 300 | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
[moderation_categories_pattern] = { | ||||||||
-- luacheck: pop | ||||||||
type = "number", | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
minimum = 0, | ||||||||
maximum = 1 | ||||||||
} | ||||||||
}, | ||||||||
additionalProperties = false | ||||||||
}, | ||||||||
toxicity_level = { | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
type = "number", | ||||||||
minimum = 0, | ||||||||
maximum = 1, | ||||||||
default = 0.5 | ||||||||
}, | ||||||||
type = { | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
type = "string", | ||||||||
enum = { "openai" }, | ||||||||
} | ||||||||
}, | ||||||||
required = { "provider", "type" }, | ||||||||
} | ||||||||
|
||||||||
|
||||||||
local _M = { | ||||||||
version = 0.1, | ||||||||
priority = 1040, -- TODO: might change | ||||||||
name = "ai-content-moderation", | ||||||||
schema = schema, | ||||||||
} | ||||||||
|
||||||||
|
||||||||
function _M.check_schema(conf) | ||||||||
return core.schema.check(schema, conf) | ||||||||
end | ||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two blank lines between functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||||||||
function _M.rewrite(conf, ctx) | ||||||||
conf = fetch_secrets(conf, true, conf, "") | ||||||||
if not conf then | ||||||||
return internal_server_error, "failed to retrieve secrets from conf" | ||||||||
end | ||||||||
|
||||||||
local body, err = core.request.get_json_request_body_table() | ||||||||
if not body then | ||||||||
return bad_request, err | ||||||||
end | ||||||||
|
||||||||
local msgs = body.messages | ||||||||
if type(msgs) ~= "table" or #msgs < 1 then | ||||||||
return bad_request, "messages not found in request body" | ||||||||
end | ||||||||
|
||||||||
local provider = conf.provider[next(conf.provider)] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the schema should avoid this from happening. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current schema definition does not seem to be able to prevent multiple properties from being entered incorrectly. It is recommended that you consider adding a |
||||||||
|
||||||||
-- TODO support secret | ||||||||
local credentials = aws_instance:Credentials({ | ||||||||
accessKeyId = provider.access_key_id, | ||||||||
secretAccessKey = provider.secret_access_key, | ||||||||
sessionToken = provider.session_token, | ||||||||
}) | ||||||||
|
||||||||
local default_endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com" | ||||||||
local scheme, host, port = unpack(http:parse_uri(provider.endpoint or default_endpoint)) | ||||||||
local endpoint = scheme .. "://" .. host | ||||||||
aws_instance.config.endpoint = endpoint | ||||||||
aws_instance.config.ssl_verify = false | ||||||||
bzp2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
local comprehend = aws_instance:Comprehend({ | ||||||||
credentials = credentials, | ||||||||
endpoint = endpoint, | ||||||||
region = provider.region, | ||||||||
port = port, | ||||||||
}) | ||||||||
|
||||||||
local ai_module = require("apisix.plugins.ai." .. conf.type) | ||||||||
local create_request_text_segments = ai_module.create_request_text_segments | ||||||||
|
||||||||
local text_segments = create_request_text_segments(msgs) | ||||||||
local res, err = comprehend:detectToxicContent({ | ||||||||
LanguageCode = "en", | ||||||||
TextSegments = text_segments, | ||||||||
}) | ||||||||
|
||||||||
if not res then | ||||||||
core.log.error("failed to send request to ", provider, ": ", err) | ||||||||
return internal_server_error, err | ||||||||
end | ||||||||
|
||||||||
local results = res.body and res.body.ResultList | ||||||||
if not results or type(results) ~= "table" or #results < 1 then | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
return internal_server_error, "failed to get moderation results from response" | ||||||||
end | ||||||||
|
||||||||
for _, result in ipairs(results) do | ||||||||
if conf.moderation_categories then | ||||||||
for _, item in pairs(result.Labels) do | ||||||||
if not conf.moderation_categories[item.Name] then | ||||||||
goto continue | ||||||||
end | ||||||||
if item.Score > conf.moderation_categories[item.Name] then | ||||||||
return bad_request, "request body exceeds " .. item.Name .. " threshold" | ||||||||
end | ||||||||
::continue:: | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
if result.Toxicity > conf.toxicity_level then | ||||||||
return bad_request, "request body exceeds toxicity threshold" | ||||||||
end | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
return _M |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
-- | ||
-- Licensed to the Apache Software Foundation (ASF) under one or more | ||
-- contributor license agreements. See the NOTICE file distributed with | ||
-- this work for additional information regarding copyright ownership. | ||
-- The ASF licenses this file to You under the Apache License, Version 2.0 | ||
-- (the "License"); you may not use this file except in compliance with | ||
-- the License. You may obtain a copy of the License at | ||
-- | ||
-- http://www.apache.org/licenses/LICENSE-2.0 | ||
-- | ||
-- Unless required by applicable law or agreed to in writing, software | ||
-- distributed under the License is distributed on an "AS IS" BASIS, | ||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
-- See the License for the specific language governing permissions and | ||
-- limitations under the License. | ||
-- | ||
local core = require("apisix.core") | ||
local ipairs = ipairs | ||
|
||
local _M = {} | ||
|
||
|
||
function _M.create_request_text_segments(msgs) | ||
local text_segments = {} | ||
for _, msg in ipairs(msgs) do | ||
core.table.insert_tail(text_segments, { | ||
Text = msg.content | ||
}) | ||
end | ||
return text_segments | ||
end | ||
|
||
return _M |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then we can remove
local aws = require("resty.aws")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done