diff --git a/apisix/plugins/ai.lua b/apisix/plugins/ai.lua index 6126ed873353..ad6167e25325 100644 --- a/apisix/plugins/ai.lua +++ b/apisix/plugins/ai.lua @@ -27,6 +27,7 @@ local ipairs = ipairs local pcall = pcall local loadstring = loadstring local type = type +local pairs = pairs local get_cache_key_func local get_cache_key_func_def_render @@ -66,6 +67,7 @@ local orig_router_match local orig_handle_upstream = apisix.handle_upstream local orig_balancer_run = load_balancer.run +local default_keepalive_pool = {} local function match_route(ctx) orig_router_match(ctx) @@ -113,7 +115,7 @@ local function ai_upstream() end -local pool_opt = { pool_size = 320 } +local pool_opt local function ai_balancer_run(route) local server = route.value.upstream.nodes[1] if enable_keepalive then @@ -123,96 +125,72 @@ local function ai_balancer_run(route) server.port, "] err: ", err) return ok, err end - balancer.enable_keepalive(60, 1000) + balancer.enable_keepalive(default_keepalive_pool.idle_timeout, + default_keepalive_pool.requests) else balancer.set_current_peer(server.host, server.port or 80) end end local function routes_analyze(routes) - -- TODO: need to add a option in config.yaml to enable this feature(default is true) - local route_flags = core.table.new(0, 5) - local route_up_flags = core.table.new(0, 8) + local route_flags = core.table.new(0, 16) + local route_up_flags = core.table.new(0, 12) for _, route in ipairs(routes) do if type(route) == "table" then - if route.value.methods then - route_flags["methods"] = true - end - - if route.value.host or route.hosts then - route_flags["host"] = true - end - - if route.value.vars then - route_flags["vars"] = true - end - - if route.value.filter_fun then - route_flags["filter_fun"] = true - end - - if route.value.remote_addr or route.remote_addrs then - route_flags["remote_addr"] = true - end - - if route.value.service then - route_flags["service"] = true - end - - if route.value.enable_websocket then - route_flags["enable_websocket"] = true - end - - if route.value.plugins then - route_flags["plugins"] = true - end - - if route.value.upstream_id then - route_flags["upstream_id"] = true - end - - if route.value.service_id then - route_flags["service_id"] = true - end - - if route.value.plugin_config_id then - route_flags["plugin_config_id"] = true - end - - local upstream = route.value.upstream - if upstream and upstream.nodes and #upstream.nodes == 1 then - local node = upstream.nodes[1] - if not core.utils.parse_ipv4(node.host) - and not core.utils.parse_ipv6(node.host) then - route_up_flags["has_domain"] = true - end - - if upstream.pass_host == "pass" then - route_up_flags["pass_host"] = true - end - - if upstream.scheme == "http" then - route_up_flags["scheme"] = true - end - - if upstream.checks then - route_up_flags["checks"] = true - end - - if upstream.retries then - route_up_flags["retries"] = true - end - - if upstream.timeout then - route_up_flags["timeout"] = true - end - - if upstream.tls then - route_up_flags["tls"] = true + for key, value in pairs(route.value) do + -- collect route flags + if key == "methods" then + route_flags["methods"] = true + elseif key == "host" or key == "hosts" then + route_flags["host"] = true + elseif key == "vars" then + route_flags["vars"] = true + elseif key == "filter_fun"then + route_flags["filter_fun"] = true + elseif key == "remote_addr" or key == "remote_addrs" then + route_flags["remote_addr"] = true + elseif key == "service" then + route_flags["service"] = true + elseif key == "enable_websocket" then + route_flags["enable_websocket"] = true + elseif key == "plugins" then + route_flags["plugins"] = true + elseif key == "upstream_id" then + route_flags["upstream_id"] = true + elseif key == "service_id" then + route_flags["service_id"] = true + elseif key == "plugin_config_id" then + route_flags["plugin_config_id"] = true end - if upstream.keepalive then - route_up_flags["keepalive"] = true + -- collect upstream flags + if key == "upstream" then + if #value.nodes == 1 then + for k, v in pairs(value) do + if k == "nodes" then + if (not core.utils.parse_ipv4(v[1].host) + and not core.utils.parse_ipv6(v[1].host)) then + route_up_flags["has_domain"] = true + end + elseif k == "pass_host" and v ~= "pass" then + route_up_flags["pass_host"] = true + elseif k == "scheme" and v ~= "http" then + route_up_flags["scheme"] = true + elseif k == "checks" then + route_up_flags["checks"] = true + elseif k == "retries" then + route_up_flags["retries"] = true + elseif k == "timeout" then + route_up_flags["timeout"] = true + elseif k == "tls" then + route_up_flags["tls"] = true + elseif k == "keepalive" then + route_up_flags["keepalive"] = true + end + end + else + route_up_flags["more_nodes"] = true + end end end end @@ -234,31 +212,42 @@ local function routes_analyze(routes) end end - if not route_flags["service"] - and not route_flags["service_id"] - and not route_flags["upstream_id"] - and not route_flags["enable_websocket"] - and not route_flags["plugins"] - and not route_up_flags["has_domain"] - and route_up_flags["pass_host"] - and route_up_flags["scheme"] - and not route_up_flags["checks"] - and not route_up_flags["retries"] - and not route_up_flags["timeout"] - and not route_up_flags["timeout"] - and not route_up_flags["keepalive"] then - -- replace the upstream module - apisix.handle_upstream = ai_upstream - load_balancer.run = ai_balancer_run - else + if route_flags["service"] + or route_flags["service_id"] + or route_flags["upstream_id"] + or route_flags["enable_websocket"] + or route_flags["plugins"] + or route_up_flags["has_domain"] + or route_up_flags["pass_host"] + or route_up_flags["scheme"] + or route_up_flags["checks"] + or route_up_flags["retries"] + or route_up_flags["timeout"] + or route_up_flags["tls"] + or route_up_flags["keepalive"] + or route_up_flags["more_nodes"] then apisix.handle_upstream = orig_handle_upstream load_balancer.run = orig_balancer_run + else + -- replace the upstream module + apisix.handle_upstream = ai_upstream + load_balancer.run = ai_balancer_run end end function _M.init() event.register(event.CONST.BUILD_ROUTER, routes_analyze) + local local_conf = core.config.local_conf() + local up_keepalive_conf = + core.table.try_read_attr(local_conf, "nginx_config", + "http", "upstream") + default_keepalive_pool.idle_timeout = + core.config_util.parse_time_unit(up_keepalive_conf.keepalive_timeout) + default_keepalive_pool.size = up_keepalive_conf.keepalive + default_keepalive_pool.requests = up_keepalive_conf.keepalive_requests + + pool_opt = { pool_size = default_keepalive_pool.size } end