Skip to content

Commit

Permalink
Don't ignore empty state polling (#2728)
Browse files Browse the repository at this point in the history
* Don't ignore empty state polling

* Test case

* Start polling in a loop to ensure we don't wait for an outdated waker
  • Loading branch information
olegnn authored and taiki-e committed Mar 30, 2023
1 parent 472f556 commit 8e24adc
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 39 deletions.
13 changes: 6 additions & 7 deletions futures-util/src/stream/stream/flatten_unordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ impl SharedPollState {
}

/// Attempts to start polling, returning stored state in case of success.
/// Returns `None` if either waker is waking at the moment or state is empty.
/// Returns `None` if either waker is waking at the moment.
fn start_polling(
&self,
) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
let value = self
.state
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
if value & WAKING == NONE && value & NEED_TO_POLL_ALL != NONE {
if value & WAKING == NONE {
Some(POLLING)
} else {
None
Expand Down Expand Up @@ -405,11 +405,10 @@ where

let mut this = self.as_mut().project();

let (mut poll_state_value, state_bomb) = match this.poll_state.start_polling() {
Some(value) => value,
_ => {
// Waker was called, just wait for the next poll
return Poll::Pending;
// Attempt to start polling, in case some waker is holding the lock, wait in loop
let (mut poll_state_value, state_bomb) = loop {
if let Some(value) = this.poll_state.start_polling() {
break value;
}
};

Expand Down
95 changes: 63 additions & 32 deletions futures/tests/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,46 +325,77 @@ fn flatten_unordered() {
});
}

fn timeout<I: Clone>(time: Duration, value: I) -> impl Future<Output = I> {
let ready = Arc::new(AtomicBool::new(false));
let mut spawned = false;

future::poll_fn(move |cx| {
if !spawned {
let waker = cx.waker().clone();
let ready = ready.clone();

std::thread::spawn(move || {
std::thread::sleep(time);
ready.store(true, Ordering::Release);

waker.wake_by_ref()
});
spawned = true;
}

if ready.load(Ordering::Acquire) {
Poll::Ready(value.clone())
} else {
Poll::Pending
}
})
}

fn build_nested_fu<S: Stream + Unpin>(st: S) -> impl Stream<Item = S::Item> + Unpin
where
S::Item: Clone,
{
let inner = st
.then(|item| timeout(Duration::from_millis(50), item))
.enumerate()
.map(|(idx, value)| {
stream::once(if idx % 2 == 0 {
future::ready(value).left_future()
} else {
timeout(Duration::from_millis(100), value).right_future()
})
})
.flatten_unordered(None);

stream::once(future::ready(inner)).flatten_unordered(None)
}

// nested `flatten_unordered`
let te = ThreadPool::new().unwrap();
let handle = te
let base_handle = te
.spawn_with_handle(async move {
let inner = stream::iter(0..10)
.then(|_| {
let task = Arc::new(AtomicBool::new(false));
let mut spawned = false;

future::poll_fn(move |cx| {
if !spawned {
let waker = cx.waker().clone();
let task = task.clone();

std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(500));
task.store(true, Ordering::Release);

waker.wake_by_ref()
});
spawned = true;
}

if task.load(Ordering::Acquire) {
Poll::Ready(Some(()))
} else {
Poll::Pending
}
})
})
.map(|_| stream::once(future::ready(())))
.flatten_unordered(None);
let fu = build_nested_fu(stream::iter(1..=10));

let stream = stream::once(future::ready(inner)).flatten_unordered(None);
assert_eq!(fu.count().await, 10);
})
.unwrap();

block_on(base_handle);

let empty_state_move_handle = te
.spawn_with_handle(async move {
let mut fu = build_nested_fu(stream::iter(1..10));
{
let mut cx = noop_context();
let _ = fu.poll_next_unpin(&mut cx);
let _ = fu.poll_next_unpin(&mut cx);
}

assert_eq!(stream.count().await, 10);
assert_eq!(fu.count().await, 9);
})
.unwrap();

block_on(handle);
block_on(empty_state_move_handle);
}

#[test]
Expand Down

0 comments on commit 8e24adc

Please sign in to comment.