From a730a19c8f933254a9d571a153275e3d96391ab3 Mon Sep 17 00:00:00 2001 From: Oleg Nosov Date: Wed, 22 Mar 2023 16:50:45 +0100 Subject: [PATCH] FlattenUnordered: always replace inner wakers (#2726) --- .../src/stream/stream/flatten_unordered.rs | 23 +++++----- futures/tests/no-std/src/lib.rs | 1 + futures/tests/stream.rs | 43 +++++++++++++++++++ 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/futures-util/src/stream/stream/flatten_unordered.rs b/futures-util/src/stream/stream/flatten_unordered.rs index 484c3733aa..f5430bc309 100644 --- a/futures-util/src/stream/stream/flatten_unordered.rs +++ b/futures-util/src/stream/stream/flatten_unordered.rs @@ -209,9 +209,8 @@ impl WrappedWaker { /// /// This function will modify waker's `inner_waker` via `UnsafeCell`, so /// it should be used only during `POLLING` phase by one thread at the time. - unsafe fn replace_waker(self_arc: &mut Arc, cx: &Context<'_>) -> Waker { + unsafe fn replace_waker(self_arc: &mut Arc, cx: &Context<'_>) { *self_arc.inner_waker.get() = cx.waker().clone().into(); - waker(self_arc.clone()) } /// Attempts to start the waking process for the waker with the given value. @@ -414,6 +413,12 @@ where } }; + // Safety: now state is `POLLING`. + unsafe { + WrappedWaker::replace_waker(this.stream_waker, cx); + WrappedWaker::replace_waker(this.inner_streams_waker, cx) + }; + if poll_state_value & NEED_TO_POLL_STREAM != NONE { let mut stream_waker = None; @@ -431,13 +436,9 @@ where break; } else { - // Initialize base stream waker if it's not yet initialized - if stream_waker.is_none() { - // Safety: now state is `POLLING`. - stream_waker - .replace(unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) }); - } - let mut cx = Context::from_waker(stream_waker.as_ref().unwrap()); + let mut cx = Context::from_waker( + stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())), + ); match this.stream.as_mut().poll_next(&mut cx) { Poll::Ready(Some(item)) => { @@ -475,9 +476,7 @@ where } if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE { - // Safety: now state is `POLLING`. - let inner_streams_waker = - unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) }; + let inner_streams_waker = waker(this.inner_streams_waker.clone()); let mut cx = Context::from_waker(&inner_streams_waker); match this.inner_streams.as_mut().poll_next(&mut cx) { diff --git a/futures/tests/no-std/src/lib.rs b/futures/tests/no-std/src/lib.rs index 89a8fa1ff1..4e2196860b 100644 --- a/futures/tests/no-std/src/lib.rs +++ b/futures/tests/no-std/src/lib.rs @@ -1,5 +1,6 @@ #![cfg(nightly)] #![no_std] +#![allow(useless_anonymous_reexport)] #[cfg(feature = "futures-core-alloc")] #[cfg(target_has_atomic = "ptr")] diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 5cde45833f..9d61cb60c8 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -14,6 +14,7 @@ use futures::stream::{self, StreamExt}; use futures::task::Poll; use futures::{ready, FutureExt}; use futures_core::Stream; +use futures_executor::ThreadPool; use futures_test::task::noop_context; #[test] @@ -65,6 +66,7 @@ fn flatten_unordered() { use futures::task::*; use std::convert::identity; use std::pin::Pin; + use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; use std::time::Duration; @@ -322,6 +324,47 @@ fn flatten_unordered() { assert_eq!(values, (0..60).collect::>()); }); } + + // nested `flatten_unordered` + let te = ThreadPool::new().unwrap(); + let handle = te + .spawn_with_handle(async move { + let inner = stream::iter(0..10) + .then(|_| { + let task = Arc::new(AtomicBool::new(false)); + let mut spawned = false; + + future::poll_fn(move |cx| { + if !spawned { + let waker = cx.waker().clone(); + let task = task.clone(); + + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(500)); + task.store(true, Ordering::Release); + + waker.wake_by_ref() + }); + spawned = true; + } + + if task.load(Ordering::Acquire) { + Poll::Ready(Some(())) + } else { + Poll::Pending + } + }) + }) + .map(|_| stream::once(future::ready(()))) + .flatten_unordered(None); + + let stream = stream::once(future::ready(inner)).flatten_unordered(None); + + assert_eq!(stream.count().await, 10); + }) + .unwrap(); + + block_on(handle); } #[test]