diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index 52d32024ff3..a4ab8a03676 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -57,8 +57,10 @@ use try_next::TryNext; cfg_time! { pub(crate) mod timeout; + pub(crate) mod timeout_repeating; use timeout::Timeout; - use tokio::time::Duration; + use timeout_repeating::TimeoutRepeating; + use tokio::time::{Duration, Interval}; mod throttle; use throttle::{throttle, Throttle}; mod chunks_timeout; @@ -924,7 +926,9 @@ pub trait StreamExt: Stream { /// If the wrapped stream yields a value before the deadline is reached, the /// value is returned. Otherwise, an error is returned. The caller may decide /// to continue consuming the stream and will eventually get the next source - /// stream value once it becomes available. + /// stream value once it becomes available. See + /// [`timeout_repeating`](StreamExt::timeout_repeating) for an alternative + /// where the timeouts will repeat. /// /// # Notes /// @@ -971,6 +975,25 @@ pub trait StreamExt: Stream { /// assert_eq!(int_stream.try_next().await, Ok(None)); /// # } /// ``` + /// + /// Once a timeout error is received, no further events will be received + /// unless the wrapped stream yields a value (timeouts do not repeat). + /// + /// ``` + /// # #[tokio::main(flavor = "current_thread", start_paused = true)] + /// # async fn main() { + /// use tokio_stream::{StreamExt, wrappers::IntervalStream}; + /// use std::time::Duration; + /// let interval_stream = IntervalStream::new(tokio::time::interval(Duration::from_millis(100))); + /// let timeout_stream = interval_stream.timeout(Duration::from_millis(10)); + /// tokio::pin!(timeout_stream); + /// + /// // Only one timeout will be received between values in the source stream. + /// assert!(timeout_stream.try_next().await.is_ok()); + /// assert!(timeout_stream.try_next().await.is_err(), "expected one timeout"); + /// assert!(timeout_stream.try_next().await.is_ok(), "expected no more timeouts"); + /// # } + /// ``` #[cfg(all(feature = "time"))] #[cfg_attr(docsrs, doc(cfg(feature = "time")))] fn timeout(self, duration: Duration) -> Timeout @@ -980,6 +1003,95 @@ pub trait StreamExt: Stream { Timeout::new(self, duration) } + /// Applies a per-item timeout to the passed stream. + /// + /// `timeout_repeating()` takes an [`Interval`](tokio::time::Interval) that + /// controls the time each element of the stream has to complete before + /// timing out. + /// + /// If the wrapped stream yields a value before the deadline is reached, the + /// value is returned. Otherwise, an error is returned. The caller may decide + /// to continue consuming the stream and will eventually get the next source + /// stream value once it becomes available. Unlike `timeout()`, if no value + /// becomes available before the deadline is reached, additional errors are + /// returned at the specified interval. See [`timeout`](StreamExt::timeout) + /// for an alternative where the timeouts do not repeat. + /// + /// # Notes + /// + /// This function consumes the stream passed into it and returns a + /// wrapped version of it. + /// + /// Polling the returned stream will continue to poll the inner stream even + /// if one or more items time out. + /// + /// # Examples + /// + /// Suppose we have a stream `int_stream` that yields 3 numbers (1, 2, 3): + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_stream::{self as stream, StreamExt}; + /// use std::time::Duration; + /// # let int_stream = stream::iter(1..=3); + /// + /// let int_stream = int_stream.timeout_repeating(tokio::time::interval(Duration::from_secs(1))); + /// tokio::pin!(int_stream); + /// + /// // When no items time out, we get the 3 elements in succession: + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(2))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(3))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// + /// // If the second item times out, we get an error and continue polling the stream: + /// # let mut int_stream = stream::iter(vec![Ok(1), Err(()), Ok(2), Ok(3)]); + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert!(int_stream.try_next().await.is_err()); + /// assert_eq!(int_stream.try_next().await, Ok(Some(2))); + /// assert_eq!(int_stream.try_next().await, Ok(Some(3))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// + /// // If we want to stop consuming the source stream the first time an + /// // element times out, we can use the `take_while` operator: + /// # let int_stream = stream::iter(vec![Ok(1), Err(()), Ok(2), Ok(3)]); + /// let mut int_stream = int_stream.take_while(Result::is_ok); + /// + /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); + /// assert_eq!(int_stream.try_next().await, Ok(None)); + /// # } + /// ``` + /// + /// Timeout errors will be continuously produced at the specified interval + /// until the wrapped stream yields a value. + /// + /// ``` + /// # #[tokio::main(flavor = "current_thread", start_paused = true)] + /// # async fn main() { + /// use tokio_stream::{StreamExt, wrappers::IntervalStream}; + /// use std::time::Duration; + /// let interval_stream = IntervalStream::new(tokio::time::interval(Duration::from_millis(23))); + /// let timeout_stream = interval_stream.timeout_repeating(tokio::time::interval(Duration::from_millis(9))); + /// tokio::pin!(timeout_stream); + /// + /// // Multiple timeouts will be received between values in the source stream. + /// assert!(timeout_stream.try_next().await.is_ok()); + /// assert!(timeout_stream.try_next().await.is_err(), "expected one timeout"); + /// assert!(timeout_stream.try_next().await.is_err(), "expected a second timeout"); + /// // Will eventually receive another value from the source stream... + /// assert!(timeout_stream.try_next().await.is_ok(), "expected non-timeout"); + /// # } + /// ``` + #[cfg(all(feature = "time"))] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + fn timeout_repeating(self, interval: Interval) -> TimeoutRepeating + where + Self: Sized, + { + TimeoutRepeating::new(self, interval) + } + /// Slows down a stream by enforcing a delay between items. /// /// The underlying timer behind this utility has a granularity of one millisecond. diff --git a/tokio-stream/src/stream_ext/timeout.rs b/tokio-stream/src/stream_ext/timeout.rs index a440d203ec4..17d1349022e 100644 --- a/tokio-stream/src/stream_ext/timeout.rs +++ b/tokio-stream/src/stream_ext/timeout.rs @@ -23,7 +23,7 @@ pin_project! { } } -/// Error returned by `Timeout`. +/// Error returned by `Timeout` and `TimeoutRepeating`. #[derive(Debug, PartialEq, Eq)] pub struct Elapsed(()); diff --git a/tokio-stream/src/stream_ext/timeout_repeating.rs b/tokio-stream/src/stream_ext/timeout_repeating.rs new file mode 100644 index 00000000000..253d2fd677e --- /dev/null +++ b/tokio-stream/src/stream_ext/timeout_repeating.rs @@ -0,0 +1,56 @@ +use crate::stream_ext::Fuse; +use crate::{Elapsed, Stream}; +use tokio::time::Interval; + +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream returned by the [`timeout_repeating`](super::StreamExt::timeout_repeating) method. + #[must_use = "streams do nothing unless polled"] + #[derive(Debug)] + pub struct TimeoutRepeating { + #[pin] + stream: Fuse, + #[pin] + interval: Interval, + } +} + +impl TimeoutRepeating { + pub(super) fn new(stream: S, interval: Interval) -> Self { + TimeoutRepeating { + stream: Fuse::new(stream), + interval, + } + } +} + +impl Stream for TimeoutRepeating { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.project(); + + match me.stream.poll_next(cx) { + Poll::Ready(v) => { + if v.is_some() { + me.interval.reset(); + } + return Poll::Ready(v.map(Ok)); + } + Poll::Pending => {} + }; + + ready!(me.interval.poll_tick(cx)); + Poll::Ready(Some(Err(Elapsed::new()))) + } + + fn size_hint(&self) -> (usize, Option) { + let (lower, _) = self.stream.size_hint(); + + // The timeout stream may insert an error an infinite number of times. + (lower, None) + } +}