Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tzssangglass committed Oct 21, 2022
1 parent 7e520c5 commit 38242df
Showing 1 changed file with 87 additions and 98 deletions.
185 changes: 87 additions & 98 deletions apisix/plugins/ai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ local ipairs = ipairs
local pcall = pcall
local loadstring = loadstring
local type = type
local pairs = pairs

local get_cache_key_func
local get_cache_key_func_def_render
Expand Down Expand Up @@ -66,6 +67,7 @@ local orig_router_match
local orig_handle_upstream = apisix.handle_upstream
local orig_balancer_run = load_balancer.run

local default_keepalive_pool = {}

local function match_route(ctx)
orig_router_match(ctx)
Expand Down Expand Up @@ -113,7 +115,7 @@ local function ai_upstream()
end


local pool_opt = { pool_size = 320 }
local pool_opt
local function ai_balancer_run(route)
local server = route.value.upstream.nodes[1]
if enable_keepalive then
Expand All @@ -123,96 +125,72 @@ local function ai_balancer_run(route)
server.port, "] err: ", err)
return ok, err
end
balancer.enable_keepalive(60, 1000)
balancer.enable_keepalive(default_keepalive_pool.idle_timeout,
default_keepalive_pool.requests)
else
balancer.set_current_peer(server.host, server.port or 80)
end
end

local function routes_analyze(routes)
-- TODO: need to add a option in config.yaml to enable this feature(default is true)
local route_flags = core.table.new(0, 5)
local route_up_flags = core.table.new(0, 8)
local route_flags = core.table.new(0, 16)
local route_up_flags = core.table.new(0, 12)
for _, route in ipairs(routes) do
if type(route) == "table" then
if route.value.methods then
route_flags["methods"] = true
end

if route.value.host or route.hosts then
route_flags["host"] = true
end

if route.value.vars then
route_flags["vars"] = true
end

if route.value.filter_fun then
route_flags["filter_fun"] = true
end

if route.value.remote_addr or route.remote_addrs then
route_flags["remote_addr"] = true
end

if route.value.service then
route_flags["service"] = true
end

if route.value.enable_websocket then
route_flags["enable_websocket"] = true
end

if route.value.plugins then
route_flags["plugins"] = true
end

if route.value.upstream_id then
route_flags["upstream_id"] = true
end

if route.value.service_id then
route_flags["service_id"] = true
end

if route.value.plugin_config_id then
route_flags["plugin_config_id"] = true
end

local upstream = route.value.upstream
if upstream and upstream.nodes and #upstream.nodes == 1 then
local node = upstream.nodes[1]
if not core.utils.parse_ipv4(node.host)
and not core.utils.parse_ipv6(node.host) then
route_up_flags["has_domain"] = true
end

if upstream.pass_host == "pass" then
route_up_flags["pass_host"] = true
end

if upstream.scheme == "http" then
route_up_flags["scheme"] = true
end

if upstream.checks then
route_up_flags["checks"] = true
end

if upstream.retries then
route_up_flags["retries"] = true
end

if upstream.timeout then
route_up_flags["timeout"] = true
end

if upstream.tls then
route_up_flags["tls"] = true
for key, value in pairs(route.value) do
-- collect route flags
if key == "methods" then
route_flags["methods"] = true
elseif key == "host" or key == "hosts" then
route_flags["host"] = true
elseif key == "vars" then
route_flags["vars"] = true
elseif key == "filter_fun"then
route_flags["filter_fun"] = true
elseif key == "remote_addr" or key == "remote_addrs" then
route_flags["remote_addr"] = true
elseif key == "service" then
route_flags["service"] = true
elseif key == "enable_websocket" then
route_flags["enable_websocket"] = true
elseif key == "plugins" then
route_flags["plugins"] = true
elseif key == "upstream_id" then
route_flags["upstream_id"] = true
elseif key == "service_id" then
route_flags["service_id"] = true
elseif key == "plugin_config_id" then
route_flags["plugin_config_id"] = true
end

if upstream.keepalive then
route_up_flags["keepalive"] = true
-- collect upstream flags
if key == "upstream" then
if #value.nodes == 1 then
for k, v in pairs(value) do
if k == "nodes" then
if (not core.utils.parse_ipv4(v[1].host)
and not core.utils.parse_ipv6(v[1].host)) then
route_up_flags["has_domain"] = true
end
elseif k == "pass_host" and v ~= "pass" then
route_up_flags["pass_host"] = true
elseif k == "scheme" and v ~= "http" then
route_up_flags["scheme"] = true
elseif k == "checks" then
route_up_flags["checks"] = true
elseif k == "retries" then
route_up_flags["retries"] = true
elseif k == "timeout" then
route_up_flags["timeout"] = true
elseif k == "tls" then
route_up_flags["tls"] = true
elseif k == "keepalive" then
route_up_flags["keepalive"] = true
end
end
else
route_up_flags["more_nodes"] = true
end
end
end
end
Expand All @@ -234,31 +212,42 @@ local function routes_analyze(routes)
end
end

if not route_flags["service"]
and not route_flags["service_id"]
and not route_flags["upstream_id"]
and not route_flags["enable_websocket"]
and not route_flags["plugins"]
and not route_up_flags["has_domain"]
and route_up_flags["pass_host"]
and route_up_flags["scheme"]
and not route_up_flags["checks"]
and not route_up_flags["retries"]
and not route_up_flags["timeout"]
and not route_up_flags["timeout"]
and not route_up_flags["keepalive"] then
-- replace the upstream module
apisix.handle_upstream = ai_upstream
load_balancer.run = ai_balancer_run
else
if route_flags["service"]
or route_flags["service_id"]
or route_flags["upstream_id"]
or route_flags["enable_websocket"]
or route_flags["plugins"]
or route_up_flags["has_domain"]
or route_up_flags["pass_host"]
or route_up_flags["scheme"]
or route_up_flags["checks"]
or route_up_flags["retries"]
or route_up_flags["timeout"]
or route_up_flags["tls"]
or route_up_flags["keepalive"]
or route_up_flags["more_nodes"] then
apisix.handle_upstream = orig_handle_upstream
load_balancer.run = orig_balancer_run
else
-- replace the upstream module
apisix.handle_upstream = ai_upstream
load_balancer.run = ai_balancer_run
end
end


function _M.init()
event.register(event.CONST.BUILD_ROUTER, routes_analyze)
local local_conf = core.config.local_conf()
local up_keepalive_conf =
core.table.try_read_attr(local_conf, "nginx_config",
"http", "upstream")
default_keepalive_pool.idle_timeout =
core.config_util.parse_time_unit(up_keepalive_conf.keepalive_timeout)
default_keepalive_pool.size = up_keepalive_conf.keepalive
default_keepalive_pool.requests = up_keepalive_conf.keepalive_requests

pool_opt = { pool_size = default_keepalive_pool.size }
end


Expand Down

0 comments on commit 38242df

Please sign in to comment.