Skip to content

Commit

Permalink
tests(plugin/ai-request-transformer): replace mocking server by http_…
Browse files Browse the repository at this point in the history
…mock module

(cherry picked from commit acffb9d)
  • Loading branch information
ADD-SP committed Feb 19, 2024
1 parent 39dd520 commit 0216500
Showing 1 changed file with 95 additions and 126 deletions.
221 changes: 95 additions & 126 deletions spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
local llm_class = require("kong.llm")
local helpers = require "spec.helpers"
local cjson = require "cjson"
local http_mock = require "spec.helpers.http_mock"
local pl_path = require "pl.path"

local MOCK_PORT = helpers.get_available_port()
local PLUGIN_NAME = "ai-request-transformer"
Expand All @@ -14,7 +16,7 @@ local FORMATS = {
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/openai"
upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/openai"
},
},
auth = {
Expand All @@ -30,7 +32,7 @@ local FORMATS = {
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/cohere"
upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/cohere"
},
},
auth = {
Expand All @@ -46,7 +48,7 @@ local FORMATS = {
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/anthropic"
upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/anthropic"
},
},
auth = {
Expand All @@ -62,7 +64,7 @@ local FORMATS = {
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/azure"
upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/azure"
},
},
auth = {
Expand All @@ -78,7 +80,7 @@ local FORMATS = {
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/llama2",
upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/llama2",
llama2_format = "raw",
},
},
Expand All @@ -95,7 +97,7 @@ local FORMATS = {
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/chat/mistral",
upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/mistral",
mistral_format = "ollama",
},
},
Expand All @@ -114,7 +116,7 @@ local OPENAI_NOT_JSON = {
options = {
max_tokens = 512,
temperature = 0.5,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/not-json"
upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/not-json"
},
},
auth = {
Expand Down Expand Up @@ -152,131 +154,77 @@ local EXPECTED_RESULT = {
}

local SYSTEM_PROMPT = "You are a mathematician. "
.. "Multiply all numbers in my JSON request, by 2. Return me the JSON output only"
.. "Multiply all numbers in my JSON request, by 2. Return me the JSON output only"


local client
describe(PLUGIN_NAME .. ": (unit)", function()
local mock
local ai_proxy_fixtures_dir = pl_path.abspath("spec/fixtures/ai-proxy/")

lazy_setup(function()
mock = http_mock.new(MOCK_PORT, {
["~/chat/(?<provider>[a-z0-9]+)"] = {
content = string.format([[
local base_dir = "%s/"
ngx.header["Content-Type"] = "application/json"
for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
local pl_file = require "pl.file"
local json = require("cjson.safe")
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)
describe(PLUGIN_NAME .. ": (unit)", function()

lazy_setup(function()
-- set up provider fixtures
local fixtures = {
http_mock = {},
}

fixtures.http_mock.openai = [[
server {
server_name llm;
listen ]]..MOCK_PORT..[[;
default_type 'application/json';
location ~/chat/(?<provider>[a-z0-9]+) {
content_by_lua_block {
local pl_file = require "pl.file"
local json = require("cjson.safe")
local token = ngx.req.get_headers()["authorization"]
local token_query = ngx.req.get_uri_args()["apikey"]
if token == "Bearer " .. ngx.var.provider .. "-key" or token_query == "$1-key" or body.apikey == "$1-key" then
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)
local token = ngx.req.get_headers()["authorization"]
local token_query = ngx.req.get_uri_args()["apikey"]
if err or (body.messages == ngx.null) then
ngx.status = 400
ngx.say(pl_file.read(base_dir .. ngx.var.provider .. "/llm-v1-chat/responses/bad_request.json"))
if token == "Bearer " .. ngx.var.provider .. "-key" or token_query == "$1-key" or body.apikey == "$1-key" then
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)
if err or (body.messages == ngx.null) then
ngx.status = 400
ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/llm-v1-chat/responses/bad_request.json"))
else
ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/request-transformer/response-in-json.json"))
end
else
ngx.status = 401
ngx.print(pl_file.read("spec/fixtures/ai-proxy/" .. ngx.var.provider .. "/llm-v1-chat/responses/unauthorized.json"))
ngx.status = 200
ngx.say(pl_file.read(base_dir .. ngx.var.provider .. "/request-transformer/response-in-json.json"))
end
}
}
location ~/not-json {
content_by_lua_block {
local pl_file = require "pl.file"
ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/request-transformer/response-not-json.json"))
}
}
}
]]

-- start kong
assert(helpers.start_kong({
-- set the strategy
database = strategy,
-- use the custom test template to create a local mock server
nginx_conf = "spec/fixtures/custom_nginx.template",
-- make sure our plugin gets loaded
plugins = "bundled," .. PLUGIN_NAME,
-- write & load declarative config, only if 'strategy=off'
declarative_config = strategy == "off" and helpers.make_yaml_file() or nil,
}, nil, nil, fixtures))
end)

lazy_teardown(function()
helpers.stop_kong(nil, true)
end)

before_each(function()
client = helpers.proxy_client()
end)

after_each(function()
if client then client:close() end
end)

for name, format_options in pairs(FORMATS) do

describe(name .. " transformer tests, exact json response", function()

it("transforms request based on LLM instructions", function()
local llm = llm_class:new(format_options, {})
assert.truthy(llm)
local result, err = llm:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
nil -- transformation extraction pattern
)

assert.is_nil(err)

result, err = cjson.decode(result)
assert.is_nil(err)
else
ngx.status = 401
ngx.say(pl_file.read(base_dir .. ngx.var.provider .. "/llm-v1-chat/responses/unauthorized.json"))
end
]], ai_proxy_fixtures_dir),
},
["~/not-json"] = {
content = string.format([[
local base_dir = "%s/"
local pl_file = require "pl.file"
ngx.header["Content-Type"] = "application/json"
ngx.print(pl_file.read(base_dir .. "openai/request-transformer/response-not-json.json"))
]], ai_proxy_fixtures_dir),
},
})

assert.same(EXPECTED_RESULT, result)
end)
end)
assert(mock:start())
end)


