diff --git a/apisix/init.lua b/apisix/init.lua index b4438cf0f0d3..bbd4e636fca8 100644 --- a/apisix/init.lua +++ b/apisix/init.lua @@ -39,7 +39,6 @@ local set_upstream = apisix_upstream.set_by_route local upstream_util = require("apisix.utils.upstream") local xrpc = require("apisix.stream.xrpc") local ctxdump = require("resty.ctxdump") -local ipmatcher = require("resty.ipmatcher") local ngx_balancer = require("ngx.balancer") local debug = require("apisix.debug") local ngx = ngx @@ -48,7 +47,6 @@ local ngx_exit = ngx.exit local math = math local error = error local ipairs = ipairs -local tostring = tostring local ngx_now = ngx.now local ngx_var = ngx.var local str_byte = string.byte @@ -168,61 +166,9 @@ function _M.http_ssl_phase() end - - -local function parse_domain_for_nodes(nodes) - local new_nodes = core.table.new(#nodes, 0) - for _, node in ipairs(nodes) do - local host = node.host - if not ipmatcher.parse_ipv4(host) and - not ipmatcher.parse_ipv6(host) then - local ip, err = core.resolver.parse_domain(host) - if ip then - local new_node = core.table.clone(node) - new_node.host = ip - new_node.domain = host - core.table.insert(new_nodes, new_node) - end - - if err then - core.log.error("dns resolver domain: ", host, " error: ", err) - end - else - core.table.insert(new_nodes, node) - end - end - return new_nodes -end - - -local function parse_domain_in_up(up) - local nodes = up.value.nodes - local new_nodes, err = parse_domain_for_nodes(nodes) - if not new_nodes then - return nil, err - end - - local ok = upstream_util.compare_upstream_node(up.dns_value, new_nodes) - if ok then - return up - end - - if not up.orig_modifiedIndex then - up.orig_modifiedIndex = up.modifiedIndex - end - up.modifiedIndex = up.orig_modifiedIndex .. "#" .. ngx_now() - - up.dns_value = core.table.clone(up.value) - up.dns_value.nodes = new_nodes - core.log.info("resolve upstream which contain domain: ", - core.json.delay_encode(up, true)) - return up -end - - local function parse_domain_in_route(route) local nodes = route.value.upstream.nodes - local new_nodes, err = parse_domain_for_nodes(nodes) + local new_nodes, err = upstream_util.parse_domain_for_nodes(nodes) if not new_nodes then return nil, err end @@ -280,38 +226,6 @@ local function set_upstream_headers(api_ctx, picked_server) end -local function get_upstream_by_id(up_id) - local upstreams = core.config.fetch_created_obj("/upstreams") - if upstreams then - local upstream = upstreams:get(tostring(up_id)) - if not upstream then - core.log.error("failed to find upstream by id: " .. up_id) - if is_http then - return core.response.exit(502) - end - - return ngx_exit(1) - end - - if upstream.has_domain then - local err - upstream, err = parse_domain_in_up(upstream) - if err then - core.log.error("failed to get resolved upstream: ", err) - if is_http then - return core.response.exit(500) - end - - return ngx_exit(1) - end - end - - core.log.info("parsed upstream: ", core.json.delay_encode(upstream, true)) - return upstream.dns_value or upstream.value - end -end - - local function verify_tls_client(ctx) if ctx and ctx.ssl_client_verified then local res = ngx_var.ssl_client_verify @@ -475,7 +389,15 @@ function _M.http_access_phase() end if up_id then - local upstream = get_upstream_by_id(up_id) + local upstream = apisix_upstream.get_by_id(up_id) + if not upstream then + if is_http then + return core.response.exit(502) + end + + return ngx_exit(1) + end + api_ctx.matched_upstream = upstream else @@ -898,7 +820,17 @@ function _M.stream_preread_phase() local up_id = matched_route.value.upstream_id if up_id then - api_ctx.matched_upstream = get_upstream_by_id(up_id) + local upstream = apisix_upstream.get_by_id(up_id) + if not upstream then + if is_http then + return core.response.exit(502) + end + + return ngx_exit(1) + end + + api_ctx.matched_upstream = upstream + else if matched_route.has_domain then local err diff --git a/apisix/stream/xrpc/sdk.lua b/apisix/stream/xrpc/sdk.lua index f77a4c4e0ba7..487ca392b435 100644 --- a/apisix/stream/xrpc/sdk.lua +++ b/apisix/stream/xrpc/sdk.lua @@ -21,8 +21,10 @@ local core = require("apisix.core") local config_util = require("apisix.core.config_util") local router = require("apisix.stream.router.ip_port") +local apisix_upstream = require("apisix.upstream") local xrpc_socket = require("resty.apisix.stream.xrpc.socket") local ngx_now = ngx.now +local str_fmt = string.format local tab_insert = table.insert local error = error local tostring = tostring @@ -153,18 +155,25 @@ end -- @function xrpc.sdk.set_upstream -- @tparam table xrpc session -- @tparam table the route configuration +-- @treturn nil|string error message if present function _M.set_upstream(session, conf) local up if conf.upstream then up = conf.upstream - -- TODO: support upstream_id + else + local id = conf.upstream_id + up = apisix_upstream.get_by_id(id) + if not up then + return str_fmt("upstream %s can't be got", id) + end end - local key = tostring(conf) - core.log.info("set upstream to: ", key) + local key = tostring(up) + core.log.info("set upstream to: ", key, " conf: ", core.json.delay_encode(up, true)) session._upstream_key = key session.upstream_conf = up + return nil end diff --git a/apisix/upstream.lua b/apisix/upstream.lua index 0b9ce2008624..fe731e8a766d 100644 --- a/apisix/upstream.lua +++ b/apisix/upstream.lua @@ -542,4 +542,30 @@ function _M.init_worker() end +function _M.get_by_id(up_id) + local upstream + local upstreams = core.config.fetch_created_obj("/upstreams") + if upstreams then + upstream = upstreams:get(tostring(up_id)) + end + + if not upstream then + core.log.error("failed to find upstream by id: ", up_id) + return nil + end + + if upstream.has_domain then + local err + upstream, err = upstream_util.parse_domain_in_up(upstream) + if err then + core.log.error("failed to get resolved upstream: ", err) + return nil + end + end + + core.log.info("parsed upstream: ", core.json.delay_encode(upstream, true)) + return upstream.dns_value or upstream.value +end + + return _M diff --git a/apisix/utils/upstream.lua b/apisix/utils/upstream.lua index 74666dac51d0..c39d4cce2219 100644 --- a/apisix/utils/upstream.lua +++ b/apisix/utils/upstream.lua @@ -15,6 +15,8 @@ -- limitations under the License. -- local core = require("apisix.core") +local ipmatcher = require("resty.ipmatcher") +local ngx_now = ngx.now local ipairs = ipairs local type = type @@ -27,7 +29,7 @@ local function sort_by_key_host(a, b) end -function _M.compare_upstream_node(up_conf, new_t) +local function compare_upstream_node(up_conf, new_t) if up_conf == nil then return false end @@ -56,6 +58,58 @@ function _M.compare_upstream_node(up_conf, new_t) return true end +_M.compare_upstream_node = compare_upstream_node + + +local function parse_domain_for_nodes(nodes) + local new_nodes = core.table.new(#nodes, 0) + for _, node in ipairs(nodes) do + local host = node.host + if not ipmatcher.parse_ipv4(host) and + not ipmatcher.parse_ipv6(host) then + local ip, err = core.resolver.parse_domain(host) + if ip then + local new_node = core.table.clone(node) + new_node.host = ip + new_node.domain = host + core.table.insert(new_nodes, new_node) + end + + if err then + core.log.error("dns resolver domain: ", host, " error: ", err) + end + else + core.table.insert(new_nodes, node) + end + end + return new_nodes +end +_M.parse_domain_for_nodes = parse_domain_for_nodes + + +function _M.parse_domain_in_up(up) + local nodes = up.value.nodes + local new_nodes, err = parse_domain_for_nodes(nodes) + if not new_nodes then + return nil, err + end + + local ok = compare_upstream_node(up.dns_value, new_nodes) + if ok then + return up + end + + if not up.orig_modifiedIndex then + up.orig_modifiedIndex = up.modifiedIndex + end + up.modifiedIndex = up.orig_modifiedIndex .. "#" .. ngx_now() + + up.dns_value = core.table.clone(up.value) + up.dns_value.nodes = new_nodes + core.log.info("resolve upstream which contain domain: ", + core.json.delay_encode(up, true)) + return up +end return _M diff --git a/t/xrpc/apisix/stream/xrpc/protocols/pingpong/init.lua b/t/xrpc/apisix/stream/xrpc/protocols/pingpong/init.lua index a45e9b8c7f7a..694e5e7ad9af 100644 --- a/t/xrpc/apisix/stream/xrpc/protocols/pingpong/init.lua +++ b/t/xrpc/apisix/stream/xrpc/protocols/pingpong/init.lua @@ -140,7 +140,11 @@ function _M.from_downstream(session, downstream) local conf = router[ctx.service] if conf then - sdk.set_upstream(session, conf) + local err = sdk.set_upstream(session, conf) + if err then + core.log.error("failed to set upstream: ", err) + return DECLINED + end end end diff --git a/t/xrpc/pingpong.t b/t/xrpc/pingpong.t index 2a296a83da23..a264e2fc8263 100644 --- a/t/xrpc/pingpong.t +++ b/t/xrpc/pingpong.t @@ -260,9 +260,9 @@ passed end sock:send(data:sub(5)) end +--- wait: 1.1 --- error_log failed to read: timeout ---- wait: 1.1 @@ -605,3 +605,129 @@ connect to 127.0.0.1:1995 while prereading client data connect to 127.0.0.3:1995 while prereading client data connect to 127.0.0.4:1995 while prereading client data --- stream_conf_enable + + + +=== TEST 18: use upstream_id +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/upstreams/1', + ngx.HTTP_PUT, + { + nodes = { + ["127.0.0.3:1995"] = 1 + }, + type = "roundrobin" + } + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + local code, body = t('/apisix/admin/stream_routes/2', + ngx.HTTP_PUT, + { + protocol = { + superior_id = 1, + conf = { + service = "a" + }, + name = "pingpong" + }, + upstream_id = 1 + } + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + ngx.say(body) + } + } +--- request +GET /t +--- response_body +passed + + + +=== TEST 19: hit +--- request eval +"POST /t +" . +"pp\x04\x00\x00\x00\x00\x00\x00\x03a\x00\x00\x00ABC" +--- response_body eval +"pp\x04\x00\x00\x00\x00\x00\x00\x03a\x00\x00\x00ABC" +--- grep_error_log eval +qr/connect to \S+ while prereading client data/ +--- grep_error_log_out +connect to 127.0.0.3:1995 while prereading client data +--- stream_conf_enable + + + +=== TEST 20: cache router by version, with upstream_id +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + + local sock = ngx.socket.tcp() + sock:settimeout(1000) + local ok, err = sock:connect("127.0.0.1", 1985) + if not ok then + ngx.log(ngx.ERR, "failed to connect: ", err) + return ngx.exit(503) + end + + assert(sock:send("pp\x04\x00\x00\x00\x00\x00\x00\x03a\x00\x00\x00ABC")) + + ngx.sleep(0.1) + + local code, body = t('/apisix/admin/upstreams/1', + ngx.HTTP_PUT, + { + nodes = { + ["127.0.0.1:1995"] = 1 + }, + type = "roundrobin" + } + ) + if code >= 300 then + ngx.status = code + ngx.say(body) + return + end + + ngx.sleep(0.1) + + local s = "pp\x04\x00\x00\x00\x00\x00\x00\x04a\x00\x00\x00ABCD" + assert(sock:send(s)) + + while true do + local data, err = sock:receiveany(4096) + if not data then + sock:close() + break + end + ngx.print(data) + end + } + } +--- request +GET /t +--- response_body eval +"pp\x04\x00\x00\x00\x00\x00\x00\x03a\x00\x00\x00ABC" . +"pp\x04\x00\x00\x00\x00\x00\x00\x04a\x00\x00\x00ABCD" +--- grep_error_log eval +qr/connect to \S+ while prereading client data/ +--- grep_error_log_out +connect to 127.0.0.3:1995 while prereading client data +connect to 127.0.0.1:1995 while prereading client data +--- stream_conf_enable