Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek committed Aug 15, 2024
1 parent bf90fc2 commit f5d902f
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 130 deletions.
2 changes: 2 additions & 0 deletions apisix/balancer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ do
local default_keepalive_pool

function set_current_peer(server, ctx)
core.log.warn("dibag peer: ", core.json.encode(server))
core.log.warn(debug.traceback("dibag"))
local up_conf = ctx.upstream_conf
local keepalive_pool = up_conf.keepalive_pool

Expand Down
15 changes: 13 additions & 2 deletions apisix/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ local tonumber = tonumber
local type = type
local pairs = pairs
local ngx_re_match = ngx.re.match
local balancer = require("ngx.balancer")
local control_api_router

local is_http = false
Expand Down Expand Up @@ -722,7 +723,9 @@ function _M.http_access_phase()
plugin.run_plugin("access", plugins, api_ctx)
end

_M.handle_upstream(api_ctx, route, enable_websocket)
if not api_ctx.custom_upstream_ip then
_M.handle_upstream(api_ctx, route, enable_websocket)
end
end


Expand Down Expand Up @@ -893,7 +896,15 @@ function _M.http_balancer_phase()
return core.response.exit(500)
end

load_balancer.run(api_ctx.matched_route, api_ctx, common_phase)
if api_ctx.custom_upstream_ip then
local ok, err = balancer.set_current_peer(api_ctx.custom_upstream_ip, api_ctx.custom_upstream_port)
if not ok then
core.log.error("failed to overwrite upstream for ai_proxy: ", err)
return core.response.exit(500)
end
else
load_balancer.run(api_ctx.matched_route, api_ctx, common_phase)
end
end


Expand Down
92 changes: 18 additions & 74 deletions apisix/plugins/ai-proxy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ local _M = {


function _M.check_schema(conf)
-- check custom URL correctness
return core.schema.check(schema, conf)
end


-- static messages
local ERROR__NOT_SET = 'data: {"error": true, "message": "empty or unsupported transformer response"}'
-- constants messages
local CONTENT_TYPE_JSON = "application/json"

-- formats_compatible is a map of formats that are compatible with each other.
local formats_compatible = {
Expand Down Expand Up @@ -111,115 +111,59 @@ local function transform_body(conf, ctx)
ctx.plugin.buffered_response_body = response_body
end


function _M.header_filter(conf, ctx)
-- only act on 200 in first release - pass the unmodifed response all the way through if any failure
if ngx.status ~= 200 then
return
local function get_request_table()
local req_body, err = core.request.get_body() -- TODO: max size
if not req_body then
return nil, "failed to get request body: " .. err
end

local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider)
-- ai_driver.post_request(conf)

transform_body(conf, ctx)
core.request.set_header(ctx, "Content-Length", nil)
end

function _M.body_filter(conf, ctx)
if ngx.status ~= 200 then
return
req_body, err = req_body:gsub("\\\"", "\"") -- remove escaping in JSON
if not req_body then
return nil, "failed to remove escaping from body: " .. req_body .. ". err: " .. err
end

ngx.arg[1] = ctx.plugin.buffered_response_body
ngx.arg[2] = true

ctx.plugin.buffered_response_body = nil
return core.json.decode(req_body)
end


function _M.access(conf, ctx)
local f = io.open("dibag", "w+")
f:write(core.json.encode(ngx.ctx, true))
f:close()
local route_type = conf.route_type
local multipart = false

local content_type = core.request.header(ctx, "Content-Type") or "application/json"
multipart = content_type == "multipart/form-data" -- this may be a large file upload, so we have to proxy it directly

local request_table = core.request.get_body() -- TODO: max size
local esc, err = request_table:gsub("\\\"", "\"")
if err then
core.log.error("dibagerr: ", err)
end
request_table = core.json.decode(esc)
core.log.warn("dibag: type: ", core.json.encode(esc))
if not request_table then
return 400, "content-type header does not match request body, or bad JSON formatting"
local content_type = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if content_type ~= CONTENT_TYPE_JSON then
return 400, "unsupported content-type: " .. content_type
end

-- copy from the user request if present
if (not multipart) and (not conf.model.name) and (request_table.model) then
if type(request_table.model) == "string" then
conf.model.name = request_table.model
end
elseif multipart then
conf.model.name = "NOT_SPECIFIED" -- TEST: UPLOAD A FILE hehe
end

-- check that the user isn't trying to override the plugin conf model in the request body
if type(request_table.model) == "string" and request_table.model ~= "" then
if request_table.model ~= conf.model.name then
return 400, "cannot use own model - must be: " .. conf.model.name
end
end

-- model is stashed in the copied plugin conf, for consistency in transformation functions
if not conf.model.name then
return 400, "model parameter not found in request, nor in gateway configuration"
local request_table, err = get_request_table()
if not request_table then
return 400, err
end

-- check the incoming format is the same as the configured LLM format
local compatible, err = is_compatible(request_table, route_type)
if not multipart and not compatible then
-- llm_state.disable_ai_proxy_response_transform()
return 400, err
end

local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. conf.model.provider)

-- execute pre-request hooks for this driver

-- transform the body to kapisix-format for this provider/model
local parsed_request_body, content_type, err
if route_type ~= "preserve" and (not multipart) then
-- transform the body to kapisix-format for this provider/model
parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf.model, route_type)
if err then
-- llm_state.disable_ai_proxy_response_transform()
return 400, err
end
end

-- execute pre-request hooks for "all" drivers before set new body
local ok, err = ai_driver.pre_request(conf, parsed_request_body)
if not ok then
return 400, err
end

