From d94765b9faaaeabd6700bb6b47e762309b32ba19 Mon Sep 17 00:00:00 2001 From: olegnn Date: Mon, 20 Mar 2023 21:05:00 +0100 Subject: [PATCH 1/3] fu: always replace inner wakers --- .../src/stream/stream/flatten_unordered.rs | 18 +++----- futures/tests/stream.rs | 41 +++++++++++++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/futures-util/src/stream/stream/flatten_unordered.rs b/futures-util/src/stream/stream/flatten_unordered.rs index 484c3733aa..43b4d092c4 100644 --- a/futures-util/src/stream/stream/flatten_unordered.rs +++ b/futures-util/src/stream/stream/flatten_unordered.rs @@ -414,8 +414,13 @@ where } }; + // Safety: now state is `POLLING`. + let stream_waker = unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) }; + let inner_streams_waker = + unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) }; + if poll_state_value & NEED_TO_POLL_STREAM != NONE { - let mut stream_waker = None; + let mut cx = Context::from_waker(&stream_waker); // Here we need to poll the base stream. // @@ -431,14 +436,6 @@ where break; } else { - // Initialize base stream waker if it's not yet initialized - if stream_waker.is_none() { - // Safety: now state is `POLLING`. - stream_waker - .replace(unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) }); - } - let mut cx = Context::from_waker(stream_waker.as_ref().unwrap()); - match this.stream.as_mut().poll_next(&mut cx) { Poll::Ready(Some(item)) => { let next_item_fut = match Fc::next_step(item) { @@ -475,9 +472,6 @@ where } if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE { - // Safety: now state is `POLLING`. - let inner_streams_waker = - unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) }; let mut cx = Context::from_waker(&inner_streams_waker); match this.inner_streams.as_mut().poll_next(&mut cx) { diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index 6a302b798e..c5e0bd4ac2 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -14,6 +14,7 @@ use futures::stream::{self, StreamExt}; use futures::task::Poll; use futures::{ready, FutureExt}; use futures_core::Stream; +use futures_executor::ThreadPool; use futures_test::task::noop_context; #[test] @@ -82,6 +83,7 @@ fn flatten_unordered() { use futures::task::*; use std::convert::identity; use std::pin::Pin; + use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; use std::time::Duration; @@ -339,6 +341,45 @@ fn flatten_unordered() { assert_eq!(values, (0..60).collect::>()); }); } + + // nested `flatten_unordered` + let te = ThreadPool::new().unwrap(); + let handle = te.spawn_with_handle(async move { + let inner = stream::iter(0..10) + .then(|_| { + let task = Arc::new(AtomicBool::new(false)); + let mut spawned = false; + + future::poll_fn(move |cx| { + if !spawned { + let waker = cx.waker().clone(); + let task = task.clone(); + + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(500)); + task.store(true, Ordering::Release); + + waker.wake_by_ref() + }); + spawned = true; + } + + if task.load(Ordering::Acquire) { + Poll::Ready(Some(())) + } else { + Poll::Pending + } + }) + }) + .map(|_| stream::once(future::ready(()))) + .flatten_unordered(None); + + let stream = stream::once(future::ready(inner)).flatten_unordered(None); + + assert_eq!(stream.count().await, 10); + }).unwrap(); + + block_on(handle); } #[test] From 7e49a401c999f5371b9d9d2c74a89941a19b689f Mon Sep 17 00:00:00 2001 From: olegnn Date: Mon, 20 Mar 2023 21:39:31 +0100 Subject: [PATCH 2/3] Fix tests --- futures/tests/no-std/src/lib.rs | 1 + futures/tests/stream.rs | 64 +++++++++++++++++---------------- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/futures/tests/no-std/src/lib.rs b/futures/tests/no-std/src/lib.rs index 89a8fa1ff1..4e2196860b 100644 --- a/futures/tests/no-std/src/lib.rs +++ b/futures/tests/no-std/src/lib.rs @@ -1,5 +1,6 @@ #![cfg(nightly)] #![no_std] +#![allow(useless_anonymous_reexport)] #[cfg(feature = "futures-core-alloc")] #[cfg(target_has_atomic = "ptr")] diff --git a/futures/tests/stream.rs b/futures/tests/stream.rs index c5e0bd4ac2..83f49579bc 100644 --- a/futures/tests/stream.rs +++ b/futures/tests/stream.rs @@ -344,40 +344,42 @@ fn flatten_unordered() { // nested `flatten_unordered` let te = ThreadPool::new().unwrap(); - let handle = te.spawn_with_handle(async move { - let inner = stream::iter(0..10) - .then(|_| { - let task = Arc::new(AtomicBool::new(false)); - let mut spawned = false; - - future::poll_fn(move |cx| { - if !spawned { - let waker = cx.waker().clone(); - let task = task.clone(); - - std::thread::spawn(move || { - std::thread::sleep(Duration::from_millis(500)); - task.store(true, Ordering::Release); - - waker.wake_by_ref() - }); - spawned = true; - } - - if task.load(Ordering::Acquire) { - Poll::Ready(Some(())) - } else { - Poll::Pending - } + let handle = te + .spawn_with_handle(async move { + let inner = stream::iter(0..10) + .then(|_| { + let task = Arc::new(AtomicBool::new(false)); + let mut spawned = false; + + future::poll_fn(move |cx| { + if !spawned { + let waker = cx.waker().clone(); + let task = task.clone(); + + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(500)); + task.store(true, Ordering::Release); + + waker.wake_by_ref() + }); + spawned = true; + } + + if task.load(Ordering::Acquire) { + Poll::Ready(Some(())) + } else { + Poll::Pending + } + }) }) - }) - .map(|_| stream::once(future::ready(()))) - .flatten_unordered(None); + .map(|_| stream::once(future::ready(()))) + .flatten_unordered(None); - let stream = stream::once(future::ready(inner)).flatten_unordered(None); + let stream = stream::once(future::ready(inner)).flatten_unordered(None); - assert_eq!(stream.count().await, 10); - }).unwrap(); + assert_eq!(stream.count().await, 10); + }) + .unwrap(); block_on(handle); } From d4413458cde9361cc1774d82300f424f9cf41ca3 Mon Sep 17 00:00:00 2001 From: olegnn Date: Tue, 21 Mar 2023 00:21:38 +0100 Subject: [PATCH 3/3] Avoid unnecessary clones --- .../src/stream/stream/flatten_unordered.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/futures-util/src/stream/stream/flatten_unordered.rs b/futures-util/src/stream/stream/flatten_unordered.rs index 43b4d092c4..f5430bc309 100644 --- a/futures-util/src/stream/stream/flatten_unordered.rs +++ b/futures-util/src/stream/stream/flatten_unordered.rs @@ -209,9 +209,8 @@ impl WrappedWaker { /// /// This function will modify waker's `inner_waker` via `UnsafeCell`, so /// it should be used only during `POLLING` phase by one thread at the time. - unsafe fn replace_waker(self_arc: &mut Arc, cx: &Context<'_>) -> Waker { + unsafe fn replace_waker(self_arc: &mut Arc, cx: &Context<'_>) { *self_arc.inner_waker.get() = cx.waker().clone().into(); - waker(self_arc.clone()) } /// Attempts to start the waking process for the waker with the given value. @@ -415,12 +414,13 @@ where }; // Safety: now state is `POLLING`. - let stream_waker = unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) }; - let inner_streams_waker = - unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) }; + unsafe { + WrappedWaker::replace_waker(this.stream_waker, cx); + WrappedWaker::replace_waker(this.inner_streams_waker, cx) + }; if poll_state_value & NEED_TO_POLL_STREAM != NONE { - let mut cx = Context::from_waker(&stream_waker); + let mut stream_waker = None; // Here we need to poll the base stream. // @@ -436,6 +436,10 @@ where break; } else { + let mut cx = Context::from_waker( + stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())), + ); + match this.stream.as_mut().poll_next(&mut cx) { Poll::Ready(Some(item)) => { let next_item_fut = match Fc::next_step(item) { @@ -472,6 +476,7 @@ where } if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE { + let inner_streams_waker = waker(this.inner_streams_waker.clone()); let mut cx = Context::from_waker(&inner_streams_waker); match this.inner_streams.as_mut().poll_next(&mut cx) {