Skip to content

Commit

Permalink
fu: always replace inner wakers
Browse files Browse the repository at this point in the history
  • Loading branch information
olegnn committed Mar 20, 2023
1 parent 8253b78 commit d94765b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
18 changes: 6 additions & 12 deletions futures-util/src/stream/stream/flatten_unordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,13 @@ where
}
};

// Safety: now state is `POLLING`.
let stream_waker = unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) };
let inner_streams_waker =
unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) };

if poll_state_value & NEED_TO_POLL_STREAM != NONE {
let mut stream_waker = None;
let mut cx = Context::from_waker(&stream_waker);

// Here we need to poll the base stream.
//
Expand All @@ -431,14 +436,6 @@ where

break;
} else {
// Initialize base stream waker if it's not yet initialized
if stream_waker.is_none() {
// Safety: now state is `POLLING`.
stream_waker
.replace(unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) });
}
let mut cx = Context::from_waker(stream_waker.as_ref().unwrap());

match this.stream.as_mut().poll_next(&mut cx) {
Poll::Ready(Some(item)) => {
let next_item_fut = match Fc::next_step(item) {
Expand Down Expand Up @@ -475,9 +472,6 @@ where
}

if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
// Safety: now state is `POLLING`.
let inner_streams_waker =
unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) };
let mut cx = Context::from_waker(&inner_streams_waker);

match this.inner_streams.as_mut().poll_next(&mut cx) {
Expand Down
41 changes: 41 additions & 0 deletions futures/tests/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use futures::stream::{self, StreamExt};
use futures::task::Poll;
use futures::{ready, FutureExt};
use futures_core::Stream;
use futures_executor::ThreadPool;
use futures_test::task::noop_context;

#[test]
Expand Down Expand Up @@ -82,6 +83,7 @@ fn flatten_unordered() {
use futures::task::*;
use std::convert::identity;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;

Expand Down Expand Up @@ -339,6 +341,45 @@ fn flatten_unordered() {
assert_eq!(values, (0..60).collect::<Vec<u8>>());
});
}

// nested `flatten_unordered`
let te = ThreadPool::new().unwrap();
let handle = te.spawn_with_handle(async move {
let inner = stream::iter(0..10)
.then(|_| {
let task = Arc::new(AtomicBool::new(false));
let mut spawned = false;

future::poll_fn(move |cx| {
if !spawned {
let waker = cx.waker().clone();
let task = task.clone();

std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(500));
task.store(true, Ordering::Release);

waker.wake_by_ref()
});
spawned = true;
}

if task.load(Ordering::Acquire) {
Poll::Ready(Some(()))
} else {
Poll::Pending
}
})
})
.map(|_| stream::once(future::ready(())))
.flatten_unordered(None);

let stream = stream::once(future::ready(inner)).flatten_unordered(None);

assert_eq!(stream.count().await, 10);
}).unwrap();

block_on(handle);
}

#[test]
Expand Down

0 comments on commit d94765b

Please sign in to comment.