if route_type ~= "preserve" then
ngx_req.set_body_data(core.json.encode(parsed_request_body))
core.request.set_header(ctx, "Content-Type", content_type)
end

-- get the provider's cached identity interface - nil may come back, which is fine

-- now re-configure the request for this operation type
local ok, err = ai_driver.configure_request(conf, ctx)
if not ok then
core.log.error("failed to configure request for AI service: ", err)
return 500
end

end

return _M
42 changes: 15 additions & 27 deletions apisix/plugins/ai-proxy/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ local _M = {}
-- imports
local cjson = require("cjson.safe")
local core = require("apisix.core")
local upstream = require("apisix.upstream")
local fmt = string.format
local socket_url = require "socket.url"
local string_gsub = string.gsub
local set_matched_upstream_to_ctx = require("apisix.init").set_matched_upstream_to_ctx
--

-- globals
local DRIVER_NAME = "openai"
local DEFAULT_HOST = "api.openai.com"
local DEFAULT_PORT = 443
--

---
Expand Down Expand Up @@ -86,7 +85,7 @@ local transformers_from = {
}

function _M.from_format(response_string, model_info, route_type)
ngx.log(ngx.DEBUG, "converting from ", model_info.provider, "://", route_type, " type to kapisix")
core.log.info("converting from ", model_info.provider, "://", route_type, " type to kapisix")

-- MUST return a string, to set as the response body
if not transformers_from[route_type] then
Expand All @@ -106,7 +105,7 @@ function _M.from_format(response_string, model_info, route_type)
end

function _M.to_format(request_table, model_info, route_type)
ngx.log(ngx.DEBUG, "converting from kapisix type to ", model_info.provider, "/", route_type)
core.log.info("converting from kapisix type to ", model_info.provider, "/", route_type)

if route_type == "preserve" then
-- do nothing
Expand All @@ -131,22 +130,6 @@ function _M.to_format(request_table, model_info, route_type)
return response_object, content_type, nil
end

function _M.header_filter_hooks(body)
-- nothing to parse in header_filter phase
end

-- function _M.post_request(conf)
-- for i, v in ipairs(headers_to_be_cleared) do
-- response.clear_header(v)
-- end
-- end

function _M.pre_request(conf, body)
core.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli

return true, nil
end

local operation_map = {
["llm/v1/completions"] = {
path = "/v1/completions",
Expand All @@ -160,6 +143,13 @@ local operation_map = {

-- returns err or nil
function _M.configure_request(conf, ctx)
local ip, err = core.resolver.parse_domain(conf.model.options.upstream_host or DEFAULT_HOST)
if not ip then
core.log.error("failed to resolve ai_proxy upstream host: ", err)
return core.response.exit(500)
end
ctx.custom_upstream_ip = ip
ctx.custom_upstream_port = DEFAULT_PORT
local parsed_url

if (conf.model.options and conf.model.options.upstream_url) then
Expand All @@ -176,13 +166,11 @@ function _M.configure_request(conf, ctx)
-- if the path is read from a URL capture, ensure that it is valid
parsed_url.path = string_gsub(parsed_url.path, "^/*", "/")

-- set_matched_upstream_to_ctx(ctx, ctx.matched_route, ctx.route.value.upstream_id) -- TODO: check usage with tfsp plugin
-- ctx.matched_upsteam.
ngx.var.upstream_uri = parsed_url.path -- TODO: escaping
ngx.var.upstream_scheme = parsed_url.scheme -- sanity http/s
core.log.warn("scheme: ", parsed_url.scheme)
ngx.var.upstream_host = parsed_url.host -- TODO: sanity checks. encapsulate to a func
-- upstream.set(parsed_url.host, (tonumber(parsed_url.port) or 443))
ngx.var.upstream_scheme = "https" -- sanity http/s
ngx.var.upstream_host = conf.model.options.upstream_host or DEFAULT_HOST -- TODO: sanity checks. encapsulate to a func
ctx.custom_balancer_host = conf.model.options.upstream_host or DEFAULT_HOST
ctx.custom_balancer_port = conf.model.options.port or DEFAULT_PORT

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
Expand Down
35 changes: 8 additions & 27 deletions apisix/plugins/ai-proxy/schema.lua
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,14 @@ local model_options_schema = {
maximum = 1,

},
upstream_url = {
upstream_host = {
type = "string",
description = "To be specified to override the URL to the AI provider endpoints",
description = "To be specified to override the host of the AI provider",

},
upstream_port = {
type = "integer",
description = "To be specified to override the AI provider port",

},
upstream_path = {
Expand All @@ -94,7 +99,7 @@ local model_schema = {
type = "string",
description = "AI provider request format - kapisix translates "
.. "requests to and from the specified backend compatible formats.",
oneOf = { "openai" }
oneOf = { "openai" }, -- add more providers later

},
name = {
Expand All @@ -107,30 +112,6 @@ local model_schema = {
}



-- TODO: introduce a new param in ctx that allows plugins to add log items dynamically
-- without having to configure the error_log_format
local logging_schema = {
type = "object",
required = true,
properties = {
log_statistics = {
type = "boolean",
description = "If enabled and supported by the driver, "
.. "will add model usage and token metrics into the kapisix log plugin(s) output.",
required = true,
default = false
},
log_payloads = {
type = "boolean",
description = "If enabled, will log the request and response body into the kapisix log plugin(s) output.",
required = true,
default = false
},
}
}


return {
type = "object",
properties = {
Expand Down

0 comments on commit f5d902f

Please sign in to comment.