From fae97ced3a1f71fc46b6eadd3313e19705cc0006 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 2 Sep 2022 13:26:55 -0700 Subject: [PATCH] feat(rt): add Timer trait (#2974) This adds a `hyper::rt::Timer` trait, and it is used in connection builders to configure a custom timer source for timeouts. Co-authored-by: Robert Cunningham --- Cargo.toml | 1 - benches/support/mod.rs | 3 ++ benches/support/tokiort.rs | 66 +++++++++++++++++++++++++++++ src/client/conn/http1.rs | 14 +++--- src/client/conn/http2.rs | 30 +++++++++---- src/client/conn/mod.rs | 23 ++++++++-- src/common/mod.rs | 9 ++++ src/common/time.rs | 87 ++++++++++++++++++++++++++++++++++++++ src/proto/h1/conn.rs | 22 ++++++++-- src/proto/h1/dispatch.rs | 21 +++++---- src/proto/h1/io.rs | 20 +++------ src/proto/h1/mod.rs | 10 +++-- src/proto/h1/role.rs | 56 ++++++++++++++++++------ src/proto/h2/client.rs | 4 +- src/proto/h2/ping.rs | 62 +++++++++++++++------------ src/proto/h2/server.rs | 19 +++++++-- src/rt.rs | 23 ++++++++++ src/server/conn.rs | 68 +++++++++++++++++++++++++---- tests/client.rs | 26 ++++++++---- tests/server.rs | 16 +++++-- tests/support/mod.rs | 3 ++ tests/support/tokiort.rs | 1 + 22 files changed, 470 insertions(+), 114 deletions(-) create mode 100644 benches/support/mod.rs create mode 100644 benches/support/tokiort.rs create mode 100644 src/common/time.rs create mode 120000 tests/support/tokiort.rs diff --git a/Cargo.toml b/Cargo.toml index afbc8936f1..11d15ce3b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -96,7 +96,6 @@ server = [] runtime = [ "tokio/net", "tokio/rt", - "tokio/time", ] # C-API support (currently unstable (no semver)) diff --git a/benches/support/mod.rs b/benches/support/mod.rs new file mode 100644 index 0000000000..3718a8905f --- /dev/null +++ b/benches/support/mod.rs @@ -0,0 +1,3 @@ + +mod tokiort; +pub use tokiort::TokioTimer; \ No newline at end of file diff --git a/benches/support/tokiort.rs b/benches/support/tokiort.rs new file mode 100644 index 0000000000..9e4b924ee3 --- /dev/null +++ b/benches/support/tokiort.rs @@ -0,0 +1,66 @@ +#![allow(dead_code)] +//! Various runtimes for hyper +use std::{ + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use futures_util::Future; +use hyper::rt::{Sleep, Timer}; + +/// An Executor that uses the tokio runtime. +pub struct TokioExecutor; + +/// A Timer that uses the tokio runtime. + +#[derive(Clone, Debug)] +pub struct TokioTimer; + +impl Timer for TokioTimer { + fn sleep(&self, duration: Duration) -> Box { + let s = tokio::time::sleep(duration); + let hs = TokioSleep { inner: Box::pin(s) }; + return Box::new(hs); + } + + fn sleep_until(&self, deadline: Instant) -> Box { + return Box::new(TokioSleep { + inner: Box::pin(tokio::time::sleep_until(deadline.into())), + }); + } +} + +struct TokioTimeout { + inner: Pin>>, +} + +impl Future for TokioTimeout +where + T: Future, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll { + self.inner.as_mut().poll(context) + } +} + +// Use TokioSleep to get tokio::time::Sleep to implement Unpin. +// see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html +pub(crate) struct TokioSleep { + pub(crate) inner: Pin>, +} + +impl Future for TokioSleep { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.inner.as_mut().poll(cx) + } +} + +// Use HasSleep to get tokio::time::Sleep to implement Unpin. +// see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html + +impl Sleep for TokioSleep {} diff --git a/src/client/conn/http1.rs b/src/client/conn/http1.rs index e50567ce6b..7878031656 100644 --- a/src/client/conn/http1.rs +++ b/src/client/conn/http1.rs @@ -10,14 +10,14 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::Recv; use crate::body::Body; +use super::super::dispatch; use crate::common::{ exec::{BoxSendFuture, Exec}, task, Future, Pin, Poll, }; -use crate::upgrade::Upgraded; use crate::proto; -use crate::rt::Executor; -use super::super::dispatch; +use crate::rt::{Executor}; +use crate::upgrade::Upgraded; type Dispatcher = proto::dispatch::Dispatcher, B, T, proto::h1::ClientTransaction>; @@ -120,7 +120,10 @@ where /// before calling this method. /// - Since absolute-form `Uri`s are not required, if received, they will /// be serialized as-is. - pub fn send_request(&mut self, req: Request) -> impl Future>> { + pub fn send_request( + &mut self, + req: Request, + ) -> impl Future>> { let sent = self.dispatch.send(req); async move { @@ -130,7 +133,7 @@ where Ok(Err(err)) => Err(err), // this is definite bug if it happens, but it shouldn't happen! Err(_canceled) => panic!("dispatch dropped without returning error"), - } + }, Err(_req) => { tracing::debug!("connection was not ready"); @@ -476,4 +479,3 @@ impl Builder { } } } - diff --git a/src/client/conn/http2.rs b/src/client/conn/http2.rs index 50bcb20afd..f6f9cb3099 100644 --- a/src/client/conn/http2.rs +++ b/src/client/conn/http2.rs @@ -12,13 +12,14 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::Recv; use crate::body::Body; +use super::super::dispatch; +use crate::common::time::Time; use crate::common::{ exec::{BoxSendFuture, Exec}, task, Future, Pin, Poll, }; use crate::proto; -use crate::rt::Executor; -use super::super::dispatch; +use crate::rt::{Executor, Timer}; /// The sender side of an established connection. pub struct SendRequest { @@ -44,6 +45,7 @@ where #[derive(Clone, Debug)] pub struct Builder { pub(super) exec: Exec, + pub(super) timer: Time, h2_builder: proto::h2::client::Config, } @@ -114,7 +116,10 @@ where /// before calling this method. /// - Since absolute-form `Uri`s are not required, if received, they will /// be serialized as-is. - pub fn send_request(&mut self, req: Request) -> impl Future>> { + pub fn send_request( + &mut self, + req: Request, + ) -> impl Future>> { let sent = self.dispatch.send(req); async move { @@ -124,7 +129,7 @@ where Ok(Err(err)) => Err(err), // this is definite bug if it happens, but it shouldn't happen! Err(_canceled) => panic!("dispatch dropped without returning error"), - } + }, Err(_req) => { tracing::debug!("connection was not ready"); @@ -207,6 +212,7 @@ impl Builder { pub fn new() -> Builder { Builder { exec: Exec::Default, + timer: Time::Empty, h2_builder: Default::default(), } } @@ -220,6 +226,15 @@ impl Builder { self } + /// Provide a timer to execute background HTTP2 tasks. + pub fn timer(&mut self, timer: M) -> &mut Builder + where + M: Timer + Send + Sync + 'static, + { + self.timer = Time::Timer(Arc::new(timer)); + self + } + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 /// stream-level flow control. /// @@ -398,14 +413,13 @@ impl Builder { tracing::trace!("client handshake HTTP/1"); let (tx, rx) = dispatch::channel(); - let h2 = - proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec) - .await?; + let h2 = proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec, opts.timer) + .await?; Ok(( SendRequest { dispatch: tx.unbound() }, + //SendRequest { dispatch: tx }, Connection { inner: (PhantomData, h2) }, )) } } } - diff --git a/src/client/conn/mod.rs b/src/client/conn/mod.rs index 159a24369d..ae5ce15b71 100644 --- a/src/client/conn/mod.rs +++ b/src/client/conn/mod.rs @@ -85,6 +85,7 @@ use crate::rt::Executor; #[cfg(feature = "http1")] use crate::upgrade::Upgraded; use crate::{Recv, Request, Response}; +use crate::{common::time::Time, rt::Timer}; #[cfg(feature = "http1")] pub mod http1; @@ -161,6 +162,7 @@ where #[derive(Clone, Debug)] pub struct Builder { pub(super) exec: Exec, + pub(super) timer: Time, h09_responses: bool, h1_parser_config: ParserConfig, h1_writev: Option, @@ -418,6 +420,7 @@ impl Builder { pub fn new() -> Builder { Builder { exec: Exec::Default, + timer: Time::Empty, h09_responses: false, h1_writev: None, h1_read_buf_exact_size: None, @@ -447,6 +450,15 @@ impl Builder { self } + /// Provide a timer to execute background HTTP2 tasks. + pub fn timer(&mut self, timer: M) -> &mut Builder + where + M: Timer + Send + Sync + 'static, + { + self.timer = Time::Timer(Arc::new(timer)); + self + } + /// Set whether HTTP/0.9 responses should be tolerated. /// /// Default is false. @@ -857,9 +869,14 @@ impl Builder { } #[cfg(feature = "http2")] Proto::Http2 => { - let h2 = - proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec.clone()) - .await?; + let h2 = proto::h2::client::handshake( + io, + rx, + &opts.h2_builder, + opts.exec.clone(), + opts.timer.clone(), + ) + .await?; ProtoClient::H2 { h2 } } }; diff --git a/src/common/mod.rs b/src/common/mod.rs index bc1781e832..0a3c65eeb0 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -15,6 +15,8 @@ pub(crate) mod exec; pub(crate) mod io; mod never; pub(crate) mod task; +#[cfg(any(feature = "http1", feature = "http2", feature = "server"))] +pub(crate) mod time; pub(crate) mod watch; #[cfg(any(feature = "http1", feature = "http2", feature = "runtime"))] @@ -26,3 +28,10 @@ cfg_proto! { pub(crate) use std::marker::Unpin; } pub(crate) use std::{future::Future, pin::Pin}; + +pub(crate) fn into_pin(boxed: Box) -> Pin> { + // It's not possible to move or replace the insides of a `Pin>` + // when `T: !Unpin`, so it's safe to pin it directly without any + // additional requirements. + unsafe { Pin::new_unchecked(boxed) } +} diff --git a/src/common/time.rs b/src/common/time.rs new file mode 100644 index 0000000000..a26cf6e3cd --- /dev/null +++ b/src/common/time.rs @@ -0,0 +1,87 @@ +use std::{fmt, sync::Arc}; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::{ + pin::Pin, + time::{Duration, Instant}, +}; + +#[cfg(all(feature = "server", feature = "runtime"))] +use crate::rt::Sleep; +use crate::rt::Timer; + +/// A user-provided timer to time background tasks. +#[derive(Clone)] +pub(crate) enum Time { + Timer(Arc), + Empty, +} + +impl fmt::Debug for Time { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Time").finish() + } +} + +/* +pub(crate) fn timeout(tim: Tim, future: F, duration: Duration) -> HyperTimeout { + HyperTimeout { sleep: tim.sleep(duration), future: future } +} + +pin_project_lite::pin_project! { + pub(crate) struct HyperTimeout { + sleep: Box, + #[pin] + future: F + } +} + +pub(crate) struct Timeout; + +impl Future for HyperTimeout where F: Future { + + type Output = Result; + + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll{ + let mut this = self.project(); + if let Poll::Ready(v) = this.future.poll(ctx) { + return Poll::Ready(Ok(v)); + } + + if let Poll::Ready(_) = Pin::new(&mut this.sleep).poll(ctx) { + return Poll::Ready(Err(Timeout)); + } + + return Poll::Pending; + } +} +*/ + +#[cfg(all(feature = "server", feature = "runtime"))] +impl Time { + pub(crate) fn sleep(&self, duration: Duration) -> Box { + match *self { + Time::Empty => { + panic!("You must supply a timer.") + } + Time::Timer(ref t) => t.sleep(duration), + } + } + + pub(crate) fn sleep_until(&self, deadline: Instant) -> Box { + match *self { + Time::Empty => { + panic!("You must supply a timer.") + } + Time::Timer(ref t) => t.sleep_until(deadline), + } + } + + pub(crate) fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { + match *self { + Time::Empty => { + panic!("You must supply a timer.") + } + Time::Timer(ref t) => t.reset(sleep, new_deadline), + } + } +} diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 2db8380c4c..b57e6a8918 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -9,16 +9,18 @@ use http::header::{HeaderValue, CONNECTION}; use http::{HeaderMap, Method, Version}; use httparse::ParserConfig; use tokio::io::{AsyncRead, AsyncWrite}; -#[cfg(all(feature = "server", feature = "runtime"))] -use tokio::time::Sleep; use tracing::{debug, error, trace}; use super::io::Buffered; use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; use crate::body::DecodedLength; +#[cfg(all(feature = "server", feature = "runtime"))] +use crate::common::time::Time; use crate::common::{task, Pin, Poll, Unpin}; use crate::headers::connection_keep_alive; use crate::proto::{BodyLength, MessageHead}; +#[cfg(all(feature = "server", feature = "runtime"))] +use crate::rt::Sleep; const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; @@ -57,6 +59,8 @@ where h1_header_read_timeout_fut: None, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout_running: false, + #[cfg(all(feature = "server", feature = "runtime"))] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -78,6 +82,11 @@ where } } + #[cfg(all(feature = "server", feature = "runtime"))] + pub(crate) fn set_timer(&mut self, timer: Time) { + self.state.timer = timer; + } + #[cfg(feature = "server")] pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) { self.io.set_flush_pipeline(enabled); @@ -202,6 +211,8 @@ where h1_header_read_timeout_fut: &mut self.state.h1_header_read_timeout_fut, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout_running: &mut self.state.h1_header_read_timeout_running, + #[cfg(all(feature = "server", feature = "runtime"))] + timer: self.state.timer.clone(), preserve_header_case: self.state.preserve_header_case, #[cfg(feature = "ffi")] preserve_header_order: self.state.preserve_header_order, @@ -802,9 +813,11 @@ struct State { #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout: Option, #[cfg(all(feature = "server", feature = "runtime"))] - h1_header_read_timeout_fut: Option>>, + h1_header_read_timeout_fut: Option>>, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout_running: bool, + #[cfg(all(feature = "server", feature = "runtime"))] + timer: Time, preserve_header_case: bool, #[cfg(feature = "ffi")] preserve_header_order: bool, @@ -1035,7 +1048,8 @@ mod tests { // an empty IO, we'll be skipping and using the read buffer anyways let io = tokio_test::io::Builder::new().build(); - let mut conn = Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io); + let mut conn = + Conn::<_, bytes::Bytes, crate::proto::h1::ServerTransaction>::new(io); *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]); conn.state.cached_headers = Some(HeaderMap::with_capacity(2)); diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 38061421cf..5c75f302fa 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -58,10 +58,10 @@ cfg_client! { impl Dispatcher where D: Dispatch< - PollItem = MessageHead, - PollBody = Bs, - RecvItem = MessageHead, - > + Unpin, + PollItem = MessageHead, + PollBody = Bs, + RecvItem = MessageHead, + > + Unpin, D::PollError: Into>, I: AsyncRead + AsyncWrite + Unpin, T: Http1Transaction + Unpin, @@ -256,7 +256,10 @@ where if wants.contains(Wants::UPGRADE) { let upgrade = self.conn.on_upgrade(); debug_assert!(!upgrade.is_none(), "empty upgrade"); - debug_assert!(head.extensions.get::().is_none(), "OnUpgrade already set"); + debug_assert!( + head.extensions.get::().is_none(), + "OnUpgrade already set" + ); head.extensions.insert(upgrade); } self.dispatch.recv_msg(Ok((head, body)))?; @@ -396,10 +399,10 @@ where impl Future for Dispatcher where D: Dispatch< - PollItem = MessageHead, - PollBody = Bs, - RecvItem = MessageHead, - > + Unpin, + PollItem = MessageHead, + PollBody = Bs, + RecvItem = MessageHead, + > + Unpin, D::PollError: Into>, I: AsyncRead + AsyncWrite + Unpin, T: Http1Transaction + Unpin, diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 1d251e2c84..caf76d921f 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -5,13 +5,9 @@ use std::future::Future; use std::io::{self, IoSlice}; use std::marker::Unpin; use std::mem::MaybeUninit; -#[cfg(all(feature = "server", feature = "runtime"))] -use std::time::Duration; use bytes::{Buf, BufMut, Bytes, BytesMut}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -#[cfg(all(feature = "server", feature = "runtime"))] -use tokio::time::Instant; use tracing::{debug, trace}; use super::{Http1Transaction, ParseContext, ParsedMessage}; @@ -193,6 +189,8 @@ where h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running, + #[cfg(all(feature = "server", feature = "runtime"))] + timer: parse_ctx.timer.clone(), preserve_header_case: parse_ctx.preserve_header_case, #[cfg(feature = "ffi")] preserve_header_order: parse_ctx.preserve_header_order, @@ -209,15 +207,7 @@ where #[cfg(all(feature = "server", feature = "runtime"))] { *parse_ctx.h1_header_read_timeout_running = false; - - if let Some(h1_header_read_timeout_fut) = - parse_ctx.h1_header_read_timeout_fut - { - // Reset the timer in order to avoid woken up when the timeout finishes - h1_header_read_timeout_fut - .as_mut() - .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60)); - } + parse_ctx.h1_header_read_timeout_fut.take(); } return Poll::Ready(Ok(msg)); } @@ -674,6 +664,8 @@ enum WriteStrategy { #[cfg(test)] mod tests { + use crate::common::time::Time; + use super::*; use std::time::Duration; @@ -741,6 +733,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 5a2587a843..03b4ea28b3 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -4,11 +4,13 @@ use std::{pin::Pin, time::Duration}; use bytes::BytesMut; use http::{HeaderMap, Method}; use httparse::ParserConfig; -#[cfg(all(feature = "server", feature = "runtime"))] -use tokio::time::Sleep; use crate::body::DecodedLength; +#[cfg(all(feature = "server", feature = "runtime"))] +use crate::common::time::Time; use crate::proto::{BodyLength, MessageHead}; +#[cfg(all(feature = "server", feature = "runtime"))] +use crate::rt::Sleep; pub(crate) use self::conn::Conn; pub(crate) use self::decode::Decoder; @@ -79,9 +81,11 @@ pub(crate) struct ParseContext<'a> { #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout: Option, #[cfg(all(feature = "server", feature = "runtime"))] - h1_header_read_timeout_fut: &'a mut Option>>, + h1_header_read_timeout_fut: &'a mut Option>>, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout_running: &'a mut bool, + #[cfg(all(feature = "server", feature = "runtime"))] + timer: Time, preserve_header_case: bool, #[cfg(feature = "ffi")] preserve_header_order: bool, diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 408df4effb..ad95a0d012 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -1,5 +1,7 @@ use std::fmt::{self, Write}; use std::mem::MaybeUninit; +#[cfg(all(feature = "server", feature = "runtime"))] +use std::time::Instant; use bytes::Bytes; use bytes::BytesMut; @@ -7,8 +9,6 @@ use bytes::BytesMut; use http::header::ValueIter; use http::header::{self, Entry, HeaderName, HeaderValue}; use http::{HeaderMap, Method, StatusCode, Version}; -#[cfg(all(feature = "server", feature = "runtime"))] -use tokio::time::Instant; use tracing::{debug, error, trace, trace_span, warn}; use crate::body::DecodedLength; @@ -83,12 +83,12 @@ where match ctx.h1_header_read_timeout_fut { Some(h1_header_read_timeout_fut) => { debug!("resetting h1 header read timeout timer"); - h1_header_read_timeout_fut.as_mut().reset(deadline); + ctx.timer.reset(h1_header_read_timeout_fut, deadline); } None => { debug!("setting h1 header read timeout timer"); *ctx.h1_header_read_timeout_fut = - Some(Box::pin(tokio::time::sleep_until(deadline))); + Some(crate::common::into_pin(ctx.timer.sleep_until(deadline))); } } } @@ -994,7 +994,6 @@ impl Http1Transaction for Client { // SAFETY: array is valid up to `headers_len` let header = unsafe { &mut *header.as_mut_ptr() }; Client::obs_fold_line(&mut slice, header); - } } @@ -1566,6 +1565,9 @@ fn extend(dst: &mut Vec, data: &[u8]) { mod tests { use bytes::BytesMut; + #[cfg(feature = "runtime")] + use crate::common::time::Time; + use super::*; #[test] @@ -1585,6 +1587,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1620,6 +1624,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1650,6 +1656,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1678,6 +1686,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1708,6 +1718,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1742,6 +1754,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1773,6 +1787,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1799,6 +1815,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: true, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1846,6 +1864,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -1874,6 +1894,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2111,6 +2133,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2139,6 +2163,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2167,6 +2193,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2450,18 +2478,12 @@ mod tests { value: (0, buf.len()), }; Client::obs_fold_line(&mut buf, &mut idx); - String::from_utf8(buf[idx.value.0 .. idx.value.1].to_vec()).unwrap() + String::from_utf8(buf[idx.value.0..idx.value.1].to_vec()).unwrap() } - assert_eq!( - unfold("a normal line"), - "a normal line", - ); + assert_eq!(unfold("a normal line"), "a normal line",); - assert_eq!( - unfold("obs\r\n fold\r\n\t line"), - "obs fold line", - ); + assert_eq!(unfold("obs\r\n fold\r\n\t line"), "obs fold line",); } #[test] @@ -2696,6 +2718,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2788,6 +2812,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, @@ -2836,6 +2862,8 @@ mod tests { h1_header_read_timeout_fut: &mut None, #[cfg(feature = "runtime")] h1_header_read_timeout_running: &mut false, + #[cfg(feature = "runtime")] + timer: Time::Empty, preserve_header_case: false, #[cfg(feature = "ffi")] preserve_header_order: false, diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index e2252955c4..e2032af4cb 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -13,6 +13,7 @@ use tracing::{debug, trace, warn}; use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; use crate::body::Body; +use crate::common::time::Time; use crate::common::{exec::Exec, task, Future, Never, Pin, Poll}; use crate::ext::Protocol; use crate::headers; @@ -109,6 +110,7 @@ pub(crate) async fn handshake( req_rx: ClientRx, config: &Config, exec: Exec, + timer: Time, ) -> crate::Result> where T: AsyncRead + AsyncWrite + Send + Unpin + 'static, @@ -137,7 +139,7 @@ where let (conn, ping) = if ping_config.is_enabled() { let pp = conn.ping_pong().expect("conn.ping_pong"); - let (recorder, mut ponger) = ping::channel(pp, ping_config); + let (recorder, mut ponger) = ping::channel(pp, ping_config, timer); let conn = future::poll_fn(move |cx| { match ponger.poll(cx) { diff --git a/src/proto/h2/ping.rs b/src/proto/h2/ping.rs index 1e8386497c..22257f9c41 100644 --- a/src/proto/h2/ping.rs +++ b/src/proto/h2/ping.rs @@ -27,22 +27,24 @@ use std::future::Future; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{self, Poll}; -use std::time::Duration; -#[cfg(not(feature = "runtime"))] -use std::time::Instant; +use std::time::{Duration, Instant}; + use h2::{Ping, PingPong}; -#[cfg(feature = "runtime")] -use tokio::time::{Instant, Sleep}; use tracing::{debug, trace}; +#[cfg_attr(not(feature = "runtime"), allow(unused))] +use crate::common::time::Time; +#[cfg_attr(not(feature = "runtime"), allow(unused))] +use crate::rt::Sleep; + type WindowSize = u32; pub(super) fn disabled() -> Recorder { Recorder { shared: None } } -pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) { +pub(super) fn channel(ping_pong: PingPong, config: Config, __timer: Time) -> (Recorder, Ponger) { debug_assert!( config.is_enabled(), "ping channel requires bdp or keep-alive config", @@ -67,8 +69,9 @@ pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) interval, timeout: config.keep_alive_timeout, while_idle: config.keep_alive_while_idle, - timer: Box::pin(tokio::time::sleep(interval)), + sleep: crate::common::into_pin(__timer.sleep(interval)), state: KeepAliveState::Init, + timer: __timer, }); #[cfg(feature = "runtime")] @@ -173,13 +176,14 @@ struct KeepAlive { while_idle: bool, state: KeepAliveState, - timer: Pin>, + sleep: Pin>, + timer: Time, } #[cfg(feature = "runtime")] enum KeepAliveState { Init, - Scheduled, + Scheduled(Instant), PingSent, } @@ -301,7 +305,7 @@ impl Ponger { #[cfg(feature = "runtime")] { if let Some(ref mut ka) = self.keep_alive { - ka.schedule(is_idle, &locked); + ka.maybe_schedule(is_idle, &locked); ka.maybe_ping(cx, &mut locked); } } @@ -324,11 +328,12 @@ impl Ponger { { if let Some(ref mut ka) = self.keep_alive { locked.update_last_read_at(); - ka.schedule(is_idle, &locked); + ka.maybe_schedule(is_idle, &locked); + ka.maybe_ping(cx, &mut locked); } } - if let Some(ref mut bdp) = self.bdp { + if let Some(ref mut bdp) = self.bdp { let bytes = locked.bytes.expect("bdp enabled implies bytes"); locked.bytes = Some(0); // reset trace!("received BDP ack; bytes = {}, rtt = {:?}", bytes, rtt); @@ -336,7 +341,7 @@ impl Ponger { let update = bdp.calculate(bytes, rtt); locked.next_bdp_at = Some(now + bdp.ping_delay); if let Some(update) = update { - return Poll::Ready(Ponged::SizeUpdate(update)) + return Poll::Ready(Ponged::SizeUpdate(update)); } } } @@ -471,38 +476,39 @@ fn seconds(dur: Duration) -> f64 { #[cfg(feature = "runtime")] impl KeepAlive { - fn schedule(&mut self, is_idle: bool, shared: &Shared) { + fn maybe_schedule(&mut self, is_idle: bool, shared: &Shared) { match self.state { KeepAliveState::Init => { if !self.while_idle && is_idle { return; } - self.state = KeepAliveState::Scheduled; - let interval = shared.last_read_at() + self.interval; - self.timer.as_mut().reset(interval); + self.schedule(shared); } KeepAliveState::PingSent => { if shared.is_ping_sent() { return; } - - self.state = KeepAliveState::Scheduled; - let interval = shared.last_read_at() + self.interval; - self.timer.as_mut().reset(interval); + self.schedule(shared); } - KeepAliveState::Scheduled => (), + KeepAliveState::Scheduled(..) => (), } } + fn schedule(&mut self, shared: &Shared) { + let interval = shared.last_read_at() + self.interval; + self.state = KeepAliveState::Scheduled(interval); + self.timer.reset(&mut self.sleep, interval); + } + fn maybe_ping(&mut self, cx: &mut task::Context<'_>, shared: &mut Shared) { match self.state { - KeepAliveState::Scheduled => { - if Pin::new(&mut self.timer).poll(cx).is_pending() { + KeepAliveState::Scheduled(at) => { + if Pin::new(&mut self.sleep).poll(cx).is_pending() { return; } // check if we've received a frame while we were scheduled - if shared.last_read_at() + self.interval > self.timer.deadline() { + if shared.last_read_at() + self.interval > at { self.state = KeepAliveState::Init; cx.waker().wake_by_ref(); // schedule us again return; @@ -511,7 +517,7 @@ impl KeepAlive { shared.send_ping(); self.state = KeepAliveState::PingSent; let timeout = Instant::now() + self.timeout; - self.timer.as_mut().reset(timeout); + self.timer.reset(&mut self.sleep, timeout); } KeepAliveState::Init | KeepAliveState::PingSent => (), } @@ -520,13 +526,13 @@ impl KeepAlive { fn maybe_timeout(&mut self, cx: &mut task::Context<'_>) -> Result<(), KeepAliveTimedOut> { match self.state { KeepAliveState::PingSent => { - if Pin::new(&mut self.timer).poll(cx).is_pending() { + if Pin::new(&mut self.sleep).poll(cx).is_pending() { return Ok(()); } trace!("keep-alive timeout ({:?}) reached", self.timeout); Err(KeepAliveTimedOut) } - KeepAliveState::Init | KeepAliveState::Scheduled => Ok(()), + KeepAliveState::Init | KeepAliveState::Scheduled(..) => Ok(()), } } } diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index 20867e62ba..f2c2e7d763 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -14,6 +14,7 @@ use tracing::{debug, trace, warn}; use super::{ping, PipeToSendStream, SendBuf}; use crate::body::Body; use crate::common::exec::ConnStreamExec; +use crate::common::time::Time; use crate::common::{date, task, Future, Pin, Poll}; use crate::ext::Protocol; use crate::headers; @@ -35,7 +36,7 @@ const DEFAULT_CONN_WINDOW: u32 = 1024 * 1024; // 1mb const DEFAULT_STREAM_WINDOW: u32 = 1024 * 1024; // 1mb const DEFAULT_MAX_FRAME_SIZE: u32 = 1024 * 16; // 16kb const DEFAULT_MAX_SEND_BUF_SIZE: usize = 1024 * 400; // 400kb -// 16 MB "sane default" taken from golang http2 + // 16 MB "sane default" taken from golang http2 const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: u32 = 16 << 20; #[derive(Clone, Debug)] @@ -80,6 +81,7 @@ pin_project! { B: Body, { exec: E, + timer: Time, service: S, state: State, } @@ -114,7 +116,13 @@ where B: Body + 'static, E: ConnStreamExec, { - pub(crate) fn new(io: T, service: S, config: &Config, exec: E) -> Server { + pub(crate) fn new( + io: T, + service: S, + config: &Config, + exec: E, + timer: Time, + ) -> Server { let mut builder = h2::server::Builder::default(); builder .initial_window_size(config.initial_stream_window_size) @@ -150,6 +158,7 @@ where Server { exec, + timer, state: State::Handshaking { ping_config, hs: handshake, @@ -199,7 +208,11 @@ where let mut conn = ready!(Pin::new(hs).poll(cx).map_err(crate::Error::new_h2))?; let ping = if ping_config.is_enabled() { let pp = conn.ping_pong().expect("conn.ping_pong"); - Some(ping::channel(pp, ping_config.clone())) + Some(ping::channel( + pp, + ping_config.clone(), + me.timer.clone(), + )) } else { None }; diff --git a/src/rt.rs b/src/rt.rs index 2614b59112..9998980670 100644 --- a/src/rt.rs +++ b/src/rt.rs @@ -5,8 +5,31 @@ //! If the `runtime` feature is disabled, the types in this module can be used //! to plug in other runtimes. +use std::{ + future::Future, + pin::Pin, + time::{Duration, Instant}, +}; + /// An executor of futures. pub trait Executor { /// Place the future into the executor to be run. fn execute(&self, fut: Fut); } + +/// A timer which provides timer-like functions. +pub trait Timer { + /// Return a future that resolves in `duration` time. + fn sleep(&self, duration: Duration) -> Box; + + /// Return a future that resolves at `deadline`. + fn sleep_until(&self, deadline: Instant) -> Box; + + /// Reset a future to resolve at `new_deadline` instead. + fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { + *sleep = crate::common::into_pin(self.sleep_until(new_deadline)); + } +} + +/// A future returned by a `Timer`. +pub trait Sleep: Send + Sync + Unpin + Future {} diff --git a/src/server/conn.rs b/src/server/conn.rs index 2ab4ce336c..f7d9a90784 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -46,6 +46,7 @@ not(all(feature = "http1", feature = "http2")) ))] use std::marker::PhantomData; +use std::sync::Arc; #[cfg(all(any(feature = "http1", feature = "http2"), feature = "runtime"))] use std::time::Duration; @@ -55,6 +56,7 @@ use crate::common::io::Rewind; use crate::error::{Kind, Parse}; #[cfg(feature = "http1")] use crate::upgrade::Upgraded; +use crate::{common::time::Time, rt::Timer}; cfg_feature! { #![any(feature = "http1", feature = "http2")] @@ -86,6 +88,7 @@ cfg_feature! { #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] pub struct Http { pub(crate) exec: E, + pub(crate) timer: Time, h1_half_close: bool, h1_keep_alive: bool, h1_title_case_headers: bool, @@ -169,7 +172,7 @@ pin_project! { #[cfg(all(feature = "http1", feature = "http2"))] #[derive(Clone, Debug)] enum Fallback { - ToHttp2(proto::h2::server::Config, E), + ToHttp2(proto::h2::server::Config, E, Time), Http1Only, } @@ -225,6 +228,7 @@ impl Http { pub fn new() -> Http { Http { exec: Exec::Default, + timer: Time::Empty, h1_half_close: false, h1_keep_alive: true, h1_title_case_headers: false, @@ -554,6 +558,30 @@ impl Http { pub fn with_executor(self, exec: E2) -> Http { Http { exec, + timer: self.timer, + h1_half_close: self.h1_half_close, + h1_keep_alive: self.h1_keep_alive, + h1_title_case_headers: self.h1_title_case_headers, + h1_preserve_header_case: self.h1_preserve_header_case, + #[cfg(all(feature = "http1", feature = "runtime"))] + h1_header_read_timeout: self.h1_header_read_timeout, + h1_writev: self.h1_writev, + #[cfg(feature = "http2")] + h2_builder: self.h2_builder, + mode: self.mode, + max_buf_size: self.max_buf_size, + pipeline_flush: self.pipeline_flush, + } + } + + /// Set the timer used in background tasks. + pub fn with_timer(self, timer: M) -> Http + where + M: Timer + Send + Sync + 'static, + { + Http { + exec: self.exec, + timer: Time::Timer(Arc::new(timer)), h1_half_close: self.h1_half_close, h1_keep_alive: self.h1_keep_alive, h1_title_case_headers: self.h1_title_case_headers, @@ -610,6 +638,10 @@ impl Http { macro_rules! h1 { () => {{ let mut conn = proto::Conn::new(io); + #[cfg(feature = "runtime")] + { + conn.set_timer(self.timer.clone()); + } if !self.h1_keep_alive { conn.disable_keep_alive(); } @@ -654,8 +686,13 @@ impl Http { #[cfg(feature = "http2")] ConnectionMode::H2Only => { let rewind_io = Rewind::new(io); - let h2 = - proto::h2::Server::new(rewind_io, service, &self.h2_builder, self.exec.clone()); + let h2 = proto::h2::Server::new( + rewind_io, + service, + &self.h2_builder, + self.exec.clone(), + self.timer.clone(), + ); ProtoServer::H2 { h2 } } }; @@ -664,7 +701,11 @@ impl Http { conn: Some(proto), #[cfg(all(feature = "http1", feature = "http2"))] fallback: if self.mode == ConnectionMode::Fallback { - Fallback::ToHttp2(self.h2_builder.clone(), self.exec.clone()) + Fallback::ToHttp2( + self.h2_builder.clone(), + self.exec.clone(), + self.timer.clone(), + ) } else { Fallback::Http1Only }, @@ -808,7 +849,12 @@ where let mut conn = Some(self); futures_util::future::poll_fn(move |cx| { ready!(conn.as_mut().unwrap().poll_without_shutdown(cx))?; - Poll::Ready(conn.take().unwrap().try_into_parts().ok_or_else(crate::Error::new_without_shutdown_not_h1)) + Poll::Ready( + conn.take() + .unwrap() + .try_into_parts() + .ok_or_else(crate::Error::new_without_shutdown_not_h1), + ) }) } @@ -825,11 +871,17 @@ where }; let mut rewind_io = Rewind::new(io); rewind_io.rewind(read_buf); - let (builder, exec) = match self.fallback { - Fallback::ToHttp2(ref builder, ref exec) => (builder, exec), + let (builder, exec, timer) = match self.fallback { + Fallback::ToHttp2(ref builder, ref exec, ref timer) => (builder, exec, timer), Fallback::Http1Only => unreachable!("upgrade_h2 with Fallback::Http1Only"), }; - let h2 = proto::h2::Server::new(rewind_io, dispatch.into_service(), builder, exec.clone()); + let h2 = proto::h2::Server::new( + rewind_io, + dispatch.into_service(), + builder, + exec.clone(), + timer.clone(), + ); debug_assert!(self.conn.is_none()); self.conn = Some(ProtoServer::H2 { h2 }); diff --git a/tests/client.rs b/tests/client.rs index 4016361b4d..70b5a3e38c 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1337,6 +1337,7 @@ mod conn { use futures_channel::{mpsc, oneshot}; use futures_util::future::{self, poll_fn, FutureExt, TryFutureExt}; use http_body_util::{Empty, StreamBody}; + use hyper::rt::Timer; use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, ReadBuf}; use tokio::net::{TcpListener as TkTcpListener, TcpStream}; @@ -1347,6 +1348,8 @@ mod conn { use super::{concat, s, support, tcp_connect, FutureHyperExt}; + use support::TokioTimer; + #[tokio::test] async fn get() { let _ = ::pretty_env_logger::try_init(); @@ -1491,7 +1494,7 @@ mod conn { }); let rx = rx1.expect("thread panicked"); - let rx = rx.then(|_| tokio::time::sleep(Duration::from_millis(200))); + let rx = rx.then(|_| TokioTimer.sleep(Duration::from_millis(200))); let chunk = rt.block_on(future::join(res, rx).map(|r| r.0)).unwrap(); assert_eq!(chunk.len(), 5); } @@ -1592,7 +1595,7 @@ mod conn { concat(res) }); let rx = rx1.expect("thread panicked"); - let rx = rx.then(|_| tokio::time::sleep(Duration::from_millis(200))); + let rx = rx.then(|_| TokioTimer.sleep(Duration::from_millis(200))); rt.block_on(future::join(res, rx).map(|r| r.0)).unwrap(); } @@ -1638,7 +1641,7 @@ mod conn { concat(res) }); let rx = rx1.expect("thread panicked"); - let rx = rx.then(|_| tokio::time::sleep(Duration::from_millis(200))); + let rx = rx.then(|_| TokioTimer.sleep(Duration::from_millis(200))); rt.block_on(future::join(res, rx).map(|r| r.0)).unwrap(); } @@ -1690,7 +1693,7 @@ mod conn { }); let rx = rx1.expect("thread panicked"); - let rx = rx.then(|_| tokio::time::sleep(Duration::from_millis(200))); + let rx = rx.then(|_| TokioTimer.sleep(Duration::from_millis(200))); rt.block_on(future::join3(res1, res2, rx).map(|r| r.0)) .unwrap(); } @@ -1751,7 +1754,7 @@ mod conn { }); let rx = rx1.expect("thread panicked"); - let rx = rx.then(|_| tokio::time::sleep(Duration::from_millis(200))); + let rx = rx.then(|_| TokioTimer.sleep(Duration::from_millis(200))); rt.block_on(future::join3(until_upgrade, res, rx).map(|r| r.0)) .unwrap(); @@ -1842,7 +1845,7 @@ mod conn { }); let rx = rx1.expect("thread panicked"); - let rx = rx.then(|_| tokio::time::sleep(Duration::from_millis(200))); + let rx = rx.then(|_| TokioTimer.sleep(Duration::from_millis(200))); rt.block_on(future::join3(until_tunneled, res, rx).map(|r| r.0)) .unwrap(); @@ -1950,7 +1953,7 @@ mod conn { let _ = shdn_tx.send(true); // Allow time for graceful shutdown roundtrips... - tokio::time::sleep(Duration::from_millis(100)).await; + TokioTimer.sleep(Duration::from_millis(100)).await; // After graceful shutdown roundtrips, the client should be closed... future::poll_fn(|ctx| client.poll_ready(ctx)) @@ -1982,6 +1985,7 @@ mod conn { let io = tcp_connect(&addr).await.expect("tcp connect"); let (_client, conn) = conn::Builder::new() + .timer(TokioTimer) .http2_only(true) .http2_keep_alive_interval(Duration::from_secs(1)) .http2_keep_alive_timeout(Duration::from_secs(1)) @@ -2015,6 +2019,7 @@ mod conn { let io = tcp_connect(&addr).await.expect("tcp connect"); let (mut client, conn) = conn::Builder::new() + .timer(TokioTimer) .http2_only(true) .http2_keep_alive_interval(Duration::from_secs(1)) .http2_keep_alive_timeout(Duration::from_secs(1)) @@ -2027,7 +2032,7 @@ mod conn { }); // sleep longer than keepalive would trigger - tokio::time::sleep(Duration::from_secs(4)).await; + TokioTimer.sleep(Duration::from_secs(4)).await; future::poll_fn(|ctx| client.poll_ready(ctx)) .await @@ -2051,6 +2056,7 @@ mod conn { let io = tcp_connect(&addr).await.expect("tcp connect"); let (mut client, conn) = conn::Builder::new() + .timer(TokioTimer) .http2_only(true) .http2_keep_alive_interval(Duration::from_secs(1)) .http2_keep_alive_timeout(Duration::from_secs(1)) @@ -2097,6 +2103,7 @@ mod conn { tokio::spawn(async move { let sock = listener.accept().await.unwrap().0; hyper::server::conn::Http::new() + .with_timer(TokioTimer) .http2_only(true) .serve_connection( sock, @@ -2115,6 +2122,7 @@ mod conn { let io = tcp_connect(&addr).await.expect("tcp connect"); let (mut client, conn) = conn::Builder::new() + .timer(TokioTimer) .http2_only(true) .http2_keep_alive_interval(Duration::from_secs(1)) .http2_keep_alive_timeout(Duration::from_secs(1)) @@ -2133,7 +2141,7 @@ mod conn { let _resp = client.send_request(req).await.expect("send_request"); // sleep longer than keepalive would trigger - tokio::time::sleep(Duration::from_secs(4)).await; + TokioTimer.sleep(Duration::from_secs(4)).await; future::poll_fn(|ctx| client.poll_ready(ctx)) .await diff --git a/tests/server.rs b/tests/server.rs index 04b82b67f9..c294a70f21 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -21,6 +21,8 @@ use h2::client::SendRequest; use h2::{RecvStream, SendStream}; use http::header::{HeaderName, HeaderValue}; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; +use hyper::rt::Timer; +use support::TokioTimer; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener as TkTcpListener, TcpListener, TcpStream as TkTcpStream}; @@ -958,7 +960,7 @@ async fn expect_continue_waits_for_body_poll() { service_fn(|req| { assert_eq!(req.headers()["expect"], "100-continue"); // But! We're never going to poll the body! - tokio::time::sleep(Duration::from_millis(50)).map(move |_| { + TokioTimer.sleep(Duration::from_millis(50)).map(move |_| { // Move and drop the req, so we don't auto-close drop(req); Response::builder() @@ -1255,7 +1257,8 @@ async fn http1_allow_half_close() { .serve_connection( socket, service_fn(|_| { - tokio::time::sleep(Duration::from_millis(500)) + TokioTimer + .sleep(Duration::from_millis(500)) .map(|_| Ok::<_, hyper::Error>(Response::new(Empty::::new()))) }), ) @@ -1283,7 +1286,7 @@ async fn disconnect_after_reading_request_before_responding() { .serve_connection( socket, service_fn(|_| { - tokio::time::sleep(Duration::from_secs(2)).map( + TokioTimer.sleep(Duration::from_secs(2)).map( |_| -> Result, hyper::Error> { panic!("response future should have been dropped"); }, @@ -1376,6 +1379,7 @@ async fn header_read_timeout_slow_writes() { let (socket, _) = listener.accept().await.unwrap(); let conn = Http::new() + .with_timer(TokioTimer) .http1_header_read_timeout(Duration::from_secs(5)) .serve_connection( socket, @@ -1451,6 +1455,7 @@ async fn header_read_timeout_slow_writes_multiple_requests() { let (socket, _) = listener.accept().await.unwrap(); let conn = Http::new() + .with_timer(TokioTimer) .http1_header_read_timeout(Duration::from_secs(5)) .serve_connection( socket, @@ -2486,6 +2491,7 @@ async fn http2_keep_alive_detects_unresponsive_client() { let (socket, _) = listener.accept().await.expect("accept"); let err = Http::new() + .with_timer(TokioTimer) .http2_only(true) .http2_keep_alive_interval(Duration::from_secs(1)) .http2_keep_alive_timeout(Duration::from_secs(1)) @@ -2507,6 +2513,7 @@ async fn http2_keep_alive_with_responsive_client() { let (socket, _) = listener.accept().await.expect("accept"); Http::new() + .with_timer(TokioTimer) .http2_only(true) .http2_keep_alive_interval(Duration::from_secs(1)) .http2_keep_alive_timeout(Duration::from_secs(1)) @@ -2526,7 +2533,7 @@ async fn http2_keep_alive_with_responsive_client() { conn.await.expect("client conn"); }); - tokio::time::sleep(Duration::from_secs(4)).await; + TokioTimer.sleep(Duration::from_secs(4)).await; let req = http::Request::new(Empty::::new()); client.send_request(req).await.expect("client.send_request"); @@ -2574,6 +2581,7 @@ async fn http2_keep_alive_count_server_pings() { let (socket, _) = listener.accept().await.expect("accept"); Http::new() + .with_timer(TokioTimer) .http2_only(true) .http2_keep_alive_interval(Duration::from_secs(1)) .http2_keep_alive_timeout(Duration::from_secs(1)) diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 95bd576c73..f5ae663c13 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -21,6 +21,9 @@ pub use futures_util::{ pub use hyper::{HeaderMap, StatusCode}; pub use std::net::SocketAddr; +mod tokiort; +pub use tokiort::TokioTimer; + #[allow(unused_macros)] macro_rules! t { ( diff --git a/tests/support/tokiort.rs b/tests/support/tokiort.rs new file mode 120000 index 0000000000..d410b9522c --- /dev/null +++ b/tests/support/tokiort.rs @@ -0,0 +1 @@ +../../benches/support/tokiort.rs \ No newline at end of file