From 94b508bfcb1d1f986f34f9b12d799e5a73f65f22 Mon Sep 17 00:00:00 2001 From: Oleg Nosov Date: Sun, 6 Feb 2022 10:33:11 +0300 Subject: [PATCH] Basic `StreamExt::{flatten_unordered, flat_map_unordered}` impls (#2083) --- futures-util/benches/flatten_unordered.rs | 66 +++ .../src/stream/stream/flatten_unordered.rs | 501 ++++++++++++++++++ futures-util/src/stream/stream/mod.rs | 118 ++++- futures/tests/stream.rs | 272 +++++++++- 4 files changed, 955 insertions(+), 2 deletions(-) create mode 100644 futures-util/benches/flatten_unordered.rs create mode 100644 futures-util/src/stream/stream/flatten_unordered.rs diff --git a/futures-util/benches/flatten_unordered.rs b/futures-util/benches/flatten_unordered.rs new file mode 100644 index 0000000000..64d5f9a4e3 --- /dev/null +++ b/futures-util/benches/flatten_unordered.rs @@ -0,0 +1,66 @@ +#![feature(test)] + +extern crate test; +use crate::test::Bencher; + +use futures::channel::oneshot; +use futures::executor::block_on; +use futures::future::{self, FutureExt}; +use futures::stream::{self, StreamExt}; +use futures::task::Poll; +use std::collections::VecDeque; +use std::thread; + +#[bench] +fn oneshot_streams(b: &mut Bencher) { + const STREAM_COUNT: usize = 10_000; + const STREAM_ITEM_COUNT: usize = 1; + + b.iter(|| { + let mut txs = VecDeque::with_capacity(STREAM_COUNT); + let mut rxs = Vec::new(); + + for _ in 0..STREAM_COUNT { + let (tx, rx) = oneshot::channel(); + txs.push_back(tx); + rxs.push(rx); + } + + thread::spawn(move || { + let mut last = 1; + while let Some(tx) = txs.pop_front() { + let _ = tx.send(stream::iter(last..last + STREAM_ITEM_COUNT)); + last += STREAM_ITEM_COUNT; + } + }); + + let mut flatten = stream::unfold(rxs.into_iter(), |mut vals| { + async { + if let Some(next) = vals.next() { + let val = next.await.unwrap(); + Some((val, vals)) + } else { + None + } + } + .boxed() + }) + .flatten_unordered(None); + + block_on(future::poll_fn(move |cx| { + let mut count = 0; + loop { + match flatten.poll_next_unpin(cx) { + Poll::Ready(None) => break, + Poll::Ready(Some(_)) => { + count += 1; + } + _ => {} + } + } + assert_eq!(count, STREAM_COUNT * STREAM_ITEM_COUNT); + + Poll::Ready(()) + })) + }); +} diff --git a/futures-util/src/stream/stream/flatten_unordered.rs b/futures-util/src/stream/stream/flatten_unordered.rs new file mode 100644 index 0000000000..6361565f2b --- /dev/null +++ b/futures-util/src/stream/stream/flatten_unordered.rs @@ -0,0 +1,501 @@ +use alloc::sync::Arc; +use core::{ + cell::UnsafeCell, + convert::identity, + fmt, + num::NonZeroUsize, + pin::Pin, + sync::atomic::{AtomicU8, Ordering}, +}; + +use pin_project_lite::pin_project; + +use futures_core::{ + future::Future, + ready, + stream::{FusedStream, Stream}, + task::{Context, Poll, Waker}, +}; +#[cfg(feature = "sink")] +use futures_sink::Sink; +use futures_task::{waker, ArcWake}; + +use crate::stream::FuturesUnordered; + +/// There is nothing to poll and stream isn't being +/// polled or waking at the moment. +const NONE: u8 = 0; + +/// Inner streams need to be polled. +const NEED_TO_POLL_INNER_STREAMS: u8 = 1; + +/// The base stream needs to be polled. +const NEED_TO_POLL_STREAM: u8 = 0b10; + +/// It needs to poll base stream and inner streams. +const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM; + +/// The current stream is being polled at the moment. +const POLLING: u8 = 0b100; + +/// Inner streams are being woken at the moment. +const WAKING_INNER_STREAMS: u8 = 0b1000; + +/// The base stream is being woken at the moment. +const WAKING_STREAM: u8 = 0b10000; + +/// The base stream and inner streams are being woken at the moment. +const WAKING_ALL: u8 = WAKING_STREAM | WAKING_INNER_STREAMS; + +/// The stream was waked and will be polled. +const WOKEN: u8 = 0b100000; + +/// Determines what needs to be polled, and is stream being polled at the +/// moment or not. +#[derive(Clone, Debug)] +struct SharedPollState { + state: Arc, +} + +impl SharedPollState { + /// Constructs new `SharedPollState` with the given state. + fn new(value: u8) -> SharedPollState { + SharedPollState { state: Arc::new(AtomicU8::new(value)) } + } + + /// Attempts to start polling, returning stored state in case of success. + /// Returns `None` if some waker is waking at the moment. + fn start_polling( + &self, + ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> { + let value = self + .state + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { + if value & WAKING_ALL == NONE { + Some(POLLING) + } else { + None + } + }) + .ok()?; + let bomb = PollStateBomb::new(self, SharedPollState::reset); + + Some((value, bomb)) + } + + /// Starts the waking process and performs bitwise or with the given value. + fn start_waking( + &self, + to_poll: u8, + waking: u8, + ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> { + let value = self + .state + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { + let next_value = value | to_poll; + + if value & (WOKEN | POLLING) == NONE { + Some(next_value | waking) + } else if next_value != value { + Some(next_value) + } else { + None + } + }) + .ok()?; + + if value & (WOKEN | POLLING) == NONE { + let bomb = PollStateBomb::new(self, move |state| state.stop_waking(waking)); + + Some((value, bomb)) + } else { + None + } + } + + /// Sets current state to + /// - `!POLLING` allowing to use wakers + /// - `WOKEN` if the state was changed during `POLLING` phase as waker will be called, + /// or `will_be_woken` flag supplied + /// - `!WAKING_ALL` as + /// * Wakers called during the `POLLING` phase won't propagate their calls + /// * `POLLING` phase can't start if some of the wakers are active + /// So no wrapped waker can touch the inner waker's cell, it's safe to poll again. + fn stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8 { + self.state + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut value| { + let mut next_value = to_poll; + + value &= NEED_TO_POLL_ALL; + if value != NONE || will_be_woken { + next_value |= WOKEN; + } + next_value |= value; + + Some(next_value & !POLLING & !WAKING_ALL) + }) + .unwrap() + } + + /// Toggles state to non-waking, allowing to start polling. + fn stop_waking(&self, waking: u8) -> u8 { + self.state + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| { + let next_value = value & !waking; + + if value & WAKING_ALL == waking { + Some(next_value | WOKEN) + } else if next_value != value { + Some(next_value) + } else { + None + } + }) + .unwrap_or_else(identity) + } + + /// Resets current state allowing to poll the stream and wake up wakers. + fn reset(&self) -> u8 { + self.state.swap(NEED_TO_POLL_ALL, Ordering::AcqRel) + } +} + +/// Used to execute some function on the given state when dropped. +struct PollStateBomb<'a, F: FnOnce(&SharedPollState) -> u8> { + state: &'a SharedPollState, + drop: Option, +} + +impl<'a, F: FnOnce(&SharedPollState) -> u8> PollStateBomb<'a, F> { + /// Constructs new bomb with the given state. + fn new(state: &'a SharedPollState, drop: F) -> Self { + Self { state, drop: Some(drop) } + } + + /// Deactivates bomb, forces it to not call provided function when dropped. + fn deactivate(mut self) { + self.drop.take(); + } + + /// Manually fires the bomb, returning supplied state. + fn fire(mut self) -> Option { + self.drop.take().map(|drop| (drop)(self.state)) + } +} + +impl u8> Drop for PollStateBomb<'_, F> { + fn drop(&mut self) { + if let Some(drop) = self.drop.take() { + (drop)(self.state); + } + } +} + +/// Will update state with the provided value on `wake_by_ref` call +/// and then, if there is a need, call `inner_waker`. +struct InnerWaker { + inner_waker: UnsafeCell>, + poll_state: SharedPollState, + need_to_poll: u8, +} + +unsafe impl Send for InnerWaker {} +unsafe impl Sync for InnerWaker {} + +impl InnerWaker { + /// Replaces given waker's inner_waker for polling stream/futures which will + /// update poll state on `wake_by_ref` call. Use only if you need several + /// contexts. + /// + /// ## Safety + /// + /// This function will modify waker's `inner_waker` via `UnsafeCell`, so + /// it should be used only during `POLLING` phase. + unsafe fn replace_waker(self_arc: &mut Arc, cx: &Context<'_>) -> Waker { + *self_arc.inner_waker.get() = cx.waker().clone().into(); + waker(self_arc.clone()) + } + + /// Attempts to start the waking process for the waker with the given value. + /// If succeeded, then the stream isn't yet woken and not being polled at the moment. + fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> { + self.poll_state.start_waking(self.need_to_poll, self.waking_state()) + } + + /// Returns the corresponding waking state toggled by this waker. + fn waking_state(&self) -> u8 { + self.need_to_poll << 3 + } +} + +impl ArcWake for InnerWaker { + fn wake_by_ref(self_arc: &Arc) { + if let Some((_, state_bomb)) = self_arc.start_waking() { + // Safety: now state is not `POLLING` + let waker_opt = unsafe { self_arc.inner_waker.get().as_ref().unwrap() }; + + if let Some(inner_waker) = waker_opt.clone() { + // Stop waking to allow polling stream + let poll_state_value = state_bomb.fire().unwrap(); + + // Here we want to call waker only if stream isn't woken yet and + // also to optimize the case when two wakers are called at the same time. + // + // In this case the best strategy will be to propagate only the latest waker's awake, + // and then poll both entities in a single `poll_next` call + if poll_state_value & (WOKEN | WAKING_ALL) == self_arc.waking_state() { + // Wake up inner waker + inner_waker.wake(); + } + } + } + } +} + +pin_project! { + /// Future which contains optional stream. + /// + /// If it's `Some`, it will attempt to call `poll_next` on it, + /// returning `Some((item, next_item_fut))` in case of `Poll::Ready(Some(...))` + /// or `None` in case of `Poll::Ready(None)`. + /// + /// If `poll_next` will return `Poll::Pending`, it will be forwarded to + /// the future and current task will be notified by waker. + #[must_use = "futures do nothing unless you `.await` or poll them"] + struct PollStreamFut { + #[pin] + stream: Option, + } +} + +impl PollStreamFut { + /// Constructs new `PollStreamFut` using given `stream`. + fn new(stream: impl Into>) -> Self { + Self { stream: stream.into() } + } +} + +impl Future for PollStreamFut { + type Output = Option<(St::Item, PollStreamFut)>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut stream = self.project().stream; + + let item = if let Some(stream) = stream.as_mut().as_pin_mut() { + ready!(stream.poll_next(cx)) + } else { + None + }; + let next_item_fut = PollStreamFut::new(stream.get_mut().take()); + let out = item.map(|item| (item, next_item_fut)); + + Poll::Ready(out) + } +} + +pin_project! { + /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered) + /// method. + #[project = FlattenUnorderedProj] + #[must_use = "streams do nothing unless polled"] + pub struct FlattenUnordered where St: Stream { + #[pin] + inner_streams: FuturesUnordered>, + #[pin] + stream: St, + poll_state: SharedPollState, + limit: Option, + is_stream_done: bool, + inner_streams_waker: Arc, + stream_waker: Arc, + } +} + +impl fmt::Debug for FlattenUnordered +where + St: Stream + fmt::Debug, + St::Item: Stream + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FlattenUnordered") + .field("poll_state", &self.poll_state) + .field("inner_streams", &self.inner_streams) + .field("limit", &self.limit) + .field("stream", &self.stream) + .field("is_stream_done", &self.is_stream_done) + .finish() + } +} + +impl FlattenUnordered +where + St: Stream, + St::Item: Stream + Unpin, +{ + pub(super) fn new(stream: St, limit: Option) -> FlattenUnordered { + let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM); + + FlattenUnordered { + inner_streams: FuturesUnordered::new(), + stream, + is_stream_done: false, + limit: limit.and_then(NonZeroUsize::new), + inner_streams_waker: Arc::new(InnerWaker { + inner_waker: UnsafeCell::new(None), + poll_state: poll_state.clone(), + need_to_poll: NEED_TO_POLL_INNER_STREAMS, + }), + stream_waker: Arc::new(InnerWaker { + inner_waker: UnsafeCell::new(None), + poll_state: poll_state.clone(), + need_to_poll: NEED_TO_POLL_STREAM, + }), + poll_state, + } + } + + delegate_access_inner!(stream, St, ()); +} + +impl FlattenUnorderedProj<'_, St> +where + St: Stream, +{ + /// Checks if current `inner_streams` size is less than optional limit. + fn is_exceeded_limit(&self) -> bool { + self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get()) + } +} + +impl FusedStream for FlattenUnordered +where + St: FusedStream, + St::Item: FusedStream + Unpin, +{ + fn is_terminated(&self) -> bool { + self.stream.is_terminated() && self.inner_streams.is_empty() + } +} + +impl Stream for FlattenUnordered +where + St: Stream, + St::Item: Stream + Unpin, +{ + type Item = ::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut next_item = None; + let mut need_to_poll_next = NONE; + + let mut this = self.as_mut().project(); + + let (mut poll_state_value, state_bomb) = match this.poll_state.start_polling() { + Some(value) => value, + _ => { + // Waker was called, just wait for the next poll + return Poll::Pending; + } + }; + + if poll_state_value & NEED_TO_POLL_STREAM != NONE { + // Safety: now state is `POLLING`. + let stream_waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) }; + + // Here we need to poll the base stream. + // + // To improve performance, we will attempt to place as many items as we can + // to the `FuturesUnordered` bucket before polling inner streams + loop { + if this.is_exceeded_limit() || *this.is_stream_done { + // We either exceeded the limit or the stream is exhausted + if !*this.is_stream_done { + // The stream needs to be polled in the next iteration + need_to_poll_next |= NEED_TO_POLL_STREAM; + } + + break; + } else { + match this.stream.as_mut().poll_next(&mut Context::from_waker(&stream_waker)) { + Poll::Ready(Some(inner_stream)) => { + // Add new stream to the inner streams bucket + this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream)); + // Inner streams must be polled afterward + poll_state_value |= NEED_TO_POLL_INNER_STREAMS; + } + Poll::Ready(None) => { + // Mark the stream as done + *this.is_stream_done = true; + } + Poll::Pending => { + break; + } + } + } + } + } + + if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE { + // Safety: now state is `POLLING`. + let inner_streams_waker = + unsafe { InnerWaker::replace_waker(this.inner_streams_waker, cx) }; + + match this + .inner_streams + .as_mut() + .poll_next(&mut Context::from_waker(&inner_streams_waker)) + { + Poll::Ready(Some(Some((item, next_item_fut)))) => { + // Push next inner stream item future to the list of inner streams futures + this.inner_streams.as_mut().push(next_item_fut); + // Take the received item + next_item = Some(item); + // On the next iteration, inner streams must be polled again + need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS; + } + Poll::Ready(Some(None)) => { + // On the next iteration, inner streams must be polled again + need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS; + } + _ => {} + } + } + + // We didn't have any `poll_next` panic, so it's time to deactivate the bomb + state_bomb.deactivate(); + + let mut force_wake = + // we need to poll the stream and didn't reach the limit yet + need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit() + // or we need to poll inner streams again + || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE; + + // Stop polling and swap the latest state + poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake); + // If state was changed during `POLLING` phase, need to manually call a waker + force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE; + + let is_done = *this.is_stream_done && this.inner_streams.is_empty(); + + if next_item.is_some() || is_done { + Poll::Ready(next_item) + } else { + if force_wake { + cx.waker().wake_by_ref(); + } + + Poll::Pending + } + } +} + +// Forwarding impl of Sink from the underlying stream +#[cfg(feature = "sink")] +impl Sink for FlattenUnordered +where + St: Stream + Sink, +{ + type Error = St::Error; + + delegate_sink!(stream, Item); +} diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index a7313401d5..642b91e823 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -197,6 +197,25 @@ mod buffered; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::buffered::Buffered; +#[cfg(not(futures_no_atomic_cas))] +#[cfg(feature = "alloc")] +mod flatten_unordered; + +#[cfg(not(futures_no_atomic_cas))] +#[cfg(feature = "alloc")] +#[allow(unreachable_pub)] +pub use self::flatten_unordered::FlattenUnordered; + +#[cfg(not(futures_no_atomic_cas))] +#[cfg(feature = "alloc")] +delegate_all!( + /// Stream for the [`flat_map_unordered`](StreamExt::flat_map_unordered) method. + FlatMapUnordered( + FlattenUnordered> + ): Debug + Sink + Stream + FusedStream + AccessInner[St, (. .)] + New[|x: St, limit: Option, f: F| FlattenUnordered::new(Map::new(x, f), limit)] + where St: Stream, U: Stream, U: Unpin, F: FnMut(St::Item) -> U +); + #[cfg(not(futures_no_atomic_cas))] #[cfg(feature = "alloc")] mod for_each_concurrent; @@ -754,13 +773,57 @@ pub trait StreamExt: Stream { assert_stream::<::Item, _>(Flatten::new(self)) } + /// Flattens a stream of streams into just one continuous stream. Polls + /// inner streams concurrently. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::channel::mpsc; + /// use futures::stream::StreamExt; + /// use std::thread; + /// + /// let (tx1, rx1) = mpsc::unbounded(); + /// let (tx2, rx2) = mpsc::unbounded(); + /// let (tx3, rx3) = mpsc::unbounded(); + /// + /// thread::spawn(move || { + /// tx1.unbounded_send(1).unwrap(); + /// tx1.unbounded_send(2).unwrap(); + /// }); + /// thread::spawn(move || { + /// tx2.unbounded_send(3).unwrap(); + /// tx2.unbounded_send(4).unwrap(); + /// }); + /// thread::spawn(move || { + /// tx3.unbounded_send(rx1).unwrap(); + /// tx3.unbounded_send(rx2).unwrap(); + /// }); + /// + /// let mut output = rx3.flatten_unordered(None).collect::>().await; + /// output.sort(); + /// + /// assert_eq!(output, vec![1, 2, 3, 4]); + /// # }); + /// ``` + #[cfg(not(futures_no_atomic_cas))] + #[cfg(feature = "alloc")] + fn flatten_unordered(self, limit: impl Into>) -> FlattenUnordered + where + Self::Item: Stream + Unpin, + Self: Sized, + { + FlattenUnordered::new(self, limit.into()) + } + /// Maps a stream like [`StreamExt::map`] but flattens nested `Stream`s. /// /// [`StreamExt::map`] is very useful, but if it produces a `Stream` instead, /// you would have to chain combinators like `.map(f).flatten()` while this /// combinator provides ability to write `.flat_map(f)` instead of chaining. /// - /// The provided closure which produce inner streams is executed over all elements + /// The provided closure which produces inner streams is executed over all elements /// of stream as last inner stream is terminated and next stream item is available. /// /// Note that this function consumes the stream passed into it and returns a @@ -788,6 +851,59 @@ pub trait StreamExt: Stream { assert_stream::(FlatMap::new(self, f)) } + /// Maps a stream like [`StreamExt::map`] but flattens nested `Stream`s + /// and polls them concurrently, yielding items in any order, as they made + /// available. + /// + /// [`StreamExt::map`] is very useful, but if it produces `Stream`s + /// instead, and you need to poll all of them concurrently, you would + /// have to use something like `for_each_concurrent` and merge values + /// by hand. This combinator provides ability to collect all values + /// from concurrently polled streams into one stream. + /// + /// The first argument is an optional limit on the number of concurrently + /// polled streams. If this limit is not `None`, no more than `limit` streams + /// will be polled concurrently. The `limit` argument is of type + /// `Into>`, and so can be provided as either `None`, + /// `Some(10)`, or just `10`. Note: a limit of zero is interpreted as + /// no limit at all, and will have the same result as passing in `None`. + /// + /// The provided closure which produces inner streams is executed over + /// all elements of stream as next stream item is available and limit + /// of concurrently processed streams isn't exceeded. + /// + /// Note that this function consumes the stream passed into it and + /// returns a wrapped version of it. + /// + /// # Examples + /// + /// ``` + /// # futures::executor::block_on(async { + /// use futures::stream::{self, StreamExt}; + /// + /// let stream = stream::iter(1..5); + /// let stream = stream.flat_map_unordered(1, |x| stream::iter(vec![x; x])); + /// let mut values = stream.collect::>().await; + /// values.sort(); + /// + /// assert_eq!(vec![1usize, 2, 2, 3, 3, 3, 4, 4, 4, 4], values); + /// # }); + /// ``` + #[cfg(not(futures_no_atomic_cas))] + #[cfg(feature = "alloc")] + fn flat_map_unordered( + self, + limit: impl Into>, + f: F, + ) -> FlatMapUnordered + where + U: Stream + Unpin, + F: FnMut(Self::Item) -> U, + Self: Sized, + { + FlatMapUnordered::new(self, limit.into(), f) + } + /// Combinator similar to [`StreamExt::fold`] that holds internal state /// and produces a new stream. /// diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 0d453d1752..71ec654bfb 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -1,10 +1,14 @@ +use std::iter; +use std::sync::Arc; + use futures::channel::mpsc; use futures::executor::block_on; use futures::future::{self, Future}; +use futures::lock::Mutex; use futures::sink::SinkExt; use futures::stream::{self, StreamExt}; use futures::task::Poll; -use futures::FutureExt; +use futures::{ready, FutureExt}; use futures_test::task::noop_context; #[test] @@ -49,6 +53,272 @@ fn scan() { }); } +#[test] +fn flatten_unordered() { + use futures::executor::block_on; + use futures::stream::*; + use futures::task::*; + use std::convert::identity; + use std::pin::Pin; + use std::thread; + use std::time::Duration; + + struct DataStream { + data: Vec, + polled: bool, + wake_immediately: bool, + } + + impl Stream for DataStream { + type Item = u8; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + if !self.polled { + if !self.wake_immediately { + let waker = ctx.waker().clone(); + let sleep_time = + Duration::from_millis(*self.data.first().unwrap_or(&0) as u64 / 10); + thread::spawn(move || { + thread::sleep(sleep_time); + waker.wake_by_ref(); + }); + } else { + ctx.waker().wake_by_ref(); + } + self.polled = true; + Poll::Pending + } else { + self.polled = false; + Poll::Ready(self.data.pop()) + } + } + } + + struct Interchanger { + polled: bool, + base: u8, + wake_immediately: bool, + } + + impl Stream for Interchanger { + type Item = DataStream; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + if !self.polled { + self.polled = true; + if !self.wake_immediately { + let waker = ctx.waker().clone(); + let sleep_time = Duration::from_millis(self.base as u64); + thread::spawn(move || { + thread::sleep(sleep_time); + waker.wake_by_ref(); + }); + } else { + ctx.waker().wake_by_ref(); + } + Poll::Pending + } else { + let data: Vec<_> = (0..6).rev().map(|v| v + self.base * 6).collect(); + self.base += 1; + self.polled = false; + Poll::Ready(Some(DataStream { + polled: false, + data, + wake_immediately: self.wake_immediately && self.base % 2 == 0, + })) + } + } + } + + // basic behaviour + { + block_on(async { + let st = stream::iter(vec![ + stream::iter(0..=4u8), + stream::iter(6..=10), + stream::iter(10..=12), + ]); + + let fl_unordered = st.flatten_unordered(3).collect::>().await; + + assert_eq!(fl_unordered, vec![0, 6, 10, 1, 7, 11, 2, 8, 12, 3, 9, 4, 10]); + }); + + block_on(async { + let st = stream::iter(vec![ + stream::iter(0..=4u8), + stream::iter(6..=10), + stream::iter(0..=2), + ]); + + let mut fm_unordered = st + .flat_map_unordered(1, |s| s.filter(|v| futures::future::ready(v % 2 == 0))) + .collect::>() + .await; + + fm_unordered.sort_unstable(); + + assert_eq!(fm_unordered, vec![0, 0, 2, 2, 4, 6, 8, 10]); + }); + } + + // wake up immediately + { + block_on(async { + let mut fl_unordered = Interchanger { polled: false, base: 0, wake_immediately: true } + .take(10) + .map(|s| s.map(identity)) + .flatten_unordered(10) + .collect::>() + .await; + + fl_unordered.sort_unstable(); + + assert_eq!(fl_unordered, (0..60).collect::>()); + }); + + block_on(async { + let mut fm_unordered = Interchanger { polled: false, base: 0, wake_immediately: true } + .take(10) + .flat_map_unordered(10, |s| s.map(identity)) + .collect::>() + .await; + + fm_unordered.sort_unstable(); + + assert_eq!(fm_unordered, (0..60).collect::>()); + }); + } + + // wake up after delay + { + block_on(async { + let mut fl_unordered = Interchanger { polled: false, base: 0, wake_immediately: false } + .take(10) + .map(|s| s.map(identity)) + .flatten_unordered(10) + .collect::>() + .await; + + fl_unordered.sort_unstable(); + + assert_eq!(fl_unordered, (0..60).collect::>()); + }); + + block_on(async { + let mut fm_unordered = Interchanger { polled: false, base: 0, wake_immediately: false } + .take(10) + .flat_map_unordered(10, |s| s.map(identity)) + .collect::>() + .await; + + fm_unordered.sort_unstable(); + + assert_eq!(fm_unordered, (0..60).collect::>()); + }); + + block_on(async { + let (mut fm_unordered, mut fl_unordered) = futures_util::join!( + Interchanger { polled: false, base: 0, wake_immediately: false } + .take(10) + .flat_map_unordered(10, |s| s.map(identity)) + .collect::>(), + Interchanger { polled: false, base: 0, wake_immediately: false } + .take(10) + .map(|s| s.map(identity)) + .flatten_unordered(10) + .collect::>() + ); + + fm_unordered.sort_unstable(); + fl_unordered.sort_unstable(); + + assert_eq!(fm_unordered, fl_unordered); + assert_eq!(fm_unordered, (0..60).collect::>()); + }); + } + + // waker panics + { + let stream = Arc::new(Mutex::new( + Interchanger { polled: false, base: 0, wake_immediately: true } + .take(10) + .flat_map_unordered(10, |s| s.map(identity)), + )); + + struct PanicWaker; + + impl ArcWake for PanicWaker { + fn wake_by_ref(_arc_self: &Arc) { + panic!("WAKE UP"); + } + } + + std::thread::spawn({ + let stream = stream.clone(); + move || { + let mut st = poll_fn(|cx| { + let mut lock = ready!(stream.lock().poll_unpin(cx)); + + let panic_waker = waker(Arc::new(PanicWaker)); + let mut panic_cx = Context::from_waker(&panic_waker); + let _ = ready!(lock.poll_next_unpin(&mut panic_cx)); + + Poll::Ready(Some(())) + }); + + block_on(st.next()) + } + }) + .join() + .unwrap_err(); + + block_on(async move { + let mut values: Vec<_> = stream.lock().await.by_ref().collect().await; + values.sort_unstable(); + + assert_eq!(values, (0..60).collect::>()); + }); + } + + // stream panics + { + let st = stream::iter(iter::once( + once(Box::pin(async { panic!("Polled") })).left_stream::(), + )) + .chain( + Interchanger { polled: false, base: 0, wake_immediately: true } + .map(|stream| stream.right_stream()) + .take(10), + ); + + let stream = Arc::new(Mutex::new(st.flatten_unordered(10))); + + std::thread::spawn({ + let stream = stream.clone(); + move || { + let mut st = poll_fn(|cx| { + let mut lock = ready!(stream.lock().poll_unpin(cx)); + let data = ready!(lock.poll_next_unpin(cx)); + + Poll::Ready(data) + }); + + block_on(st.next()) + } + }) + .join() + .unwrap_err(); + + block_on(async move { + let mut values: Vec<_> = stream.lock().await.by_ref().collect().await; + values.sort_unstable(); + + assert_eq!(values, (0..60).collect::>()); + }); + } +} + #[test] fn take_until() { fn make_stop_fut(stop_on: u32) -> impl Future {