Skip to content

Commit

Permalink
feat: add ClientBuilder::read_timeout(dur) (#2241)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar authored Apr 15, 2024
1 parent e99da85 commit 1af8945
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 32 deletions.
113 changes: 97 additions & 16 deletions src/async_impl/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use bytes::Bytes;
use http_body::Body as HttpBody;
use http_body_util::combinators::BoxBody;
//use sync_wrapper::SyncWrapper;
use pin_project_lite::pin_project;
#[cfg(feature = "stream")]
use tokio::fs::File;
use tokio::time::Sleep;
Expand All @@ -23,13 +25,26 @@ enum Inner {
Streaming(BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>),
}

/// A body with a total timeout.
///
/// The timeout does not reset upon each chunk, but rather requires the whole
/// body be streamed before the deadline is reached.
pub(crate) struct TotalTimeoutBody<B> {
inner: B,
timeout: Pin<Box<Sleep>>,
pin_project! {
/// A body with a total timeout.
///
/// The timeout does not reset upon each chunk, but rather requires the whole
/// body be streamed before the deadline is reached.
pub(crate) struct TotalTimeoutBody<B> {
#[pin]
inner: B,
timeout: Pin<Box<Sleep>>,
}
}

pin_project! {
pub(crate) struct ReadTimeoutBody<B> {
#[pin]
inner: B,
#[pin]
sleep: Option<Sleep>,
timeout: Duration,
}
}

/// Converts any `impl Body` into a `impl Stream` of just its DATA frames.
Expand Down Expand Up @@ -289,23 +304,32 @@ pub(crate) fn total_timeout<B>(body: B, timeout: Pin<Box<Sleep>>) -> TotalTimeou
}
}

pub(crate) fn with_read_timeout<B>(body: B, timeout: Duration) -> ReadTimeoutBody<B> {
ReadTimeoutBody {
inner: body,
sleep: None,
timeout,
}
}

impl<B> hyper::body::Body for TotalTimeoutBody<B>
where
B: hyper::body::Body + Unpin,
B: hyper::body::Body,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Data = B::Data;
type Error = crate::Error;

fn poll_frame(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
if let Poll::Ready(()) = self.timeout.as_mut().poll(cx) {
let this = self.project();
if let Poll::Ready(()) = this.timeout.as_mut().poll(cx) {
return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut))));
}
Poll::Ready(
futures_core::ready!(Pin::new(&mut self.inner).poll_frame(cx))
futures_core::ready!(this.inner.poll_frame(cx))
.map(|opt_chunk| opt_chunk.map_err(crate::error::body)),
)
}
Expand All @@ -321,22 +345,79 @@ where
}
}

impl<B> hyper::body::Body for ReadTimeoutBody<B>
where
B: hyper::body::Body,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Data = B::Data;
type Error = crate::Error;

fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
let mut this = self.project();

// Start the `Sleep` if not active.
let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
some
} else {
this.sleep.set(Some(tokio::time::sleep(*this.timeout)));
this.sleep.as_mut().as_pin_mut().unwrap()
};

// Error if the timeout has expired.
if let Poll::Ready(()) = sleep_pinned.poll(cx) {
return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut))));
}

let item = futures_core::ready!(this.inner.poll_frame(cx))
.map(|opt_chunk| opt_chunk.map_err(crate::error::body));
// a ready frame means timeout is reset
this.sleep.set(None);
Poll::Ready(item)
}

#[inline]
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}

#[inline]
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
}

pub(crate) type ResponseBody =
http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;

pub(crate) fn response(
body: hyper::body::Incoming,
timeout: Option<Pin<Box<Sleep>>>,
deadline: Option<Pin<Box<Sleep>>>,
read_timeout: Option<Duration>,
) -> ResponseBody {
use http_body_util::BodyExt;

if let Some(timeout) = timeout {
total_timeout(body, timeout).map_err(Into::into).boxed()
} else {
body.map_err(Into::into).boxed()
match (deadline, read_timeout) {
(Some(total), Some(read)) => {
let body = with_read_timeout(body, read).map_err(box_err);
total_timeout(body, total).map_err(box_err).boxed()
}
(Some(total), None) => total_timeout(body, total).map_err(box_err).boxed(),
(None, Some(read)) => with_read_timeout(body, read).map_err(box_err).boxed(),
(None, None) => body.map_err(box_err).boxed(),
}
}

fn box_err<E>(err: E) -> Box<dyn std::error::Error + Send + Sync>
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
err.into()
}

// ===== impl DataStream =====

