Skip to content

Commit

Permalink
refactor(llm): extract class from module
Browse files Browse the repository at this point in the history
  • Loading branch information
Tieske committed May 15, 2024
1 parent 5e01766 commit 31bb26e
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 127 deletions.
222 changes: 113 additions & 109 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ local ai_shared = require("kong.llm.drivers.shared")
local re_match = ngx.re.match
local cjson = require("cjson.safe")
local fmt = string.format
local EMPTY = {}

-- TODO: this module returns a class but also has function (not methods) that are not part of the class
-- this is confusing. Refactor, the `new` function (not method!) should remain and return a new instance

-- The module table
local _M = {
config_schema = require "kong.llm.schemas",
}
_M.__index = _M



Expand Down Expand Up @@ -82,140 +81,145 @@ do
end



function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex_match)
local err, _

-- set up the request
local ai_request = {
messages = {
[1] = {
role = "system",
content = system_prompt,
do
------------------------------------------------------------------------------
-- LLM class implementation
------------------------------------------------------------------------------
local LLM = {}
LLM.__index = LLM



function LLM:ai_introspect_body(request, system_prompt, http_opts, response_regex_match)
local err, _

-- set up the request
local ai_request = {
messages = {
[1] = {
role = "system",
content = system_prompt,
},
[2] = {
role = "user",
content = request,
}
},
[2] = {
role = "user",
content = request,
}
},
stream = false,
}
stream = false,
}

-- convert it to the specified driver format
ai_request, _, err = self.driver.to_format(ai_request, self.conf.model, "llm/v1/chat")
if err then
return nil, err
end

-- run the shared logging/analytics/auth function
ai_shared.pre_request(self.conf, ai_request)
-- convert it to the specified driver format
ai_request, _, err = self.driver.to_format(ai_request, self.conf.model, "llm/v1/chat")
if err then
return nil, err
end

-- send it to the ai service
local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false)
if err then
return nil, "failed to introspect request with AI service: " .. err
end
-- run the shared logging/analytics/auth function
ai_shared.pre_request(self.conf, ai_request)

-- parse and convert the response
local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type)
if err then
return nil, "failed to convert AI response to Kong format: " .. err
end
-- send it to the ai service
local ai_response, _, err = self.driver.subrequest(ai_request, self.conf, http_opts, false)
if err then
return nil, "failed to introspect request with AI service: " .. err
end

-- run the shared logging/analytics function
ai_shared.post_request(self.conf, ai_response)
-- parse and convert the response
local ai_response, _, err = self.driver.from_format(ai_response, self.conf.model, self.conf.route_type)
if err then
return nil, "failed to convert AI response to Kong format: " .. err
end

local ai_response, err = cjson.decode(ai_response)
if err then
return nil, "failed to convert AI response to JSON: " .. err
end
-- run the shared logging/analytics function
ai_shared.post_request(self.conf, ai_response)

local new_request_body = ai_response.choices
and #ai_response.choices > 0
and ai_response.choices[1]
and ai_response.choices[1].message
and ai_response.choices[1].message.content
if not new_request_body then
return nil, "no 'choices' in upstream AI service response"
end

-- if specified, extract the first regex match from the AI response
-- this is useful for AI models that pad with assistant text, even when
-- we ask them NOT to.
if response_regex_match then
local matches, err = re_match(new_request_body, response_regex_match, "ijom")
local ai_response, err = cjson.decode(ai_response)
if err then
return nil, "failed regex matching ai response: " .. err
return nil, "failed to convert AI response to JSON: " .. err
end

if matches then
new_request_body = matches[0] -- this array DOES start at 0, for some reason
local new_request_body = ((ai_response.choices or EMPTY)[1].message or EMPTY).content
if not new_request_body then
return nil, "no 'choices' in upstream AI service response"
end

