Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to support parameterized OpenSSL.SSLStream #1079

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 13 additions & 46 deletions src/Connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,11 @@ function __init__()
# there was no artificial restriction on overall throughput
default_connection_limit[] = max(16, Threads.nthreads() * 4)
nosslcontext[] = OpenSSL.SSLContext(OpenSSL.TLSClientMethod())
TCP_POOL[] = CPool{Sockets.TCPSocket}(default_connection_limit[])
MBEDTLS_POOL[] = CPool{MbedTLS.SSLContext}(default_connection_limit[])
OPENSSL_POOL[] = CPool{OpenSSL.SSLStream}(default_connection_limit[])
return
end

function set_default_connection_limit!(n)
default_connection_limit[] = n
# reinitialize the global connection pools
TCP_POOL[] = CPool{Sockets.TCPSocket}(n)
MBEDTLS_POOL[] = CPool{MbedTLS.SSLContext}(n)
OPENSSL_POOL[] = CPool{OpenSSL.SSLStream}(n)
return
end

Expand Down Expand Up @@ -360,47 +353,27 @@ A pool can be passed to any of the `HTTP.request` methods via the `pool` keyword
"""
struct Pool
lock::ReentrantLock
tcp::CPool{Sockets.TCPSocket}
mbedtls::CPool{MbedTLS.SSLContext}
openssl::CPool{OpenSSL.SSLStream}
other::IdDict{Type, CPool}
pools::IdDict{Type, CPool}
max::Int
end

function Pool(max::Union{Int, Nothing}=nothing)
max = something(max, default_connection_limit[])
return Pool(ReentrantLock(),
CPool{Sockets.TCPSocket}(max),
CPool{MbedTLS.SSLContext}(max),
CPool{OpenSSL.SSLStream}(max),
IdDict{Type, CPool}(),
max,
)
end

# Default HTTP global connection pools
const TCP_POOL = Ref{CPool{Sockets.TCPSocket}}()
const MBEDTLS_POOL = Ref{CPool{MbedTLS.SSLContext}}()
const OPENSSL_POOL = Ref{CPool{OpenSSL.SSLStream}}()
const OTHER_POOL = Lockable(IdDict{Type, CPool}())
# Default HTTP global connection pool
const POOL = Lockable(IdDict{Type, CPool}())

getpool(::Nothing, ::Type{Sockets.TCPSocket}) = TCP_POOL[]
getpool(::Nothing, ::Type{MbedTLS.SSLContext}) = MBEDTLS_POOL[]
getpool(::Nothing, ::Type{OpenSSL.SSLStream}) = OPENSSL_POOL[]
getpool(::Nothing, ::Type{T}) where {T} = Base.@lock OTHER_POOL get!(OTHER_POOL[], T) do
getpool(::Nothing, ::Type{T}) where {T} = Base.@lock POOL get!(POOL[], T) do
CPool{T}(default_connection_limit[])
end

function getpool(pool::Pool, ::Type{T})::CPool{T} where {T}
if T === Sockets.TCPSocket
return pool.tcp
elseif T === MbedTLS.SSLContext
return pool.mbedtls
elseif T === OpenSSL.SSLStream
return pool.openssl
else
return Base.@lock pool.lock get!(() -> CPool{T}(pool.max), pool.other, T)
end
return Base.@lock pool.lock get!(() -> CPool{T}(pool.max), pool.pools, T)
end

"""
Expand All @@ -411,15 +384,9 @@ If `pool` is not specified, the default global pools are closed.
"""
function closeall(pool::Union{Nothing, Pool}=nothing)
if pool === nothing
drain!(TCP_POOL[])
drain!(MBEDTLS_POOL[])
drain!(OPENSSL_POOL[])
Base.@lock OTHER_POOL foreach(drain!, values(OTHER_POOL[]))
Base.@lock POOL foreach(drain!, values(POOL[]))
else
drain!(pool.tcp)
drain!(pool.mbedtls)
drain!(pool.openssl)
Base.@lock pool.lock foreach(drain!, values(pool.other))
Base.@lock pool.lock foreach(drain!, values(pool.pools))
end
return
end
Expand Down Expand Up @@ -570,20 +537,20 @@ function getconnection(::Type{SSLContext},
return sslconnection(SSLContext, tcp, host; kw...)
end

function getconnection(::Type{SSLStream},
function getconnection(::Type{SSLStream{T}},
host::AbstractString,
port::AbstractString;
kw...)::SSLStream
kw...)::SSLStream{T} where {T}
port = isempty(port) ? "443" : port
@debugv 2 "SSL connect: $host:$port..."
tcp = getconnection(TCPSocket, host, port; kw...)
return sslconnection(SSLStream, tcp, host; kw...)
tcp = getconnection(T, host, port; kw...)
return sslconnection(SSLStream{T}, tcp, host; kw...)
end

function sslconnection(::Type{SSLStream}, tcp::TCPSocket, host::AbstractString;
function sslconnection(::Type{SSLStream{T}}, tcp::T, host::AbstractString;
require_ssl_verification::Bool=NetworkOptions.verify_host(host, "SSL"),
sslconfig::OpenSSL.SSLContext=nosslcontext[],
kw...)::SSLStream
kw...)::SSLStream{T} where {T}
if sslconfig === nosslcontext[]
sslconfig = global_sslcontext()
end
Expand Down
6 changes: 3 additions & 3 deletions src/HTTP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end

function open end

const SOCKET_TYPE_TLS = Ref{Any}(OpenSSL.SSLStream)
const SOCKET_TYPE_TLS = Ref{Any}(OpenSSL.SSLStream{TCPSocket})

include("Conditions.jl") ;using .Conditions
include("access_log.jl")
Expand Down Expand Up @@ -190,8 +190,8 @@ SSL arguments:
["... peer must present a valid certificate, handshake is aborted if
verification failed."](https://tls.mbed.org/api/ssl_8h.html#a5695285c9dbfefec295012b566290f37)
- `sslconfig = SSLConfig(require_ssl_verification)`
- `socket_type_tls = MbedTLS.SSLContext`, the type of socket to use for TLS connections. Defaults to `MbedTLS.SSLContext`.
Also supported is passing `socket_type_tls = OpenSSL.SSLStream`. To change the global default, set `HTTP.SOCKET_TYPE_TLS[] = OpenSSL.SSLStream`.
- `socket_type_tls = OpenSSL.SSLStream{TCPSocket}`, the type of socket to use for TLS connections. Defaults to `OpenSSL.SSLStream{TCPSocket}`.
Also supported is passing `socket_type_tls = MbedTLS.SSLContext`. To change the global default, set `HTTP.SOCKET_TYPE_TLS[] = MbedTLS.SSLContext`.

Cookie arguments:
- `cookies::Union{Bool, Dict{<:AbstractString, <:AbstractString}} = true`, enable cookies, or alternatively,
Expand Down
2 changes: 1 addition & 1 deletion src/Servers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export listen, listen!, Server, forceclose, port
using Sockets, Logging, LoggingExtras, MbedTLS, Dates
using MbedTLS: SSLContext, SSLConfig
using ..IOExtras, ..Streams, ..Messages, ..Parsers, ..Connections, ..Exceptions
import ..access_threaded, ..SOCKET_TYPE_TLS, ..@logfmt_str
import ..access_threaded, ..@logfmt_str

TRUE(x) = true
getinet(host::String, port::Integer) = Sockets.InetAddr(parse(IPAddr, host), port)
Expand Down
6 changes: 2 additions & 4 deletions test/client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ using InteractiveUtils: @which
# test we can adjust default_connection_limit
for x in (10, 12)
HTTP.set_default_connection_limit!(x)
@test HTTP.Connections.TCP_POOL[].max == x
@test HTTP.Connections.MBEDTLS_POOL[].max == x
@test HTTP.Connections.OPENSSL_POOL[].max == x
@test HTTP.Connections.default_connection_limit[] == x
end

@testset "@client macro" begin
Expand All @@ -43,7 +41,7 @@ end
end
end

@testset "Client.jl" for tls in [MbedTLS.SSLContext, OpenSSL.SSLStream]
@testset "Client.jl" for tls in [MbedTLS.SSLContext, OpenSSL.SSLStream{TCPSocket}]
@testset "GET, HEAD, POST, PUT, DELETE, PATCH" begin
@test isok(HTTP.get("https://$httpbin/ip", socket_type_tls=tls))
@test isok(HTTP.head("https://$httpbin/ip", socket_type_tls=tls))
Expand Down