From d94765b9faaaeabd6700bb6b47e762309b32ba19 Mon Sep 17 00:00:00 2001 From: olegnn Date: Mon, 20 Mar 2023 21:05:00 +0100 Subject: [PATCH] fu: always replace inner wakers --- .../src/stream/stream/flatten_unordered.rs | 18 +++----- futures/tests/stream.rs | 41 +++++++++++++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/futures-util/src/stream/stream/flatten_unordered.rs b/futures-util/src/stream/stream/flatten_unordered.rs index 484c3733aa..43b4d092c4 100644 --- a/futures-util/src/stream/stream/flatten_unordered.rs +++ b/futures-util/src/stream/stream/flatten_unordered.rs @@ -414,8 +414,13 @@ where } }; + // Safety: now state is `POLLING`. + let stream_waker = unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) }; + let inner_streams_waker = + unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) }; + if poll_state_value & NEED_TO_POLL_STREAM != NONE { - let mut stream_waker = None; + let mut cx = Context::from_waker(&stream_waker); // Here we need to poll the base stream. // @@ -431,14 +436,6 @@ where break; } else { - // Initialize base stream waker if it's not yet initialized - if stream_waker.is_none() { - // Safety: now state is `POLLING`. - stream_waker - .replace(unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) }); - } - let mut cx = Context::from_waker(stream_waker.as_ref().unwrap()); - match this.stream.as_mut().poll_next(&mut cx) { Poll::Ready(Some(item)) => { let next_item_fut = match Fc::next_step(item) { @@ -475,9 +472,6 @@ where } if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE { - // Safety: now state is `POLLING`. - let inner_streams_waker = - unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) }; let mut cx = Context::from_waker(&inner_streams_waker); match this.inner_streams.as_mut().poll_next(&mut cx) { diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 6a302b798e..c5e0bd4ac2 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -14,6 +14,7 @@ use futures::stream::{self, StreamExt}; use futures::task::Poll; use futures::{ready, FutureExt}; use futures_core::Stream; +use futures_executor::ThreadPool; use futures_test::task::noop_context; #[test] @@ -82,6 +83,7 @@ fn flatten_unordered() { use futures::task::*; use std::convert::identity; use std::pin::Pin; + use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; use std::time::Duration; @@ -339,6 +341,45 @@ fn flatten_unordered() { assert_eq!(values, (0..60).collect::>()); }); } + + // nested `flatten_unordered` + let te = ThreadPool::new().unwrap(); + let handle = te.spawn_with_handle(async move { + let inner = stream::iter(0..10) + .then(|_| { + let task = Arc::new(AtomicBool::new(false)); + let mut spawned = false; + + future::poll_fn(move |cx| { + if !spawned { + let waker = cx.waker().clone(); + let task = task.clone(); + + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(500)); + task.store(true, Ordering::Release); + + waker.wake_by_ref() + }); + spawned = true; + } + + if task.load(Ordering::Acquire) { + Poll::Ready(Some(())) + } else { + Poll::Pending + } + }) + }) + .map(|_| stream::once(future::ready(()))) + .flatten_unordered(None); + + let stream = stream::once(future::ready(inner)).flatten_unordered(None); + + assert_eq!(stream.count().await, 10); + }).unwrap(); + + block_on(handle); } #[test]