impl<B> futures_core::Stream for DataStream<B>
Expand Down
62 changes: 53 additions & 9 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct Config {
auto_sys_proxy: bool,
redirect_policy: redirect::Policy,
referer: bool,
read_timeout: Option<Duration>,
timeout: Option<Duration>,
#[cfg(feature = "__tls")]
root_certs: Vec<Certificate>,
Expand Down Expand Up @@ -204,6 +205,7 @@ impl ClientBuilder {
auto_sys_proxy: true,
redirect_policy: redirect::Policy::default(),
referer: true,
read_timeout: None,
timeout: None,
#[cfg(feature = "__tls")]
root_certs: Vec::new(),
Expand Down Expand Up @@ -739,6 +741,7 @@ impl ClientBuilder {
headers: config.headers,
redirect_policy: config.redirect_policy,
referer: config.referer,
read_timeout: config.read_timeout,
request_timeout: config.timeout,
proxies,
proxies_maybe_http_auth,
Expand Down Expand Up @@ -1028,17 +1031,29 @@ impl ClientBuilder {

// Timeout options

/// Enables a request timeout.
/// Enables a total request timeout.
///
/// The timeout is applied from when the request starts connecting until the
/// response body has finished.
/// response body has finished. Also considered a total deadline.
///
/// Default is no timeout.
pub fn timeout(mut self, timeout: Duration) -> ClientBuilder {
self.config.timeout = Some(timeout);
self
}

/// Enables a read timeout.
///
/// The timeout applies to each read operation, and resets after a
/// successful read. This is more appropriate for detecting stalled
/// connections when the size isn't known beforehand.
///
/// Default is no timeout.
pub fn read_timeout(mut self, timeout: Duration) -> ClientBuilder {
self.config.read_timeout = Some(timeout);
self
}

/// Set a timeout for only the connect phase of a `Client`.
///
/// Default is `None`.
Expand Down Expand Up @@ -1985,11 +2000,17 @@ impl Client {
}
};

let timeout = timeout
let total_timeout = timeout
.or(self.inner.request_timeout)
.map(tokio::time::sleep)
.map(Box::pin);

let read_timeout_fut = self
.inner
.read_timeout
.map(tokio::time::sleep)
.map(Box::pin);

Pending {
inner: PendingInner::Request(PendingRequest {
method,
Expand All @@ -2004,7 +2025,9 @@ impl Client {
client: self.inner.clone(),

in_flight,
timeout,
total_timeout,
read_timeout_fut,
read_timeout: self.inner.read_timeout,
}),
}
}
Expand Down Expand Up @@ -2210,6 +2233,7 @@ struct ClientRef {
redirect_policy: redirect::Policy,
referer: bool,
request_timeout: Option<Duration>,
read_timeout: Option<Duration>,
proxies: Arc<Vec<Proxy>>,
proxies_maybe_http_auth: bool,
https_only: bool,
Expand Down Expand Up @@ -2246,6 +2270,10 @@ impl ClientRef {
if let Some(ref d) = self.request_timeout {
f.field("timeout", d);
}

if let Some(ref d) = self.read_timeout {
f.field("read_timeout", d);
}
}
}

Expand Down Expand Up @@ -2277,7 +2305,10 @@ pin_project! {
#[pin]
in_flight: ResponseFuture,
#[pin]
timeout: Option<Pin<Box<Sleep>>>,
total_timeout: Option<Pin<Box<Sleep>>>,
#[pin]
read_timeout_fut: Option<Pin<Box<Sleep>>>,
read_timeout: Option<Duration>,
}
}

Expand All @@ -2292,8 +2323,12 @@ impl PendingRequest {
self.project().in_flight
}

fn timeout(self: Pin<&mut Self>) -> Pin<&mut Option<Pin<Box<Sleep>>>> {
self.project().timeout
fn total_timeout(self: Pin<&mut Self>) -> Pin<&mut Option<Pin<Box<Sleep>>>> {
self.project().total_timeout
}

fn read_timeout(self: Pin<&mut Self>) -> Pin<&mut Option<Pin<Box<Sleep>>>> {
self.project().read_timeout_fut
}

fn urls(self: Pin<&mut Self>) -> &mut Vec<Url> {
Expand Down Expand Up @@ -2430,7 +2465,15 @@ impl Future for PendingRequest {
type Output = Result<Response, crate::Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(delay) = self.as_mut().timeout().as_mut().as_pin_mut() {
if let Some(delay) = self.as_mut().total_timeout().as_mut().as_pin_mut() {
if let Poll::Ready(()) = delay.poll(cx) {
return Poll::Ready(Err(
crate::error::request(crate::error::TimedOut).with_url(self.url.clone())
));
}
}

if let Some(delay) = self.as_mut().read_timeout().as_mut().as_pin_mut() {
if let Poll::Ready(()) = delay.poll(cx) {
return Poll::Ready(Err(
crate::error::request(crate::error::TimedOut).with_url(self.url.clone())
Expand Down Expand Up @@ -2622,7 +2665,8 @@ impl Future for PendingRequest {
res,
self.url.clone(),
self.client.accepts,
self.timeout.take(),
self.total_timeout.take(),
self.read_timeout,
);
return Poll::Ready(Ok(res));
}
Expand Down
6 changes: 4 additions & 2 deletions src/async_impl/response.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt;
use std::net::SocketAddr;
use std::pin::Pin;
use std::time::Duration;

use bytes::Bytes;
use http_body_util::BodyExt;
Expand Down Expand Up @@ -37,12 +38,13 @@ impl Response {
res: hyper::Response<hyper::body::Incoming>,
url: Url,
accepts: Accepts,
timeout: Option<Pin<Box<Sleep>>>,
total_timeout: Option<Pin<Box<Sleep>>>,
read_timeout: Option<Duration>,
) -> Response {
let (mut parts, body) = res.into_parts();
let decoder = Decoder::detect(
&mut parts.headers,
super::body::response(body, timeout),
super::body::response(body, total_timeout, read_timeout),
accepts,
);
let res = hyper::Response::from_parts(parts, decoder);
Expand Down
Loading

0 comments on commit 1af8945

Please sign in to comment.