diff --git a/apisix/plugins/ai-content-moderation.lua b/apisix/plugins/ai-content-moderation.lua index 62fe8f6e979d..bb39aed3bfee 100644 --- a/apisix/plugins/ai-content-moderation.lua +++ b/apisix/plugins/ai-content-moderation.lua @@ -26,8 +26,8 @@ local unpack = unpack local type = type local ipairs = ipairs local require = require -local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR -local bad_request = ngx.HTTP_BAD_REQUEST +local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR +local BAD_REQUEST = ngx.HTTP_BAD_REQUEST local aws_comprehend_schema = { @@ -54,16 +54,15 @@ local schema = { properties = { aws_comprehend = aws_comprehend_schema }, - -- change to oneOf/enum while implementing support for other services + -- ensure only one provider can be configured while implementing support for + -- other providers required = { "aws_comprehend" } }, moderation_categories = { type = "object", patternProperties = { - -- luacheck: push max code line length 300 [moderation_categories_pattern] = { - -- luacheck: pop - type = "number", + type = "number", minimum = 0, maximum = 1 } @@ -76,12 +75,12 @@ local schema = { maximum = 1, default = 0.5 }, - type = { + llm_provider = { type = "string", enum = { "openai" }, } }, - required = { "provider", "type" }, + required = { "provider", "llm_provider" }, } @@ -100,17 +99,17 @@ end function _M.rewrite(conf, ctx) conf = fetch_secrets(conf, true, conf, "") if not conf then - return internal_server_error, "failed to retrieve secrets from conf" + return INTERNAL_SERVER_ERROR, "failed to retrieve secrets from conf" end local body, err = core.request.get_json_request_body_table() if not body then - return bad_request, err + return BAD_REQUEST, err end local msgs = body.messages if type(msgs) ~= "table" or #msgs < 1 then - return bad_request, "messages not found in request body" + return BAD_REQUEST, "messages not found in request body" end local provider = conf.provider[next(conf.provider)] @@ -135,7 +134,7 @@ function _M.rewrite(conf, ctx) port = port, }) - local ai_module = require("apisix.plugins.ai." .. conf.type) + local ai_module = require("apisix.plugins.ai." .. conf.llm_provider) local create_request_text_segments = ai_module.create_request_text_segments local text_segments = create_request_text_segments(msgs) @@ -146,12 +145,12 @@ function _M.rewrite(conf, ctx) if not res then core.log.error("failed to send request to ", provider, ": ", err) - return internal_server_error, err + return INTERNAL_SERVER_ERROR, err end local results = res.body and res.body.ResultList - if not results or type(results) ~= "table" or #results < 1 then - return internal_server_error, "failed to get moderation results from response" + if type(results) ~= "table" or core.table.isempty(results) then + return INTERNAL_SERVER_ERROR, "failed to get moderation results from response" end for _, result in ipairs(results) do @@ -161,14 +160,14 @@ function _M.rewrite(conf, ctx) goto continue end if item.Score > conf.moderation_categories[item.Name] then - return bad_request, "request body exceeds " .. item.Name .. " threshold" + return BAD_REQUEST, "request body exceeds " .. item.Name .. " threshold" end ::continue:: end end if result.Toxicity > conf.toxicity_level then - return bad_request, "request body exceeds toxicity threshold" + return BAD_REQUEST, "request body exceeds toxicity threshold" end end end diff --git a/t/plugin/ai-content-moderation-secrets.t b/t/plugin/ai-content-moderation-secrets.t index 6c27e2dc53be..06d7941f7be6 100644 --- a/t/plugin/ai-content-moderation-secrets.t +++ b/t/plugin/ai-content-moderation-secrets.t @@ -123,7 +123,7 @@ Success! Data written to: kv/apisix/foo "endpoint": "http://localhost:2668" } }, - "type": "openai" + "llm_provider": "openai" } }, "upstream": { @@ -178,7 +178,7 @@ POST /echo "endpoint": "http://localhost:2668" } }, - "type": "openai" + "llm_provider": "openai" } }, "upstream": { diff --git a/t/plugin/ai-content-moderation.t b/t/plugin/ai-content-moderation.t index 50772938cef9..66393ef988f7 100644 --- a/t/plugin/ai-content-moderation.t +++ b/t/plugin/ai-content-moderation.t @@ -110,7 +110,7 @@ __DATA__ "endpoint": "http://localhost:2668" } }, - "type": "openai" + "llm_provider": "openai" } }, "upstream": { @@ -173,7 +173,7 @@ POST /echo "moderation_categories": { "PROFANITY": 0.5 }, - "type": "openai" + "llm_provider": "openai" } }, "upstream": { @@ -246,7 +246,7 @@ POST /echo "moderation_categories": { "PROFANITY": 0.7 }, - "type": "openai" + "llm_provider": "openai" } }, "upstream": {