diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index cd9658c64..a70a853b1 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -2,11 +2,13 @@ use std::fmt; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::time::Duration; use bytes::Bytes; use http_body::Body as HttpBody; use http_body_util::combinators::BoxBody; //use sync_wrapper::SyncWrapper; +use pin_project_lite::pin_project; #[cfg(feature = "stream")] use tokio::fs::File; use tokio::time::Sleep; @@ -23,13 +25,26 @@ enum Inner { Streaming(BoxBody>), } -/// A body with a total timeout. -/// -/// The timeout does not reset upon each chunk, but rather requires the whole -/// body be streamed before the deadline is reached. -pub(crate) struct TotalTimeoutBody { - inner: B, - timeout: Pin>, +pin_project! { + /// A body with a total timeout. + /// + /// The timeout does not reset upon each chunk, but rather requires the whole + /// body be streamed before the deadline is reached. + pub(crate) struct TotalTimeoutBody { + #[pin] + inner: B, + timeout: Pin>, + } +} + +pin_project! { + pub(crate) struct ReadTimeoutBody { + #[pin] + inner: B, + #[pin] + sleep: Option, + timeout: Duration, + } } /// Converts any `impl Body` into a `impl Stream` of just its DATA frames. @@ -289,23 +304,32 @@ pub(crate) fn total_timeout(body: B, timeout: Pin>) -> TotalTimeou } } +pub(crate) fn with_read_timeout(body: B, timeout: Duration) -> ReadTimeoutBody { + ReadTimeoutBody { + inner: body, + sleep: None, + timeout, + } +} + impl hyper::body::Body for TotalTimeoutBody where - B: hyper::body::Body + Unpin, + B: hyper::body::Body, B::Error: Into>, { type Data = B::Data; type Error = crate::Error; fn poll_frame( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context, ) -> Poll, Self::Error>>> { - if let Poll::Ready(()) = self.timeout.as_mut().poll(cx) { + let this = self.project(); + if let Poll::Ready(()) = this.timeout.as_mut().poll(cx) { return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut)))); } Poll::Ready( - futures_core::ready!(Pin::new(&mut self.inner).poll_frame(cx)) + futures_core::ready!(this.inner.poll_frame(cx)) .map(|opt_chunk| opt_chunk.map_err(crate::error::body)), ) } @@ -321,22 +345,79 @@ where } } +impl hyper::body::Body for ReadTimeoutBody +where + B: hyper::body::Body, + B::Error: Into>, +{ + type Data = B::Data; + type Error = crate::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll, Self::Error>>> { + let mut this = self.project(); + + // Start the `Sleep` if not active. + let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() { + some + } else { + this.sleep.set(Some(tokio::time::sleep(*this.timeout))); + this.sleep.as_mut().as_pin_mut().unwrap() + }; + + // Error if the timeout has expired. + if let Poll::Ready(()) = sleep_pinned.poll(cx) { + return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut)))); + } + + let item = futures_core::ready!(this.inner.poll_frame(cx)) + .map(|opt_chunk| opt_chunk.map_err(crate::error::body)); + // a ready frame means timeout is reset + this.sleep.set(None); + Poll::Ready(item) + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } +} + pub(crate) type ResponseBody = http_body_util::combinators::BoxBody>; pub(crate) fn response( body: hyper::body::Incoming, - timeout: Option>>, + deadline: Option>>, + read_timeout: Option, ) -> ResponseBody { use http_body_util::BodyExt; - if let Some(timeout) = timeout { - total_timeout(body, timeout).map_err(Into::into).boxed() - } else { - body.map_err(Into::into).boxed() + match (deadline, read_timeout) { + (Some(total), Some(read)) => { + let body = with_read_timeout(body, read).map_err(box_err); + total_timeout(body, total).map_err(box_err).boxed() + } + (Some(total), None) => total_timeout(body, total).map_err(box_err).boxed(), + (None, Some(read)) => with_read_timeout(body, read).map_err(box_err).boxed(), + (None, None) => body.map_err(box_err).boxed(), } } +fn box_err(err: E) -> Box +where + E: Into>, +{ + err.into() +} + // ===== impl DataStream ===== impl futures_core::Stream for DataStream diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index e72268860..dc509c42a 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -108,6 +108,7 @@ struct Config { auto_sys_proxy: bool, redirect_policy: redirect::Policy, referer: bool, + read_timeout: Option, timeout: Option, #[cfg(feature = "__tls")] root_certs: Vec, @@ -204,6 +205,7 @@ impl ClientBuilder { auto_sys_proxy: true, redirect_policy: redirect::Policy::default(), referer: true, + read_timeout: None, timeout: None, #[cfg(feature = "__tls")] root_certs: Vec::new(), @@ -739,6 +741,7 @@ impl ClientBuilder { headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, + read_timeout: config.read_timeout, request_timeout: config.timeout, proxies, proxies_maybe_http_auth, @@ -1028,10 +1031,10 @@ impl ClientBuilder { // Timeout options - /// Enables a request timeout. + /// Enables a total request timeout. /// /// The timeout is applied from when the request starts connecting until the - /// response body has finished. + /// response body has finished. Also considered a total deadline. /// /// Default is no timeout. pub fn timeout(mut self, timeout: Duration) -> ClientBuilder { @@ -1039,6 +1042,18 @@ impl ClientBuilder { self } + /// Enables a read timeout. + /// + /// The timeout applies to each read operation, and resets after a + /// successful read. This is more appropriate for detecting stalled + /// connections when the size isn't known beforehand. + /// + /// Default is no timeout. + pub fn read_timeout(mut self, timeout: Duration) -> ClientBuilder { + self.config.read_timeout = Some(timeout); + self + } + /// Set a timeout for only the connect phase of a `Client`. /// /// Default is `None`. @@ -1985,11 +2000,17 @@ impl Client { } }; - let timeout = timeout + let total_timeout = timeout .or(self.inner.request_timeout) .map(tokio::time::sleep) .map(Box::pin); + let read_timeout_fut = self + .inner + .read_timeout + .map(tokio::time::sleep) + .map(Box::pin); + Pending { inner: PendingInner::Request(PendingRequest { method, @@ -2004,7 +2025,9 @@ impl Client { client: self.inner.clone(), in_flight, - timeout, + total_timeout, + read_timeout_fut, + read_timeout: self.inner.read_timeout, }), } } @@ -2210,6 +2233,7 @@ struct ClientRef { redirect_policy: redirect::Policy, referer: bool, request_timeout: Option, + read_timeout: Option, proxies: Arc>, proxies_maybe_http_auth: bool, https_only: bool, @@ -2246,6 +2270,10 @@ impl ClientRef { if let Some(ref d) = self.request_timeout { f.field("timeout", d); } + + if let Some(ref d) = self.read_timeout { + f.field("read_timeout", d); + } } } @@ -2277,7 +2305,10 @@ pin_project! { #[pin] in_flight: ResponseFuture, #[pin] - timeout: Option>>, + total_timeout: Option>>, + #[pin] + read_timeout_fut: Option>>, + read_timeout: Option, } } @@ -2292,8 +2323,12 @@ impl PendingRequest { self.project().in_flight } - fn timeout(self: Pin<&mut Self>) -> Pin<&mut Option>>> { - self.project().timeout + fn total_timeout(self: Pin<&mut Self>) -> Pin<&mut Option>>> { + self.project().total_timeout + } + + fn read_timeout(self: Pin<&mut Self>) -> Pin<&mut Option>>> { + self.project().read_timeout_fut } fn urls(self: Pin<&mut Self>) -> &mut Vec { @@ -2430,7 +2465,15 @@ impl Future for PendingRequest { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(delay) = self.as_mut().timeout().as_mut().as_pin_mut() { + if let Some(delay) = self.as_mut().total_timeout().as_mut().as_pin_mut() { + if let Poll::Ready(()) = delay.poll(cx) { + return Poll::Ready(Err( + crate::error::request(crate::error::TimedOut).with_url(self.url.clone()) + )); + } + } + + if let Some(delay) = self.as_mut().read_timeout().as_mut().as_pin_mut() { if let Poll::Ready(()) = delay.poll(cx) { return Poll::Ready(Err( crate::error::request(crate::error::TimedOut).with_url(self.url.clone()) @@ -2622,7 +2665,8 @@ impl Future for PendingRequest { res, self.url.clone(), self.client.accepts, - self.timeout.take(), + self.total_timeout.take(), + self.read_timeout, ); return Poll::Ready(Ok(res)); } diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index d2ddfc3a1..fcb25b115 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -1,6 +1,7 @@ use std::fmt; use std::net::SocketAddr; use std::pin::Pin; +use std::time::Duration; use bytes::Bytes; use http_body_util::BodyExt; @@ -37,12 +38,13 @@ impl Response { res: hyper::Response, url: Url, accepts: Accepts, - timeout: Option>>, + total_timeout: Option>>, + read_timeout: Option, ) -> Response { let (mut parts, body) = res.into_parts(); let decoder = Decoder::detect( &mut parts.headers, - super::body::response(body, timeout), + super::body::response(body, total_timeout, read_timeout), accepts, ); let res = hyper::Response::from_parts(parts, decoder); diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 5ba687337..c18fecdbe 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -11,13 +11,13 @@ async fn client_timeout() { let server = server::http(move |_req| { async { // delay returning the response - tokio::time::sleep(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_millis(300)).await; http::Response::default() } }); let client = reqwest::Client::builder() - .timeout(Duration::from_millis(500)) + .timeout(Duration::from_millis(100)) .build() .unwrap(); @@ -38,7 +38,7 @@ async fn request_timeout() { let server = server::http(move |_req| { async { // delay returning the response - tokio::time::sleep(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_millis(300)).await; http::Response::default() } }); @@ -49,7 +49,7 @@ async fn request_timeout() { let res = client .get(&url) - .timeout(Duration::from_millis(500)) + .timeout(Duration::from_millis(100)) .send() .await; @@ -152,7 +152,7 @@ async fn response_timeout() { async { // immediate response, but delayed body let body = reqwest::Body::wrap_stream(futures_util::stream::once(async { - tokio::time::sleep(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(1)).await; Ok::<_, std::convert::Infallible>("Hello") })); @@ -175,6 +175,105 @@ async fn response_timeout() { assert!(err.is_timeout()); } +#[tokio::test] +async fn read_timeout_applies_to_headers() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| { + async { + // delay returning the response + tokio::time::sleep(Duration::from_millis(300)).await; + http::Response::default() + } + }); + + let client = reqwest::Client::builder() + .read_timeout(Duration::from_millis(100)) + .build() + .unwrap(); + + let url = format!("http://{}/slow", server.addr()); + + let res = client.get(&url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_timeout()); + assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); +} + +#[cfg(feature = "stream")] +#[tokio::test] +async fn read_timeout_applies_to_body() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| { + async { + // immediate response, but delayed body + let body = reqwest::Body::wrap_stream(futures_util::stream::once(async { + tokio::time::sleep(Duration::from_millis(300)).await; + Ok::<_, std::convert::Infallible>("Hello") + })); + + http::Response::new(body) + } + }); + + let client = reqwest::Client::builder() + .read_timeout(Duration::from_millis(100)) + .no_proxy() + .build() + .unwrap(); + + let url = format!("http://{}/slow", server.addr()); + let res = client.get(&url).send().await.expect("Failed to get"); + let body = res.text().await; + + let err = body.unwrap_err(); + + assert!(err.is_timeout()); +} + +#[cfg(feature = "stream")] +#[tokio::test] +async fn read_timeout_allows_slow_response_body() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| { + async { + // immediate response, but body that has slow chunks + + let slow = futures_util::stream::unfold(0, |state| async move { + if state < 3 { + tokio::time::sleep(Duration::from_millis(100)).await; + Some(( + Ok::<_, std::convert::Infallible>(state.to_string()), + state + 1, + )) + } else { + None + } + }); + let body = reqwest::Body::wrap_stream(slow); + + http::Response::new(body) + } + }); + + let client = reqwest::Client::builder() + .read_timeout(Duration::from_millis(200)) + //.timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + let url = format!("http://{}/slow", server.addr()); + let res = client.get(&url).send().await.expect("Failed to get"); + let body = res.text().await.expect("body text"); + + assert_eq!(body, "012"); +} + /// Tests that internal client future cancels when the oneshot channel /// is canceled. #[cfg(feature = "blocking")]