Skip to content

Commit

Permalink
feat(client): Add connect timeout to HttpConnector (#1972)
Browse files Browse the repository at this point in the history
This takes the same strategy as golang, where the timeout value is
divided equally between the candidate socket addresses.

If happy eyeballs is enabled, the division takes place "below" the
IPv4/IPv6 partitioning.
  • Loading branch information
sfackler authored and seanmonstar committed Oct 14, 2019
1 parent 536b1e1 commit 4179297
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
4 changes: 4 additions & 0 deletions src/client/connect/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ impl IpAddrs {
pub(super) fn is_empty(&self) -> bool {
self.iter.as_slice().is_empty()
}

pub(super) fn len(&self) -> usize {
self.iter.as_slice().len()
}
}

impl Iterator for IpAddrs {
Expand Down
65 changes: 49 additions & 16 deletions src/client/connect/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use futures_util::{TryFutureExt, FutureExt};
use net2::TcpBuilder;
use tokio_net::driver::Handle;
use tokio_net::tcp::TcpStream;
use tokio_timer::Delay;
use tokio_timer::{Delay, Timeout};

use crate::common::{Future, Pin, Poll, task};
use super::{Connect, Connected, Destination};
Expand All @@ -32,6 +32,7 @@ type ConnectFuture = Pin<Box<dyn Future<Output = io::Result<TcpStream>> + Send>>
pub struct HttpConnector<R = GaiResolver> {
enforce_http: bool,
handle: Option<Handle>,
connect_timeout: Option<Duration>,
happy_eyeballs_timeout: Option<Duration>,
keep_alive_timeout: Option<Duration>,
local_address: Option<IpAddr>,
Expand Down Expand Up @@ -101,6 +102,7 @@ impl<R> HttpConnector<R> {
HttpConnector {
enforce_http: true,
handle: None,
connect_timeout: None,
happy_eyeballs_timeout: Some(Duration::from_millis(300)),
keep_alive_timeout: None,
local_address: None,
Expand Down Expand Up @@ -168,6 +170,17 @@ impl<R> HttpConnector<R> {
self.local_address = addr;
}

/// Set the connect timeout.
///
/// If a domain resolves to multiple IP addresses, the timeout will be
/// evenly divided across them.
///
/// Default is `None`.
#[inline]
pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
self.connect_timeout = dur;
}

/// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
///
/// If hostname resolves to both IPv4 and IPv6 addresses and connection
Expand Down Expand Up @@ -240,6 +253,7 @@ where
HttpConnecting {
state: State::Lazy(self.resolver.clone(), host.into(), self.local_address),
handle: self.handle.clone(),
connect_timeout: self.connect_timeout,
happy_eyeballs_timeout: self.happy_eyeballs_timeout,
keep_alive_timeout: self.keep_alive_timeout,
nodelay: self.nodelay,
Expand Down Expand Up @@ -295,6 +309,7 @@ where
let fut = HttpConnecting {
state: State::Lazy(self.resolver.clone(), host.into(), self.local_address),
handle: self.handle.clone(),
connect_timeout: self.connect_timeout,
happy_eyeballs_timeout: self.happy_eyeballs_timeout,
keep_alive_timeout: self.keep_alive_timeout,
nodelay: self.nodelay,
Expand Down Expand Up @@ -323,6 +338,7 @@ fn invalid_url<R: Resolve>(err: InvalidUrl, handle: &Option<Handle>) -> HttpConn
keep_alive_timeout: None,
nodelay: false,
port: 0,
connect_timeout: None,
happy_eyeballs_timeout: None,
reuse_address: false,
send_buffer_size: None,
Expand Down Expand Up @@ -357,6 +373,7 @@ impl StdError for InvalidUrl {
pub struct HttpConnecting<R: Resolve = GaiResolver> {
state: State<R>,
handle: Option<Handle>,
connect_timeout: Option<Duration>,
happy_eyeballs_timeout: Option<Duration>,
keep_alive_timeout: Option<Duration>,
nodelay: bool,
Expand Down Expand Up @@ -389,7 +406,7 @@ where
// skip resolving the dns and start connecting right away.
if let Some(addrs) = dns::IpAddrs::try_parse(host, me.port) {
state = State::Connecting(ConnectingTcp::new(
local_addr, addrs, me.happy_eyeballs_timeout, me.reuse_address));
local_addr, addrs, me.connect_timeout, me.happy_eyeballs_timeout, me.reuse_address));
} else {
let name = dns::Name::new(mem::replace(host, String::new()));
state = State::Resolving(resolver.resolve(name), local_addr);
Expand All @@ -403,7 +420,7 @@ where
.collect();
let addrs = dns::IpAddrs::new(addrs);
state = State::Connecting(ConnectingTcp::new(
local_addr, addrs, me.happy_eyeballs_timeout, me.reuse_address));
local_addr, addrs, me.connect_timeout, me.happy_eyeballs_timeout, me.reuse_address));
},
State::Connecting(ref mut c) => {
let sock = ready!(c.poll(cx, &me.handle))?;
Expand Down Expand Up @@ -454,6 +471,7 @@ impl ConnectingTcp {
fn new(
local_addr: Option<IpAddr>,
remote_addrs: dns::IpAddrs,
connect_timeout: Option<Duration>,
fallback_timeout: Option<Duration>,
reuse_address: bool,
) -> ConnectingTcp {
Expand All @@ -462,25 +480,25 @@ impl ConnectingTcp {
if fallback_addrs.is_empty() {
return ConnectingTcp {
local_addr,
preferred: ConnectingTcpRemote::new(preferred_addrs),
preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
fallback: None,
reuse_address,
};
}

ConnectingTcp {
local_addr,
preferred: ConnectingTcpRemote::new(preferred_addrs),
preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
fallback: Some(ConnectingTcpFallback {
delay: tokio_timer::delay_for(fallback_timeout),
remote: ConnectingTcpRemote::new(fallback_addrs),
remote: ConnectingTcpRemote::new(fallback_addrs, connect_timeout),
}),
reuse_address,
}
} else {
ConnectingTcp {
local_addr,
preferred: ConnectingTcpRemote::new(remote_addrs),
preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout),
fallback: None,
reuse_address,
}
Expand All @@ -495,13 +513,17 @@ struct ConnectingTcpFallback {

struct ConnectingTcpRemote {
addrs: dns::IpAddrs,
connect_timeout: Option<Duration>,
current: Option<ConnectFuture>,
}

impl ConnectingTcpRemote {
fn new(addrs: dns::IpAddrs) -> Self {
fn new(addrs: dns::IpAddrs, connect_timeout: Option<Duration>) -> Self {
let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));

Self {
addrs,
connect_timeout,
current: None,
}
}
Expand Down Expand Up @@ -530,14 +552,14 @@ impl ConnectingTcpRemote {
err = Some(e);
if let Some(addr) = self.addrs.next() {
debug!("connecting to {}", addr);
*current = connect(&addr, local_addr, handle, reuse_address)?;
*current = connect(&addr, local_addr, handle, reuse_address, self.connect_timeout)?;
continue;
}
}
}
} else if let Some(addr) = self.addrs.next() {
debug!("connecting to {}", addr);
self.current = Some(connect(&addr, local_addr, handle, reuse_address)?);
self.current = Some(connect(&addr, local_addr, handle, reuse_address, self.connect_timeout)?);
continue;
}

Expand All @@ -546,7 +568,13 @@ impl ConnectingTcpRemote {
}
}

fn connect(addr: &SocketAddr, local_addr: &Option<IpAddr>, handle: &Option<Handle>, reuse_address: bool) -> io::Result<ConnectFuture> {
fn connect(
addr: &SocketAddr,
local_addr: &Option<IpAddr>,
handle: &Option<Handle>,
reuse_address: bool,
connect_timeout: Option<Duration>,
) -> io::Result<ConnectFuture> {
let builder = match addr {
&SocketAddr::V4(_) => TcpBuilder::new_v4()?,
&SocketAddr::V6(_) => TcpBuilder::new_v6()?,
Expand Down Expand Up @@ -581,10 +609,16 @@ fn connect(addr: &SocketAddr, local_addr: &Option<IpAddr>, handle: &Option<Handl
let std_tcp = builder.to_tcp_stream()?;

Ok(Box::pin(async move {
TcpStream::connect_std(std_tcp, &addr, &handle).await
let connect = TcpStream::connect_std(std_tcp, &addr, &handle);
match connect_timeout {
Some(timeout) => match Timeout::new(connect, timeout).await {
Ok(Ok(s)) => Ok(s),
Ok(Err(e)) => Err(e),
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
}
None => connect.await,
}
}))

//Ok(Box::pin(TcpStream::connect_std(std_tcp, addr, &handle)))
}

impl ConnectingTcp {
Expand Down Expand Up @@ -673,7 +707,6 @@ mod tests {
})
}


#[test]
fn test_errors_missing_scheme() {
let mut rt = Runtime::new().unwrap();
Expand Down Expand Up @@ -765,7 +798,7 @@ mod tests {
}

let addrs = hosts.iter().map(|host| (host.clone(), addr.port()).into()).collect();
let connecting_tcp = ConnectingTcp::new(None, dns::IpAddrs::new(addrs), Some(fallback_timeout), false);
let connecting_tcp = ConnectingTcp::new(None, dns::IpAddrs::new(addrs), None, Some(fallback_timeout), false);
let fut = ConnectingTcpFuture(connecting_tcp);

let start = Instant::now();
Expand Down

0 comments on commit 4179297

Please sign in to comment.