else
return nil, "AI response did not match specified regular expression"
-- if specified, extract the first regex match from the AI response
-- this is useful for AI models that pad with assistant text, even when
-- we ask them NOT to.
if response_regex_match then
local matches, err = re_match(new_request_body, response_regex_match, "ijom")
if err then
return nil, "failed regex matching ai response: " .. err
end

if matches then
new_request_body = matches[0] -- this array DOES start at 0, for some reason

else
return nil, "AI response did not match specified regular expression"

end
end

return new_request_body
end

return new_request_body
end


-- Parse the response instructions.
-- @tparam string|table in_body The response to parse, if a string, it will be parsed as JSON.
-- @treturn[1] table The headers, field `in_body.headers`
-- @treturn[1] string The body, field `in_body.body` (or if absent `in_body` itself as a table)
-- @treturn[1] number The status, field `in_body.status` (or 200 if absent)
-- @treturn[2] nil
-- @treturn[2] string An error message if parsing failed or input wasn't a table
function LLM:parse_json_instructions(in_body)
local err
if type(in_body) == "string" then
in_body, err = cjson.decode(in_body)
if err then
return nil, nil, nil, err
end
end

-- Parse the response instructions.
-- @tparam string|table in_body The response to parse, if a string, it will be parsed as JSON.
-- @treturn[1] table The headers, field `in_body.headers`
-- @treturn[1] string The body, field `in_body.body` (or if absent `in_body` itself as a table)
-- @treturn[1] number The status, field `in_body.status` (or 200 if absent)
-- @treturn[2] nil
-- @treturn[2] string An error message if parsing failed or input wasn't a table
function _M:parse_json_instructions(in_body)
local err
if type(in_body) == "string" then
in_body, err = cjson.decode(in_body)
if err then
return nil, nil, nil, err
if type(in_body) ~= "table" then
return nil, nil, nil, "input not table or string"
end
end

if type(in_body) ~= "table" then
return nil, nil, nil, "input not table or string"
return
in_body.headers,
in_body.body or in_body,
in_body.status or 200
end

return
in_body.headers,
in_body.body or in_body,
in_body.status or 200
end


--- Instantiate a new LLM driver instance.
-- @tparam table conf Configuration table
-- @tparam table http_opts HTTP options table
-- @treturn[1] table A new LLM driver instance
-- @treturn[2] nil
-- @treturn[2] string An error message if instantiation failed
function _M.new_driver(conf, http_opts)
local self = {
conf = conf or {},
http_opts = http_opts or {},
}
setmetatable(self, LLM)

local provider = (self.conf.model or {}).provider or "NONE_SET"
local driver_module = "kong.llm.drivers." .. provider
local ok
ok, self.driver = pcall(require, driver_module)
if not ok then
local err = "could not instantiate " .. driver_module .. " package"
kong.log.err(err)
return nil, err
end

--- Instantiate a new LLM instance.
-- @tparam table conf Configuration table
-- @tparam table http_opts HTTP options table
-- @treturn[1] table A new LLM instance
-- @treturn[2] nil
-- @treturn[2] string An error message if instantiation failed
function _M:new(conf, http_opts)
local self = {
conf = conf or {},
http_opts = http_opts or {},
}
setmetatable(self, _M)

local provider = (self.conf.model or {}).provider or "NONE_SET"
local driver_module = "kong.llm.drivers." .. provider
local ok
ok, self.driver = pcall(require, driver_module)
if not ok then
local err = "could not instantiate " .. driver_module .. " package"
kong.log.err(err)
return nil, err
return self
end

return self
end



return _M
2 changes: 1 addition & 1 deletion kong/plugins/ai-request-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function _M:access(conf)
local http_opts = create_http_opts(conf)
conf.llm.__plugin_id = conf.__plugin_id
conf.llm.__key__ = conf.__key__
local ai_driver, err = llm:new(conf.llm, http_opts)
local ai_driver, err = llm.new_driver(conf.llm, http_opts)

