diff --git a/src/Connections.jl b/src/Connections.jl index 60d6eb9b7..95b8de16d 100644 --- a/src/Connections.jl +++ b/src/Connections.jl @@ -610,31 +610,6 @@ function sslconnection(::Type{SSLContext}, tcp::TCPSocket, host::AbstractString; return io end -function sslupgrade(::Type{IOType}, c::Connection{T}, - host::AbstractString; - pool::Union{Nothing, Pool}=nothing, - require_ssl_verification::Bool=NetworkOptions.verify_host(host, "SSL"), - keepalive::Bool=true, - readtimeout::Int=0, - kw...)::Connection{IOType} where {T, IOType} - # initiate the upgrade to SSL - # if the upgrade fails, an error will be thrown and the original c will be closed - # in ConnectionRequest - tls = if readtimeout > 0 - try_with_timeout(readtimeout) do _ - sslconnection(IOType, c.io, host; require_ssl_verification=require_ssl_verification, keepalive=keepalive, kw...) - end - else - sslconnection(IOType, c.io, host; require_ssl_verification=require_ssl_verification, keepalive=keepalive, kw...) - end - # success, now we turn it into a new Connection - conn = Connection(host, "", 0, require_ssl_verification, keepalive, tls) - # release the "old" one, but don't return the connection since we're hijacking the socket - release(getpool(pool, T), connectionkey(c)) - # and return the new one - return acquire(() -> conn, getpool(pool, IOType), connectionkey(conn); forcenew=true) -end - function Base.show(io::IO, c::Connection) nwaiting = applicable(tcpsocket, c.io) ? bytesavailable(tcpsocket(c.io)) : 0 print( diff --git a/src/HTTP.jl b/src/HTTP.jl index 941bab9f5..f707b043a 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -42,6 +42,7 @@ include("Connections.jl") ;using .Connections const ConnectionPool = Connections include("StatusCodes.jl") ;using .StatusCodes include("Messages.jl") ;using .Messages +include("Tunnel.jl") ;using .Tunnel include("cookies.jl") ;using .Cookies include("Streams.jl") ;using .Streams diff --git a/src/Tunnel.jl b/src/Tunnel.jl new file mode 100644 index 000000000..7eca6693e --- /dev/null +++ b/src/Tunnel.jl @@ -0,0 +1,103 @@ +module Tunnel + +export newtunnelconnection + +using Sockets, LoggingExtras, NetworkOptions, URIs +using ConcurrentUtilities: acquire, try_with_timeout + +using ..Connections, ..Messages, ..Exceptions +using ..Connections: connection_limit_warning, getpool, getconnection, sslconnection, connectionkey, connection_isvalid + +function newtunnelconnection(; + target_type::Type{<:IO}, + target_host::AbstractString, + target_port::AbstractString, + proxy_type::Type{<:IO}, + proxy_host::AbstractString, + proxy_port::AbstractString, + proxy_auth::AbstractString="", + pool::Union{Nothing, Pool}=nothing, + connection_limit=nothing, + forcenew::Bool=false, + idle_timeout=typemax(Int), + connect_timeout::Int=30, + readtimeout::Int=30, + keepalive::Bool=true, + kw...) + connection_limit_warning(connection_limit) + + if isempty(target_port) + target_port = istcptype(target_type) ? "80" : "443" + end + + require_ssl_verification = get(kw, :require_ssl_verification, NetworkOptions.verify_host(target_host, "SSL")) + host_key = proxy_host * "/" * target_host + port_key = proxy_port * "/" * target_port + key = (host_key, port_key, require_ssl_verification, keepalive, true) + + return acquire( + getpool(pool, target_type), + key; + forcenew=forcenew, + isvalid=c->connection_isvalid(c, Int(idle_timeout))) do + + conn = Connection(host_key, port_key, idle_timeout, require_ssl_verification, keepalive, + try_with_timeout0(connect_timeout) do _ + getconnection(proxy_type, proxy_host, proxy_port; keepalive, kw...) + end + ) + try + try_with_timeout0(readtimeout) do _ + connect_tunnel(conn, target_host, target_port, proxy_auth) + end + + if !istcptype(target_type) + tls = try_with_timeout0(readtimeout) do _ + sslconnection(target_type, conn.io, target_host; keepalive, kw...) + end + + # success, now we turn it into a new Connection + conn = Connection(host_key, port_key, idle_timeout, require_ssl_verification, keepalive, tls) + end + + @assert connectionkey(conn) === key + + conn + catch ex + close(conn) + rethrow() + end + end +end + +function connect_tunnel(io, target_host, target_port, proxy_auth) + target = "$(URIs.hoststring(target_host)):$(target_port)" + @debugv 1 "📡 CONNECT HTTPS tunnel to $target" + headers = Dict("Host" => target) + if (!isempty(proxy_auth)) + headers["Proxy-Authorization"] = proxy_auth + end + request = Request("CONNECT", target, headers) + # @debugv 2 "connect_tunnel: writing headers" + writeheaders(io, request) + # @debugv 2 "connect_tunnel: reading headers" + readheaders(io, request.response) + # @debugv 2 "connect_tunnel: done reading headers" + if request.response.status != 200 + throw(StatusError(request.response.status, + request.method, request.target, request.response)) + end +end + +function try_with_timeout0(f, timeout, ::Type{T}=Any) where {T} + if timeout > 0 + try_with_timeout(f, timeout, T) + else + f(Ref(false)) + end +end + +istcptype(::Type{TCPSocket}) = true +istcptype(::Type{<:IO}) = false + +end # module Tunnel diff --git a/src/clientlayers/ConnectionRequest.jl b/src/clientlayers/ConnectionRequest.jl index 564a8f088..8e4bbe0ca 100644 --- a/src/clientlayers/ConnectionRequest.jl +++ b/src/clientlayers/ConnectionRequest.jl @@ -3,7 +3,7 @@ module ConnectionRequest using URIs, Sockets, Base64, LoggingExtras, ConcurrentUtilities, ExceptionUnwrapping using MbedTLS: SSLContext, SSLConfig using OpenSSL: SSLStream -using ..Messages, ..IOExtras, ..Connections, ..Streams, ..Exceptions +using ..Messages, ..IOExtras, ..Connections, ..Streams, ..Exceptions, ..Tunnel import ..SOCKET_TYPE_TLS islocalhost(host::AbstractString) = host == "localhost" || host == "127.0.0.1" || host == "::1" || host == "0000:0000:0000:0000:0000:0000:0000:0001" || host == "0:0:0:0:0:0:0:1" @@ -77,8 +77,31 @@ function connectionlayer(handler) IOType = sockettype(url, socket_type, socket_type_tls) start_time = time() try - io = newconnection(IOType, url.host, url.port; readtimeout=readtimeout, connect_timeout=connect_timeout, kw...) + if !isnothing(proxy) && req.url.scheme in ("https", "wss", "ws") + target_IOType = sockettype(target_url, socket_type, socket_type_tls) + + io = newtunnelconnection(; + target_type=target_IOType, + target_host=target_url.host, + target_port=target_url.port, + proxy_type=IOType, + proxy_host=url.host, + proxy_port=url.port, + proxy_auth=header(req, "Proxy-Authorization"), + connect_timeout, + readtimeout, + kw... + ) + + req.headers = filter(x->x.first != "Proxy-Authorization", req.headers) + else + io = newconnection(IOType, url.host, url.port; readtimeout=readtimeout, connect_timeout=connect_timeout, kw...) + end catch e + if e isa StatusError + return e.response + end + if logerrors msg = current_exceptions_to_string() @error msg type=Symbol("HTTP.ConnectError") method=req.method url=req.url context=req.context logtag=logtag @@ -91,31 +114,6 @@ function connectionlayer(handler) shouldreuse = !(target_url.scheme in ("ws", "wss")) try - if proxy !== nothing && target_url.scheme in ("https", "wss", "ws") - shouldreuse = false - # tunnel request - if target_url.scheme in ("https", "wss") - target_url = URI(target_url, port=443) - elseif target_url.scheme in ("ws", ) && target_url.port == "" - target_url = URI(target_url, port=80) # if there is no port info, connect_tunnel will fail - end - r = if readtimeout > 0 - try_with_timeout(readtimeout) do _ - connect_tunnel(io, target_url, req) - end - else - connect_tunnel(io, target_url, req) - end - if r.status != 200 - close(io) - return r - end - if target_url.scheme in ("https", "wss") - io = Connections.sslupgrade(socket_type_tls, io, target_url.host; readtimeout=readtimeout, kw...) - end - req.headers = filter(x->x.first != "Proxy-Authorization", req.headers) - end - stream = Stream(req.response, io) return handler(stream; readtimeout=readtimeout, logerrors=logerrors, logtag=logtag, kw...) catch e @@ -153,20 +151,4 @@ end sockettype(url::URI, tcp, tls) = url.scheme in ("wss", "https") ? tls : tcp -function connect_tunnel(io, target_url, req) - target = "$(URIs.hoststring(target_url.host)):$(target_url.port)" - @debugv 1 "📡 CONNECT HTTPS tunnel to $target" - headers = Dict("Host" => target) - if (auth = header(req, "Proxy-Authorization"); !isempty(auth)) - headers["Proxy-Authorization"] = auth - end - request = Request("CONNECT", target, headers) - # @debugv 2 "connect_tunnel: writing headers" - writeheaders(io, request) - # @debugv 2 "connect_tunnel: reading headers" - readheaders(io, request.response) - # @debugv 2 "connect_tunnel: done reading headers" - return request.response -end - end # module ConnectionRequest diff --git a/test/pool.jl b/test/pool.jl new file mode 100644 index 000000000..ebf58a37e --- /dev/null +++ b/test/pool.jl @@ -0,0 +1,181 @@ +module TestPool + +using HTTP +import ..httpbin +using Sockets +using Test + +function pooledconnections(socket_type) + pool = HTTP.Connections.getpool(nothing, socket_type) + conns_per_key = values(pool.keyedvalues) + [c for conns in conns_per_key for c in conns if isopen(c)] +end + +@testset "$schema pool" for (schema, socket_type) in [ + ("http", Sockets.TCPSocket), + ("https", HTTP.SOCKET_TYPE_TLS[])] + HTTP.Connections.closeall() + @test length(pooledconnections(socket_type)) == 0 + try + function request_ip() + r = HTTP.get("$schema://$httpbin/ip"; retry=false, redirect = false, status_exception=true) + String(r.body) + end + + @testset "Sequential request use the same socket" begin + request_ip() + conns = pooledconnections(socket_type) + @test length(conns) == 1 + conn1io = conns[1].io + + request_ip() + conns = pooledconnections(socket_type) + @test length(conns) == 1 + @test conn1io === conns[1].io + end + + @testset "Parallell requests however use parallell connections" begin + n_asyncgetters = 3 + asyncgetters = [@async request_ip() for _ in 1:n_asyncgetters] + wait.(asyncgetters) + + conns = pooledconnections(socket_type) + @test length(conns) == n_asyncgetters + end + finally + HTTP.Connections.closeall() + end +end + +function readwrite(src, dst) + n = 0 + while isopen(dst) && !eof(src) + buff = readavailable(src) + if isopen(dst) + write(dst, buff) + end + n += length(buff) + end + n +end + +@testset "http pool with proxy" begin + downstreamconnections = Base.IdSet{HTTP.Connections.Connection}() + upstreamconnections = Base.IdSet{HTTP.Connections.Connection}() + downstreamcount = 0 + upstreamcount = 0 + + # Simple implementation of an http proxy server + proxy = HTTP.listen!(IPv4(0), 8082; stream = true) do http::HTTP.Stream + push!(downstreamconnections, http.stream) + downstreamcount += 1 + + HTTP.open(http.message.method, http.message.target, http.message.headers; + decompress = false, version = http.message.version, retry=false, + redirect = false) do targetstream + push!(upstreamconnections, targetstream.stream) + upstreamcount += 1 + + up = @async readwrite(http, targetstream) + targetresponse = startread(targetstream) + + HTTP.setstatus(http, targetresponse.status) + for h in targetresponse.headers + HTTP.setheader(http, h) + end + + HTTP.startwrite(http) + readwrite(targetstream, http) + + wait(up) + end + end + + try + function http_request_ip_through_proxy() + r = HTTP.get("http://$httpbin/ip"; proxy="http://localhost:8082", retry=false, redirect = false, status_exception=true) + String(r.body) + end + + # Make the HTTP request + http_request_ip_through_proxy() + @test length(downstreamconnections) == 1 + @test length(upstreamconnections) == 1 + @test downstreamcount == 1 + @test upstreamcount == 1 + + # Make another request + # This should reuse connections from the pool in both the client and the proxy + http_request_ip_through_proxy() + + # Check that additional requests were made, both downstream and upstream + @test downstreamcount == 2 + @test upstreamcount == 2 + # But the set of unique connections in either direction should remain of size 1 + @test length(downstreamconnections) == 1 + @test length(upstreamconnections) == 1 + finally + HTTP.Connections.closeall() + close(proxy) + wait(proxy) + end +end + +function readwriteclose(src, dst) + try + readwrite(src, dst) + finally + close(src) + close(dst) + end +end + +@testset "https pool with proxy" begin + connectcount = 0 + + # Simple implementation of a connect proxy server + proxy = HTTP.listen!(IPv4(0), 8082; stream = true) do http::HTTP.Stream + @assert http.message.method == "CONNECT" + connectcount += 1 + + hostport = split(http.message.target, ":") + targetstream = connect(hostport[1], parse(Int, get(hostport, 2, "443"))) + + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + up = @async readwriteclose(http.stream.io, targetstream) + readwriteclose(targetstream, http.stream.io) + wait(up) + end + + try + function https_request_ip_through_proxy() + r = HTTP.get("https://$httpbin/ip"; proxy="http://localhost:8082", retry=false, status_exception=true) + String(r.body) + end + + @testset "Only one tunnel should be established with sequential requests" begin + https_request_ip_through_proxy() + https_request_ip_through_proxy() + @test connectcount == 1 + end + + @testset "parallell tunnels should be established with parallell requests" begin + n_asyncgetters = 3 + asyncgetters = [@async https_request_ip_through_proxy() for _ in 1:n_asyncgetters] + wait.(asyncgetters) + @test connectcount == n_asyncgetters + end + + finally + # Close pooled connections explicitly so the proxy handler can finish + # Connections.closeall never closes anything + close.(pooledconnections(HTTP.SOCKET_TYPE_TLS[])) + + HTTP.Connections.closeall() + close(proxy) + wait(proxy) + end +end + +end # module diff --git a/test/runtests.jl b/test/runtests.jl index a7e3ee3d5..cefcc1d1f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ include(joinpath(dir, "resources/TestRequest.jl")) "chunking.jl", "utils.jl", "client.jl", + "pool.jl", # "download.jl", "multipart.jl", "parsemultipart.jl",