end
lazy_teardown(function()
assert(mock:stop())
end)

describe("openai transformer tests, pattern matchers", function()
it("transforms request based on LLM instructions, with json extraction pattern", function()
local llm = llm_class:new(OPENAI_NOT_JSON, {})
for name, format_options in pairs(FORMATS) do
describe(name .. " transformer tests, exact json response", function()
it("transforms request based on LLM instructions", function()
local llm = llm_class:new(format_options, {})
assert.truthy(llm)

local result, err = llm:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
"\\{((.|\n)*)\\}" -- transformation extraction pattern (loose json)
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
nil -- transformation extraction pattern
)

assert.is_nil(err)
Expand All @@ -286,22 +234,43 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then

assert.same(EXPECTED_RESULT, result)
end)
end)
end

it("transforms request based on LLM instructions, but fails to match pattern", function()
local llm = llm_class:new(OPENAI_NOT_JSON, {})
assert.truthy(llm)
describe("openai transformer tests, pattern matchers", function()
it("transforms request based on LLM instructions, with json extraction pattern", function()
local llm = llm_class:new(OPENAI_NOT_JSON, {})
assert.truthy(llm)

local result, err = llm:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
"\\#*\\=" -- transformation extraction pattern (loose json)
)
local result, err = llm:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
"\\{((.|\n)*)\\}" -- transformation extraction pattern (loose json)
)

assert.is_nil(result)
assert.is_not_nil(err)
assert.same("AI response did not match specified regular expression", err)
end)
assert.is_nil(err)

result, err = cjson.decode(result)
assert.is_nil(err)

assert.same(EXPECTED_RESULT, result)
end)

it("transforms request based on LLM instructions, but fails to match pattern", function()
local llm = llm_class:new(OPENAI_NOT_JSON, {})
assert.truthy(llm)

local result, err = llm:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
"\\#*\\=" -- transformation extraction pattern (loose json)
)

assert.is_nil(result)
assert.is_not_nil(err)
assert.same("AI response did not match specified regular expression", err)
end) -- it
end)
end end
end)

0 comments on commit 0216500

Please sign in to comment.