if not ai_driver then
return internal_server_error(err)
Expand Down
2 changes: 1 addition & 1 deletion kong/plugins/ai-response-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function _M:access(conf)
local http_opts = create_http_opts(conf)
conf.llm.__plugin_id = conf.__plugin_id
conf.llm.__key__ = conf.__key__
local ai_driver, err = llm:new(conf.llm, http_opts)
local ai_driver, err = llm.new_driver(conf.llm, http_opts)

if not ai_driver then
return internal_server_error(err)
Expand Down
20 changes: 10 additions & 10 deletions spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
local llm_class = require("kong.llm")
local llm = require("kong.llm")
local helpers = require "spec.helpers"
local cjson = require "cjson"
local http_mock = require "spec.helpers.http_mock"
Expand Down Expand Up @@ -224,10 +224,10 @@ describe(PLUGIN_NAME .. ": (unit)", function()
for name, format_options in pairs(FORMATS) do
describe(name .. " transformer tests, exact json response", function()
it("transforms request based on LLM instructions", function()
local llm = llm_class:new(format_options, {})
assert.truthy(llm)
local llmdriver = llm.new_driver(format_options, {})
assert.truthy(llmdriver)

local result, err = llm:ai_introspect_body(
local result, err = llmdriver:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
Expand All @@ -246,10 +246,10 @@ describe(PLUGIN_NAME .. ": (unit)", function()

describe("openai transformer tests, pattern matchers", function()
it("transforms request based on LLM instructions, with json extraction pattern", function()
local llm = llm_class:new(OPENAI_NOT_JSON, {})
assert.truthy(llm)
local llmdriver = llm.new_driver(OPENAI_NOT_JSON, {})
assert.truthy(llmdriver)

local result, err = llm:ai_introspect_body(
local result, err = llmdriver:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
Expand All @@ -265,10 +265,10 @@ describe(PLUGIN_NAME .. ": (unit)", function()
end)

it("transforms request based on LLM instructions, but fails to match pattern", function()
local llm = llm_class:new(OPENAI_NOT_JSON, {})
assert.truthy(llm)
local llmdriver = llm.new_driver(OPENAI_NOT_JSON, {})
assert.truthy(llmdriver)

local result, err = llm:ai_introspect_body(
local result, err = llmdriver:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
local llm_class = require("kong.llm")
local llm = require("kong.llm")
local helpers = require "spec.helpers"
local cjson = require "cjson"
local http_mock = require "spec.helpers.http_mock"
Expand Down Expand Up @@ -90,10 +90,10 @@ describe(PLUGIN_NAME .. ": (unit)", function()

describe("openai transformer tests, specific response", function()
it("transforms request based on LLM instructions, with response transformation instructions format", function()
local llm = llm_class:new(OPENAI_INSTRUCTIONAL_RESPONSE, {})
assert.truthy(llm)
local llmdriver = llm.new_driver(OPENAI_INSTRUCTIONAL_RESPONSE, {})
assert.truthy(llmdriver)

local result, err = llm:ai_introspect_body(
local result, err = llmdriver:ai_introspect_body(
REQUEST_BODY, -- request body
SYSTEM_PROMPT, -- conf.prompt
{}, -- http opts
Expand All @@ -107,14 +107,14 @@ describe(PLUGIN_NAME .. ": (unit)", function()
assert.same(EXPECTED_RESULT, table_result)

-- parse in response string format
local headers, body, status, err = llm:parse_json_instructions(result)
local headers, body, status, err = llmdriver:parse_json_instructions(result)
assert.is_nil(err)
assert.same({ ["content-type"] = "application/xml" }, headers)
assert.same(209, status)
assert.same(EXPECTED_RESULT.body, body)

-- parse in response table format
headers, body, status, err = llm:parse_json_instructions(table_result)
headers, body, status, err = llmdriver:parse_json_instructions(table_result)
assert.is_nil(err)
assert.same({ ["content-type"] = "application/xml" }, headers)
assert.same(209, status)
Expand Down

0 comments on commit 31bb26e

Please sign in to comment.