diff --git a/Makefile b/Makefile index 545a21e4f29f..36447611ca79 100644 --- a/Makefile +++ b/Makefile @@ -380,6 +380,10 @@ install: runtime $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers $(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers + # ai-content-moderation plugin + $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai + $(ENV_INSTALL) apisix/plugins/ai/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai + $(ENV_INSTALL) bin/apisix $(ENV_INST_BINDIR)/apisix diff --git a/apisix-master-0.rockspec b/apisix-master-0.rockspec index 913a4defe39d..fb288e751787 100644 --- a/apisix-master-0.rockspec +++ b/apisix-master-0.rockspec @@ -82,7 +82,7 @@ dependencies = { "lua-resty-t1k = 1.1.5", "brotli-ffi = 0.3-1", "lua-ffi-zlib = 0.6-0", - "api7-lua-resty-aws == 2.0.1-1", + "api7-lua-resty-aws == 2.0.2-1", } build = { diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua index f5c5d8dcaf94..8baf6bbe008b 100644 --- a/apisix/cli/config.lua +++ b/apisix/cli/config.lua @@ -215,6 +215,7 @@ local _M = { "body-transformer", "ai-prompt-template", "ai-prompt-decorator", + "ai-content-moderation", "proxy-mirror", "proxy-rewrite", "workflow", diff --git a/apisix/plugins/ai-content-moderation.lua b/apisix/plugins/ai-content-moderation.lua new file mode 100644 index 000000000000..19029a65348d --- /dev/null +++ b/apisix/plugins/ai-content-moderation.lua @@ -0,0 +1,179 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local core = require("apisix.core") +local aws_instance = require("resty.aws")() +local http = require("resty.http") +local fetch_secrets = require("apisix.secret").fetch_secrets + +local next = next +local pairs = pairs +local unpack = unpack +local type = type +local ipairs = ipairs +local require = require +local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR +local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST + + +local aws_comprehend_schema = { + type = "object", + properties = { + access_key_id = { type = "string" }, + secret_access_key = { type = "string" }, + region = { type = "string" }, + endpoint = { + type = "string", + pattern = [[^https?://]] + }, + ssl_verify = { + type = "boolean", + default = true + } + }, + required = { "access_key_id", "secret_access_key", "region", } +} + +local moderation_categories_pattern = "^(PROFANITY|HATE_SPEECH|INSULT|".. + "HARASSMENT_OR_ABUSE|SEXUAL|VIOLENCE_OR_THREAT)$" +local schema = { + type = "object", + properties = { + provider = { + type = "object", + properties = { + aws_comprehend = aws_comprehend_schema + }, + maxProperties = 1, + -- ensure only one provider can be configured while implementing support for + -- other providers + required = { "aws_comprehend" } + }, + moderation_categories = { + type = "object", + patternProperties = { + [moderation_categories_pattern] = { + type = "number", + minimum = 0, + maximum = 1 + } + }, + additionalProperties = false + }, + moderation_threshold = { + type = "number", + minimum = 0, + maximum = 1, + default = 0.5 + }, + llm_provider = { + type = "string", + enum = { "openai" }, + } + }, + required = { "provider", "llm_provider" }, +} + + +local _M = { + version = 0.1, + priority = 1040, -- TODO: might change + name = "ai-content-moderation", + schema = schema, +} + + +function _M.check_schema(conf) + return core.schema.check(schema, conf) +end + + +function _M.rewrite(conf, ctx) + conf = fetch_secrets(conf, true, conf, "") + if not conf then + return HTTP_INTERNAL_SERVER_ERROR, "failed to retrieve secrets from conf" + end + + local body, err = core.request.get_json_request_body_table() + if not body then + return HTTP_BAD_REQUEST, err + end + + local msgs = body.messages + if type(msgs) ~= "table" or #msgs < 1 then + return HTTP_BAD_REQUEST, "messages not found in request body" + end + + local provider = conf.provider[next(conf.provider)] + + local credentials = aws_instance:Credentials({ + accessKeyId = provider.access_key_id, + secretAccessKey = provider.secret_access_key, + sessionToken = provider.session_token, + }) + + local default_endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com" + local scheme, host, port = unpack(http:parse_uri(provider.endpoint or default_endpoint)) + local endpoint = scheme .. "://" .. host + aws_instance.config.endpoint = endpoint + aws_instance.config.ssl_verify = provider.ssl_verify + + local comprehend = aws_instance:Comprehend({ + credentials = credentials, + endpoint = endpoint, + region = provider.region, + port = port, + }) + + local ai_module = require("apisix.plugins.ai." .. conf.llm_provider) + local create_request_text_segments = ai_module.create_request_text_segments + + local text_segments = create_request_text_segments(msgs) + local res, err = comprehend:detectToxicContent({ + LanguageCode = "en", + TextSegments = text_segments, + }) + + if not res then + core.log.error("failed to send request to ", provider, ": ", err) + return HTTP_INTERNAL_SERVER_ERROR, err + end + + local results = res.body and res.body.ResultList + if type(results) ~= "table" or core.table.isempty(results) then + return HTTP_INTERNAL_SERVER_ERROR, "failed to get moderation results from response" + end + + for _, result in ipairs(results) do + if conf.moderation_categories then + for _, item in pairs(result.Labels) do + if not conf.moderation_categories[item.Name] then + goto continue + end + if item.Score > conf.moderation_categories[item.Name] then + return HTTP_BAD_REQUEST, "request body exceeds " .. item.Name .. " threshold" + end + ::continue:: + end + end + + if result.Toxicity > conf.moderation_threshold then + return HTTP_BAD_REQUEST, "request body exceeds toxicity threshold" + end + end +end + +return _M diff --git a/apisix/plugins/ai/openai.lua b/apisix/plugins/ai/openai.lua new file mode 100644 index 000000000000..203debb7e6d1 --- /dev/null +++ b/apisix/plugins/ai/openai.lua @@ -0,0 +1,33 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local core = require("apisix.core") +local ipairs = ipairs + +local _M = {} + + +function _M.create_request_text_segments(msgs) + local text_segments = {} + for _, msg in ipairs(msgs) do + core.table.insert_tail(text_segments, { + Text = msg.content + }) + end + return text_segments +end + +return _M diff --git a/conf/config.yaml.example b/conf/config.yaml.example index bd741b2f767b..700bb7707999 100644 --- a/conf/config.yaml.example +++ b/conf/config.yaml.example @@ -478,6 +478,7 @@ plugins: # plugin list (sorted by priority) - body-transformer # priority: 1080 - ai-prompt-template # priority: 1071 - ai-prompt-decorator # priority: 1070 + - ai-content-moderation # priority: 1040 TODO: compare priority with other ai plugins - proxy-mirror # priority: 1010 - proxy-rewrite # priority: 1008 - workflow # priority: 1006 diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json index ad9c1e051523..25ec82b23dad 100644 --- a/docs/en/latest/config.json +++ b/docs/en/latest/config.json @@ -80,7 +80,8 @@ "plugins/ext-plugin-post-req", "plugins/ext-plugin-post-resp", "plugins/inspect", - "plugins/ocsp-stapling" + "plugins/ocsp-stapling", + "plugins/ai-content-moderation" ] }, { diff --git a/docs/en/latest/plugins/ai-content-moderation.md b/docs/en/latest/plugins/ai-content-moderation.md new file mode 100644 index 000000000000..781b203d9130 --- /dev/null +++ b/docs/en/latest/plugins/ai-content-moderation.md @@ -0,0 +1,253 @@ +--- +title: ai-content-moderation +keywords: + - Apache APISIX + - API Gateway + - Plugin + - ai-content-moderation +description: This document contains information about the Apache APISIX ai-content-moderation Plugin. +--- + + + +## Description + +The `ai-content-moderation` plugin processes the request body to check for toxicity and rejects the request if it exceeds the configured threshold. + +**_This plugin must be used in routes that proxy requests to LLMs only._** + +**_As of now, the plugin only supports the integration with [AWS Comprehend](https://aws.amazon.com/comprehend/) for content moderation. PRs for introducing support for other service providers are welcomed._** + +## Plugin Attributes + +| **Field** | **Required** | **Type** | **Description** | +| ----------------------------------------- | ------------ | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| provider.aws_comprehend.access_key_id | Yes | String | AWS access key ID | +| provider.aws_comprehend.secret_access_key | Yes | String | AWS secret access key | +| provider.aws_comprehend.region | Yes | String | AWS region | +| provider.aws_comprehend.endpoint | No | String | AWS Comprehend service endpoint. Must match the pattern `^https?://` | +| moderation_categories | No | Object | Key-value pairs of moderation category and their score. In each pair, the key should be one of the `PROFANITY`, `HATE_SPEECH`, `INSULT`, `HARASSMENT_OR_ABUSE`, `SEXUAL`, or `VIOLENCE_OR_THREAT`; and the value should be between 0 and 1 (inclusive). | +| moderation_threshold | No | Number | The degree to which content is harmful, offensive, or inappropriate. A higher value indicates more toxic content allowed. Range: 0 - 1. Default: 0.5 | +| llm_provider | Yes | String | Name of the LLM provider that this route will proxy requests to. | + +## Example usage + +First initialise these shell variables: + +```shell +ADMIN_API_KEY=edd1c9f034335f136f87ad84b625c8f1 +ACCESS_KEY_ID=aws-comprehend-access-key-id-here +SECRET_ACCESS_KEY=aws-comprehend-secret-access-key-here +OPENAI_KEY=open-ai-key-here +``` + +Create a route with the `ai-content-moderation` and `ai-proxy` plugin like so: + +```shell +curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \ + -H "X-API-KEY: ${ADMIN_API_KEY}" \ + -d '{ + "uri": "/post", + "plugins": { + "ai-content-moderation": { + "provider": { + "aws_comprehend": { + "access_key_id": "'"$ACCESS_KEY_ID"'", + "secret_access_key": "'"$SECRET_ACCESS_KEY"'", + "region": "us-east-1" + } + }, + "moderation_categories": { + "PROFANITY": 0.5 + }, + "llm_provider": "openai" + }, + "ai-proxy": { + "auth": { + "header": { + "api-key": "'"$OPENAI_KEY"'" + } + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + } + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "httpbin.org:80": 1 + } + } + }' +``` + +The `ai-proxy` plugin is used here as it simplifies access to LLMs. However, you may configure the LLM in the upstream configuration as well. + +Now send a request: + +```shell +curl http://127.0.0.1:9080/post -i -XPOST -H 'Content-Type: application/json' -d '{ + "messages": [ + { + "role": "user", + "content": "" + } + ] +}' +``` + +Then the request will be blocked with error like this: + +```text +HTTP/1.1 400 Bad Request +Date: Thu, 03 Oct 2024 11:53:15 GMT +Content-Type: text/plain; charset=utf-8 +Transfer-Encoding: chunked +Connection: keep-alive +Server: APISIX/3.10.0 + +request body exceeds PROFANITY threshold +``` + +Send a request with compliant content in the request body: + +```shell +curl http://127.0.0.1:9080/post -i -XPOST -H 'Content-Type: application/json' -d '{ + "messages": [ + { + "role": "system", + "content": "You are a mathematician" + }, + { "role": "user", "content": "What is 1+1?" } + ] +}' +``` + +This request will be proxied normally to the configured LLM. + +```text +HTTP/1.1 200 OK +Date: Thu, 03 Oct 2024 11:53:00 GMT +Content-Type: text/plain; charset=utf-8 +Transfer-Encoding: chunked +Connection: keep-alive +Server: APISIX/3.10.0 + +{"choices":[{"finish_reason":"stop","index":0,"message":{"content":"1+1 equals 2.","role":"assistant"}}],"created":1727956380,"id":"chatcmpl-AEEg8Pe5BAW5Sw3C1gdwXnuyulIkY","model":"gpt-4o-2024-05-13","object":"chat.completion","system_fingerprint":"fp_67802d9a6d","usage":{"completion_tokens":7,"prompt_tokens":23,"total_tokens":30}} +``` + +You can also configure filters on other moderation categories like so: + +```shell +curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \ + -H "X-API-KEY: ${ADMIN_API_KEY}" \ + -d '{ + "uri": "/post", + "plugins": { + "ai-content-moderation": { + "provider": { + "aws_comprehend": { + "access_key_id": "'"$ACCESS_KEY_ID"'", + "secret_access_key": "'"$SECRET_ACCESS_KEY"'", + "region": "us-east-1" + } + }, + "llm_provider": "openai", + "moderation_categories": { + "PROFANITY": 0.5, + "HARASSMENT_OR_ABUSE": 0.7, + "SEXUAL": 0.2 + } + }, + "ai-proxy": { + "auth": { + "header": { + "api-key": "'"$OPENAI_KEY"'" + } + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + } + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "httpbin.org:80": 1 + } + } + }' +``` + +If none of the `moderation_categories` are configured, request bodies will be moderated on the basis of overall toxicity. +The default `moderation_threshold` is 0.5, it can be configured like so. + +```shell +curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \ + -H "X-API-KEY: ${ADMIN_API_KEY}" \ + -d '{ + "uri": "/post", + "plugins": { + "ai-content-moderation": { + "provider": { + "aws_comprehend": { + "access_key_id": "'"$ACCESS_KEY_ID"'", + "secret_access_key": "'"$SECRET_ACCESS_KEY"'", + "region": "us-east-1" + } + }, + "moderation_threshold": 0.7, + "llm_provider": "openai" + }, + "ai-proxy": { + "auth": { + "header": { + "api-key": "'"$OPENAI_KEY"'" + } + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + } + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "httpbin.org:80": 1 + } + } +}' +``` diff --git a/t/admin/plugins.t b/t/admin/plugins.t index bf3d485e8b31..2b536c10e715 100644 --- a/t/admin/plugins.t +++ b/t/admin/plugins.t @@ -95,6 +95,7 @@ proxy-cache body-transformer ai-prompt-template ai-prompt-decorator +ai-content-moderation proxy-mirror proxy-rewrite workflow diff --git a/t/assets/content-moderation-responses.json b/t/assets/content-moderation-responses.json new file mode 100644 index 000000000000..e10c3d030be3 --- /dev/null +++ b/t/assets/content-moderation-responses.json @@ -0,0 +1,224 @@ +{ + "good_request": { + "ResultList": [ + { + "Toxicity": 0.02150000333786, + "Labels": [ + { + "Name": "PROFANITY", + "Score": 0.00589999556541 + }, + { + "Name": "HATE_SPEECH", + "Score": 0.01729999780655 + }, + { + "Name": "INSULT", + "Score": 0.00519999861717 + }, + { + "Name": "GRAPHIC", + "Score": 0.00520000338554 + }, + { + "Name": "HARASSMENT_OR_ABUSE", + "Score": 0.00090001106262 + }, + { + "Name": "SEXUAL", + "Score": 0.00810000061989 + }, + { + "Name": "VIOLENCE_OR_THREAT", + "Score": 0.00570000290871 + } + ] + } + ] + }, + "profane": { + "ResultList": [ + { + "Toxicity": 0.62150000333786, + "Labels": [ + { + "Name": "PROFANITY", + "Score": 0.55589999556541 + }, + { + "Name": "HATE_SPEECH", + "Score": 0.21729999780655 + }, + { + "Name": "INSULT", + "Score": 0.25519999861717 + }, + { + "Name": "GRAPHIC", + "Score": 0.12520000338554 + }, + { + "Name": "HARASSMENT_OR_ABUSE", + "Score": 0.27090001106262 + }, + { + "Name": "SEXUAL", + "Score": 0.44810000061989 + }, + { + "Name": "VIOLENCE_OR_THREAT", + "Score": 0.27570000290871 + } + ] + } + ] + }, + "profane_but_not_toxic": { + "ResultList": [ + { + "Toxicity": 0.12150000333786, + "Labels": [ + { + "Name": "PROFANITY", + "Score": 0.55589999556541 + }, + { + "Name": "HATE_SPEECH", + "Score": 0.21729999780655 + }, + { + "Name": "INSULT", + "Score": 0.25519999861717 + }, + { + "Name": "GRAPHIC", + "Score": 0.12520000338554 + }, + { + "Name": "HARASSMENT_OR_ABUSE", + "Score": 0.27090001106262 + }, + { + "Name": "SEXUAL", + "Score": 0.44810000061989 + }, + { + "Name": "VIOLENCE_OR_THREAT", + "Score": 0.27570000290871 + } + ] + } + ] + }, + "very_profane": { + "ResultList": [ + { + "Toxicity": 0.72150000333786, + "Labels": [ + { + "Name": "PROFANITY", + "Score": 0.85589999556541 + }, + { + "Name": "HATE_SPEECH", + "Score": 0.21729999780655 + }, + { + "Name": "INSULT", + "Score": 0.25519999861717 + }, + { + "Name": "GRAPHIC", + "Score": 0.12520000338554 + }, + { + "Name": "HARASSMENT_OR_ABUSE", + "Score": 0.27090001106262 + }, + { + "Name": "SEXUAL", + "Score": 0.94810000061989 + }, + { + "Name": "VIOLENCE_OR_THREAT", + "Score": 0.27570000290871 + } + ] + } + ] + }, + "toxic": { + "ResultList": [ + { + "Toxicity": 0.72150000333786, + "Labels": [ + { + "Name": "PROFANITY", + "Score": 0.25589999556541 + }, + { + "Name": "HATE_SPEECH", + "Score": 0.21729999780655 + }, + { + "Name": "INSULT", + "Score": 0.75519999861717 + }, + { + "Name": "GRAPHIC", + "Score": 0.12520000338554 + }, + { + "Name": "HARASSMENT_OR_ABUSE", + "Score": 0.27090001106262 + }, + { + "Name": "SEXUAL", + "Score": 0.64810000061989 + }, + { + "Name": "VIOLENCE_OR_THREAT", + "Score": 0.27570000290871 + } + ] + } + ] + }, + "very_toxic": { + "ResultList": [ + { + "Toxicity": 0.92150000333786, + "Labels": [ + { + "Name": "PROFANITY", + "Score": 0.25589999556541 + }, + { + "Name": "HATE_SPEECH", + "Score": 0.21729999780655 + }, + { + "Name": "INSULT", + "Score": 0.25519999861717 + }, + { + "Name": "GRAPHIC", + "Score": 0.12520000338554 + }, + { + "Name": "HARASSMENT_OR_ABUSE", + "Score": 0.27090001106262 + }, + { + "Name": "SEXUAL", + "Score": 0.44810000061989 + }, + { + "Name": "VIOLENCE_OR_THREAT", + "Score": 0.27570000290871 + } + ] + } + ] + } +} diff --git a/t/plugin/ai-content-moderation-secrets.t b/t/plugin/ai-content-moderation-secrets.t new file mode 100644 index 000000000000..06d7941f7be6 --- /dev/null +++ b/t/plugin/ai-content-moderation-secrets.t @@ -0,0 +1,213 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +BEGIN { + $ENV{VAULT_TOKEN} = "root"; + $ENV{SECRET_ACCESS_KEY} = "super-secret"; + $ENV{ACCESS_KEY_ID} = "access-key-id"; +} + +use t::APISIX 'no_plan'; + +repeat_each(1); +no_long_string(); +no_root_location(); + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $http_config = $block->http_config // <<_EOC_; + server { + listen 2668; + + default_type 'application/json'; + + location / { + content_by_lua_block { + local json = require("cjson.safe") + local core = require("apisix.core") + local open = io.open + + local f = open('t/assets/content-moderation-responses.json', "r") + local resp = f:read("*a") + f:close() + + if not resp then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to open response.json file") + end + + local responses = json.decode(resp) + if not responses then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to decode response.json contents") + end + + local headers = ngx.req.get_headers() + local auth_header = headers["Authorization"] + if core.string.find(auth_header, "access-key-id") then + ngx.say(json.encode(responses["good_request"])) + return + end + ngx.status = 403 + ngx.say("unauthorized") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests; + +__DATA__ + +=== TEST 1: store secret into vault +--- exec +VAULT_TOKEN='root' VAULT_ADDR='http://0.0.0.0:8200' vault kv put kv/apisix/foo secret_access_key=super-secret +VAULT_TOKEN='root' VAULT_ADDR='http://0.0.0.0:8200' vault kv put kv/apisix/foo access_key_id=access-key-id +--- response_body +Success! Data written to: kv/apisix/foo +Success! Data written to: kv/apisix/foo + + + +=== TEST 2: set secret_access_key and access_key_id as a reference to secret +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + -- put secret vault config + local code, body = t('/apisix/admin/secrets/vault/test1', + ngx.HTTP_PUT, + [[{ + "uri": "http://127.0.0.1:8200", + "prefix" : "kv/apisix", + "token" : "root" + }]] + ) + + if code >= 300 then + ngx.status = code + return ngx.say(body) + end + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "plugins": { + "ai-content-moderation": { + "provider": { + "aws_comprehend": { + "access_key_id": "$secret://vault/test1/foo/access_key_id", + "secret_access_key": "$secret://vault/test1/foo/secret_access_key", + "region": "us-east-1", + "endpoint": "http://localhost:2668" + } + }, + "llm_provider": "openai" + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + return ngx.say(body) + end + ngx.say("success") + } + } +--- request +GET /t +--- response_body +success + + + +=== TEST 3: good request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} +--- error_code: 200 +--- response_body chomp +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} + + + +=== TEST 4: set secret_access_key as a reference to env variable +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "plugins": { + "ai-content-moderation": { + "provider": { + "aws_comprehend": { + "access_key_id": "$env://ACCESS_KEY_ID", + "secret_access_key": "$env://SECRET_ACCESS_KEY", + "region": "us-east-1", + "endpoint": "http://localhost:2668" + } + }, + "llm_provider": "openai" + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + return + end + ngx.say("success") + } + } +--- request +GET /t +--- response_body +success + + + +=== TEST 5: good request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} +--- error_code: 200 +--- response_body chomp +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} diff --git a/t/plugin/ai-content-moderation.t b/t/plugin/ai-content-moderation.t new file mode 100644 index 000000000000..66393ef988f7 --- /dev/null +++ b/t/plugin/ai-content-moderation.t @@ -0,0 +1,304 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $http_config = $block->http_config // <<_EOC_; + server { + listen 2668; + + default_type 'application/json'; + + location / { + content_by_lua_block { + local json = require("cjson.safe") + local open = io.open + local f = open('t/assets/content-moderation-responses.json', "r") + local resp = f:read("*a") + f:close() + + if not resp then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to open response.json file") + end + + local responses = json.decode(resp) + if not responses then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to decode response.json contents") + end + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + if not body then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to get request body: ", err) + end + + body, err = json.decode(body) + if not body then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to decoded request body: ", err) + end + local result = body.TextSegments[1].Text + local final_response = responses[result] or "invalid" + + if final_response == "invalid" then + ngx.status = 500 + end + ngx.say(json.encode(final_response)) + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: sanity +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "plugins": { + "ai-content-moderation": { + "provider": { + "aws_comprehend": { + "access_key_id": "access", + "secret_access_key": "ea+secret", + "region": "us-east-1", + "endpoint": "http://localhost:2668" + } + }, + "llm_provider": "openai" + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: toxic request should fail +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"toxic"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds toxicity threshold + + + +=== TEST 3: good request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} +--- error_code: 200 + + + +=== TEST 4: profanity filter +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "plugins": { + "ai-content-moderation": { + "provider": { + "aws_comprehend": { + "access_key_id": "access", + "secret_access_key": "ea+secret", + "region": "us-east-1", + "endpoint": "http://localhost:2668" + } + }, + "moderation_categories": { + "PROFANITY": 0.5 + }, + "llm_provider": "openai" + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 5: profane request should fail +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds PROFANITY threshold + + + +=== TEST 6: very profane request should also fail +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"very_profane"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds PROFANITY threshold + + + +=== TEST 7: good_request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} +--- error_code: 200 + + + +=== TEST 8: set profanity = 0.7 (allow profane request but disallow very_profane) +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "plugins": { + "ai-content-moderation": { + "provider": { + "aws_comprehend": { + "access_key_id": "access", + "secret_access_key": "ea+secret", + "region": "us-east-1", + "endpoint": "http://localhost:2668" + } + }, + "moderation_categories": { + "PROFANITY": 0.7 + }, + "llm_provider": "openai" + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 9: profane request should pass profanity check but fail toxicity check +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds toxicity threshold + + + +=== TEST 10: profane_but_not_toxic request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane_but_not_toxic"}]} +--- error_code: 200 + + + +=== TEST 11: but very profane request will fail +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"very_profane"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds PROFANITY threshold + + + +=== TEST 12: good_request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} +--- error_code: 200