From 0216500ba323cf0bf627109cb6a9d8ba221c2592 Mon Sep 17 00:00:00 2001 From: Qi Date: Mon, 19 Feb 2024 06:28:43 +0000 Subject: [PATCH] tests(plugin/ai-request-transformer): replace mocking server by http_mock module (cherry picked from commit acffb9d52ec1ec25a11b80f7e4887b06e8fb38f6) --- .../01-transformer_spec.lua | 221 ++++++++---------- 1 file changed, 95 insertions(+), 126 deletions(-) diff --git a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua index de6b0d25416..db1aef512b0 100644 --- a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua @@ -1,6 +1,8 @@ local llm_class = require("kong.llm") local helpers = require "spec.helpers" local cjson = require "cjson" +local http_mock = require "spec.helpers.http_mock" +local pl_path = require "pl.path" local MOCK_PORT = helpers.get_available_port() local PLUGIN_NAME = "ai-request-transformer" @@ -14,7 +16,7 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/openai" + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/openai" }, }, auth = { @@ -30,7 +32,7 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/cohere" + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/cohere" }, }, auth = { @@ -46,7 +48,7 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/anthropic" + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/anthropic" }, }, auth = { @@ -62,7 +64,7 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/azure" + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/azure" }, }, auth = { @@ -78,7 +80,7 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/llama2", + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/llama2", llama2_format = "raw", }, }, @@ -95,7 +97,7 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/mistral", + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/mistral", mistral_format = "ollama", }, }, @@ -114,7 +116,7 @@ local OPENAI_NOT_JSON = { options = { max_tokens = 512, temperature = 0.5, - upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/not-json" + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/not-json" }, }, auth = { @@ -152,131 +154,77 @@ local EXPECTED_RESULT = { } local SYSTEM_PROMPT = "You are a mathematician. " - .. "Multiply all numbers in my JSON request, by 2. Return me the JSON output only" + .. "Multiply all numbers in my JSON request, by 2. Return me the JSON output only" -local client +describe(PLUGIN_NAME .. ": (unit)", function() + local mock + local ai_proxy_fixtures_dir = pl_path.abspath("spec/fixtures/ai-proxy/") + lazy_setup(function() + mock = http_mock.new(MOCK_PORT, { + ["~/chat/(?[a-z0-9]+)"] = { + content = string.format([[ + local base_dir = "%s/" + ngx.header["Content-Type"] = "application/json" -for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then + local pl_file = require "pl.file" + local json = require("cjson.safe") + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) - describe(PLUGIN_NAME .. ": (unit)", function() - - lazy_setup(function() - -- set up provider fixtures - local fixtures = { - http_mock = {}, - } - - fixtures.http_mock.openai = [[ - server { - server_name llm; - listen ]]..MOCK_PORT..[[; - - default_type 'application/json'; - - location ~/chat/(?[a-z0-9]+) { - content_by_lua_block { - local pl_file = require "pl.file" - local json = require("cjson.safe") + local token = ngx.req.get_headers()["authorization"] + local token_query = ngx.req.get_uri_args()["apikey"] + if token == "Bearer " .. ngx.var.provider .. "-key" or token_query == "$1-key" or body.apikey == "$1-key" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - local token = ngx.req.get_headers()["authorization"] - local token_query = ngx.req.get_uri_args()["apikey"] + if err or (body.messages == ngx.null) then + ngx.status = 400 + ngx.say(pl_file.read(base_dir .. ngx.var.provider .. "/llm-v1-chat/responses/bad_request.json")) - if token == "Bearer " .. ngx.var.provider .. "-key" or token_query == "$1-key" or body.apikey == "$1-key" then - ngx.req.read_body() - local body, err = ngx.req.get_body_data() - body, err = json.decode(body) - - if err or (body.messages == ngx.null) then - ngx.status = 400 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/llm-v1-chat/responses/bad_request.json")) - else - ngx.status = 200 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/request-transformer/response-in-json.json")) - end else - ngx.status = 401 - ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/llm-v1-chat/responses/unauthorized.json")) + ngx.status = 200 + ngx.say(pl_file.read(base_dir .. ngx.var.provider .. "/request-transformer/response-in-json.json")) end - } - } - - location ~/not-json { - content_by_lua_block { - local pl_file = require "pl.file" - ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/request-transformer/response-not-json.json")) - } - } - } - ]] - - -- start kong - assert(helpers.start_kong({ - -- set the strategy - database = strategy, - -- use the custom test template to create a local mock server - nginx_conf = "spec/fixtures/custom_nginx.template", - -- make sure our plugin gets loaded - plugins = "bundled," .. PLUGIN_NAME, - -- write & load declarative config, only if 'strategy=off' - declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, - }, nil, nil, fixtures)) - end) - - lazy_teardown(function() - helpers.stop_kong(nil, true) - end) - - before_each(function() - client = helpers.proxy_client() - end) - - after_each(function() - if client then client:close() end - end) - - for name, format_options in pairs(FORMATS) do - - describe(name .. " transformer tests, exact json response", function() - - it("transforms request based on LLM instructions", function() - local llm = llm_class:new(format_options, {}) - assert.truthy(llm) - local result, err = llm:ai_introspect_body( - REQUEST_BODY, -- request body - SYSTEM_PROMPT, -- conf.prompt - {}, -- http opts - nil -- transformation extraction pattern - ) - - assert.is_nil(err) - - result, err = cjson.decode(result) - assert.is_nil(err) + else + ngx.status = 401 + ngx.say(pl_file.read(base_dir .. ngx.var.provider .. "/llm-v1-chat/responses/unauthorized.json")) + end + ]], ai_proxy_fixtures_dir), + }, + ["~/not-json"] = { + content = string.format([[ + local base_dir = "%s/" + local pl_file = require "pl.file" + ngx.header["Content-Type"] = "application/json" + ngx.print(pl_file.read(base_dir .. "openai/request-transformer/response-not-json.json")) + ]], ai_proxy_fixtures_dir), + }, + }) - assert.same(EXPECTED_RESULT, result) - end) - end) + assert(mock:start()) + end) - - end + lazy_teardown(function() + assert(mock:stop()) + end) - describe("openai transformer tests, pattern matchers", function() - it("transforms request based on LLM instructions, with json extraction pattern", function() - local llm = llm_class:new(OPENAI_NOT_JSON, {}) + for name, format_options in pairs(FORMATS) do + describe(name .. " transformer tests, exact json response", function() + it("transforms request based on LLM instructions", function() + local llm = llm_class:new(format_options, {}) assert.truthy(llm) local result, err = llm:ai_introspect_body( - REQUEST_BODY, -- request body - SYSTEM_PROMPT, -- conf.prompt - {}, -- http opts - "\\{((.|\n)*)\\}" -- transformation extraction pattern (loose json) + REQUEST_BODY, -- request body + SYSTEM_PROMPT, -- conf.prompt + {}, -- http opts + nil -- transformation extraction pattern ) assert.is_nil(err) @@ -286,22 +234,43 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.same(EXPECTED_RESULT, result) end) + end) + end - it("transforms request based on LLM instructions, but fails to match pattern", function() - local llm = llm_class:new(OPENAI_NOT_JSON, {}) - assert.truthy(llm) + describe("openai transformer tests, pattern matchers", function() + it("transforms request based on LLM instructions, with json extraction pattern", function() + local llm = llm_class:new(OPENAI_NOT_JSON, {}) + assert.truthy(llm) - local result, err = llm:ai_introspect_body( - REQUEST_BODY, -- request body - SYSTEM_PROMPT, -- conf.prompt - {}, -- http opts - "\\#*\\=" -- transformation extraction pattern (loose json) - ) + local result, err = llm:ai_introspect_body( + REQUEST_BODY, -- request body + SYSTEM_PROMPT, -- conf.prompt + {}, -- http opts + "\\{((.|\n)*)\\}" -- transformation extraction pattern (loose json) + ) - assert.is_nil(result) - assert.is_not_nil(err) - assert.same("AI response did not match specified regular expression", err) - end) + assert.is_nil(err) + + result, err = cjson.decode(result) + assert.is_nil(err) + + assert.same(EXPECTED_RESULT, result) end) + + it("transforms request based on LLM instructions, but fails to match pattern", function() + local llm = llm_class:new(OPENAI_NOT_JSON, {}) + assert.truthy(llm) + + local result, err = llm:ai_introspect_body( + REQUEST_BODY, -- request body + SYSTEM_PROMPT, -- conf.prompt + {}, -- http opts + "\\#*\\=" -- transformation extraction pattern (loose json) + ) + + assert.is_nil(result) + assert.is_not_nil(err) + assert.same("AI response did not match specified regular expression", err) + end) -- it end) -end end +end)