Skip to content

Commit a03ac56

Browse files
Zekun LiZekun Li
Zekun Li
authored and
Zekun Li
committed
Shared: fix shared futures missing wake up
1 parent 7211cb7 commit a03ac56

File tree

2 files changed

+116
-48
lines changed

2 files changed

+116
-48
lines changed

futures-util/src/future/future/shared.rs

+64-47
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::pin::Pin;
99
use std::ptr;
1010
use std::sync::atomic::AtomicUsize;
1111
use std::sync::atomic::Ordering::{Acquire, SeqCst};
12-
use std::sync::{Arc, Mutex, Weak};
12+
use std::sync::{Arc, Mutex, MutexGuard, Weak};
1313

1414
/// Future for the [`shared`](super::FutureExt::shared) method.
1515
#[must_use = "futures do nothing unless you `.await` or poll them"]
@@ -81,6 +81,7 @@ const IDLE: usize = 0;
8181
const POLLING: usize = 1;
8282
const COMPLETE: usize = 2;
8383
const POISONED: usize = 3;
84+
const WOKEN_DURING_POLLING: usize = 4;
8485

8586
const NULL_WAKER_KEY: usize = usize::MAX;
8687

@@ -197,36 +198,47 @@ where
197198
}
198199
}
199200

200-
impl<Fut> Inner<Fut>
201-
where
202-
Fut: Future,
203-
Fut::Output: Clone,
204-
{
205-
/// Registers the current task to receive a wakeup when we are awoken.
206-
fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
207-
let mut wakers_guard = self.notifier.wakers.lock().unwrap();
208-
209-
let wakers_mut = wakers_guard.as_mut();
210-
211-
let wakers = match wakers_mut {
212-
Some(wakers) => wakers,
213-
None => return,
214-
};
215-
216-
let new_waker = cx.waker();
201+
/// Registers the current task to receive a wakeup when we are awoken.
202+
fn record_waker(
203+
wakers_guard: &mut MutexGuard<'_, Option<Slab<Option<Waker>>>>,
204+
waker_key: &mut usize,
205+
cx: &mut Context<'_>,
206+
) {
207+
let wakers = match wakers_guard.as_mut() {
208+
Some(wakers) => wakers,
209+
None => return,
210+
};
211+
212+
let new_waker = cx.waker();
213+
214+
if *waker_key == NULL_WAKER_KEY {
215+
*waker_key = wakers.insert(Some(new_waker.clone()));
216+
} else {
217+
match wakers[*waker_key] {
218+
Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
219+
// Could use clone_from here, but Waker doesn't specialize it.
220+
ref mut slot => *slot = Some(new_waker.clone()),
221+
}
222+
}
223+
debug_assert!(*waker_key != NULL_WAKER_KEY);
224+
}
217225

218-
if *waker_key == NULL_WAKER_KEY {
219-
*waker_key = wakers.insert(Some(new_waker.clone()));
220-
} else {
221-
match wakers[*waker_key] {
222-
Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
223-
// Could use clone_from here, but Waker doesn't specialize it.
224-
ref mut slot => *slot = Some(new_waker.clone()),
226+
/// Wakes all tasks that are registered to be woken.
227+
fn wake_all(waker_guard: &mut MutexGuard<'_, Option<Slab<Option<Waker>>>>) {
228+
if let Some(wakers) = waker_guard.as_mut() {
229+
for (_key, opt_waker) in wakers {
230+
if let Some(waker) = opt_waker.take() {
231+
waker.wake();
225232
}
226233
}
227-
debug_assert!(*waker_key != NULL_WAKER_KEY);
228234
}
235+
}
229236

237+
impl<Fut> Inner<Fut>
238+
where
239+
Fut: Future,
240+
Fut::Output: Clone,
241+
{
230242
/// Safety: callers must first ensure that `inner.state`
231243
/// is `COMPLETE`
232244
unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
@@ -268,18 +280,22 @@ where
268280
return unsafe { Poll::Ready(inner.take_or_clone_output()) };
269281
}
270282

271-
inner.record_waker(&mut this.waker_key, cx);
283+
// Guard the state transition with mutex too
284+
let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
285+
record_waker(&mut wakers_guard, &mut this.waker_key, cx);
272286

273-
match inner
287+
let prev = inner
274288
.notifier
275289
.state
276290
.compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
277-
.unwrap_or_else(|x| x)
278-
{
291+
.unwrap_or_else(|x| x);
292+
drop(wakers_guard);
293+
294+
match prev {
279295
IDLE => {
280296
// Lock acquired, fall through
281297
}
282-
POLLING => {
298+
POLLING | WOKEN_DURING_POLLING => {
283299
// Another task is currently polling, at this point we just want
284300
// to ensure that the waker for this task is registered
285301
this.inner = Some(inner);
@@ -324,15 +340,21 @@ where
324340

325341
match poll_result {
326342
Poll::Pending => {
327-
if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
328-
{
329-
// Success
330-
drop(reset);
331-
this.inner = Some(inner);
332-
return Poll::Pending;
333-
} else {
334-
unreachable!()
343+
match inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst) {
344+
Ok(POLLING) => {} // success
345+
Err(WOKEN_DURING_POLLING) => {
346+
// waker has been called inside future.poll, need to wake any new wakers registered
347+
let mut wakers = inner.notifier.wakers.lock().unwrap();
348+
wake_all(&mut wakers);
349+
let prev = inner.notifier.state.swap(IDLE, SeqCst);
350+
assert_eq!(prev, WOKEN_DURING_POLLING);
351+
drop(wakers);
352+
}
353+
_ => unreachable!(),
335354
}
355+
drop(reset);
356+
this.inner = Some(inner);
357+
return Poll::Pending;
336358
}
337359
Poll::Ready(output) => output,
338360
}
@@ -387,14 +409,9 @@ where
387409

388410
impl ArcWake for Notifier {
389411
fn wake_by_ref(arc_self: &Arc<Self>) {
390-
let wakers = &mut *arc_self.wakers.lock().unwrap();
391-
if let Some(wakers) = wakers.as_mut() {
392-
for (_key, opt_waker) in wakers {
393-
if let Some(waker) = opt_waker.take() {
394-
waker.wake();
395-
}
396-
}
397-
}
412+
let mut wakers = arc_self.wakers.lock().unwrap();
413+
let _ = arc_self.state.compare_exchange(POLLING, WOKEN_DURING_POLLING, SeqCst, SeqCst);
414+
wake_all(&mut wakers);
398415
}
399416
}
400417

futures/tests/future_shared.rs

+52-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ use futures::executor::{block_on, LocalPool};
33
use futures::future::{self, FutureExt, LocalFutureObj, TryFutureExt};
44
use futures::task::LocalSpawn;
55
use std::cell::{Cell, RefCell};
6+
use std::future::Future;
67
use std::panic::AssertUnwindSafe;
8+
use std::pin::Pin;
79
use std::rc::Rc;
8-
use std::task::Poll;
10+
use std::task::{Context, Poll};
911
use std::thread;
1012

1113
struct CountClone(Rc<Cell<i32>>);
@@ -271,3 +273,52 @@ fn poll_while_panic() {
271273
let _s = S {};
272274
panic!("test_marker");
273275
}
276+
277+
#[test]
278+
fn shared_futures_woken_during_polling() {
279+
async fn yield_now() {
280+
/// Yield implementation
281+
struct YieldNow {
282+
yielded: bool,
283+
}
284+
285+
impl Future for YieldNow {
286+
type Output = ();
287+
288+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
289+
if self.yielded {
290+
return Poll::Ready(());
291+
}
292+
293+
self.yielded = true;
294+
cx.waker().wake_by_ref();
295+
Poll::Pending
296+
}
297+
}
298+
299+
YieldNow { yielded: false }.await
300+
}
301+
fn test() {
302+
let f1 = yield_now().shared();
303+
let f2 = f1.clone();
304+
let x1 = thread::spawn(move || {
305+
block_on(async move {
306+
f1.now_or_never();
307+
})
308+
});
309+
let x2 = thread::spawn(move || {
310+
block_on(async move {
311+
f2.await;
312+
})
313+
});
314+
x1.join().ok();
315+
x2.join().ok();
316+
}
317+
318+
for _ in 0..10 {
319+
print!(".");
320+
for _ in 0..10000 {
321+
test();
322+
}
323+
}
324+
}

0 commit comments

Comments
 (0)