@@ -9,7 +9,7 @@ use std::pin::Pin;
9
9
use std:: ptr;
10
10
use std:: sync:: atomic:: AtomicUsize ;
11
11
use std:: sync:: atomic:: Ordering :: { Acquire , SeqCst } ;
12
- use std:: sync:: { Arc , Mutex , Weak } ;
12
+ use std:: sync:: { Arc , Mutex , MutexGuard , Weak } ;
13
13
14
14
/// Future for the [`shared`](super::FutureExt::shared) method.
15
15
#[ must_use = "futures do nothing unless you `.await` or poll them" ]
@@ -81,6 +81,7 @@ const IDLE: usize = 0;
81
81
const POLLING : usize = 1 ;
82
82
const COMPLETE : usize = 2 ;
83
83
const POISONED : usize = 3 ;
84
+ const WOKEN_DURING_POLLING : usize = 4 ;
84
85
85
86
const NULL_WAKER_KEY : usize = usize:: MAX ;
86
87
@@ -197,36 +198,47 @@ where
197
198
}
198
199
}
199
200
200
- impl < Fut > Inner < Fut >
201
- where
202
- Fut : Future ,
203
- Fut :: Output : Clone ,
204
- {
205
- /// Registers the current task to receive a wakeup when we are awoken.
206
- fn record_waker ( & self , waker_key : & mut usize , cx : & mut Context < ' _ > ) {
207
- let mut wakers_guard = self . notifier . wakers . lock ( ) . unwrap ( ) ;
208
-
209
- let wakers_mut = wakers_guard. as_mut ( ) ;
210
-
211
- let wakers = match wakers_mut {
212
- Some ( wakers) => wakers,
213
- None => return ,
214
- } ;
215
-
216
- let new_waker = cx. waker ( ) ;
201
+ /// Registers the current task to receive a wakeup when we are awoken.
202
+ fn record_waker (
203
+ wakers_guard : & mut MutexGuard < ' _ , Option < Slab < Option < Waker > > > > ,
204
+ waker_key : & mut usize ,
205
+ cx : & mut Context < ' _ > ,
206
+ ) {
207
+ let wakers = match wakers_guard. as_mut ( ) {
208
+ Some ( wakers) => wakers,
209
+ None => return ,
210
+ } ;
211
+
212
+ let new_waker = cx. waker ( ) ;
213
+
214
+ if * waker_key == NULL_WAKER_KEY {
215
+ * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ;
216
+ } else {
217
+ match wakers[ * waker_key] {
218
+ Some ( ref old_waker) if new_waker. will_wake ( old_waker) => { }
219
+ // Could use clone_from here, but Waker doesn't specialize it.
220
+ ref mut slot => * slot = Some ( new_waker. clone ( ) ) ,
221
+ }
222
+ }
223
+ debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ;
224
+ }
217
225
218
- if * waker_key == NULL_WAKER_KEY {
219
- * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ;
220
- } else {
221
- match wakers[ * waker_key] {
222
- Some ( ref old_waker) if new_waker. will_wake ( old_waker) => { }
223
- // Could use clone_from here, but Waker doesn't specialize it.
224
- ref mut slot => * slot = Some ( new_waker. clone ( ) ) ,
226
+ /// Wakes all tasks that are registered to be woken.
227
+ fn wake_all ( waker_guard : & mut MutexGuard < ' _ , Option < Slab < Option < Waker > > > > ) {
228
+ if let Some ( wakers) = waker_guard. as_mut ( ) {
229
+ for ( _key, opt_waker) in wakers {
230
+ if let Some ( waker) = opt_waker. take ( ) {
231
+ waker. wake ( ) ;
225
232
}
226
233
}
227
- debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ;
228
234
}
235
+ }
229
236
237
+ impl < Fut > Inner < Fut >
238
+ where
239
+ Fut : Future ,
240
+ Fut :: Output : Clone ,
241
+ {
230
242
/// Safety: callers must first ensure that `inner.state`
231
243
/// is `COMPLETE`
232
244
unsafe fn take_or_clone_output ( self : Arc < Self > ) -> Fut :: Output {
@@ -268,18 +280,22 @@ where
268
280
return unsafe { Poll :: Ready ( inner. take_or_clone_output ( ) ) } ;
269
281
}
270
282
271
- inner. record_waker ( & mut this. waker_key , cx) ;
283
+ // Guard the state transition with mutex too
284
+ let mut wakers_guard = inner. notifier . wakers . lock ( ) . unwrap ( ) ;
285
+ record_waker ( & mut wakers_guard, & mut this. waker_key , cx) ;
272
286
273
- match inner
287
+ let prev = inner
274
288
. notifier
275
289
. state
276
290
. compare_exchange ( IDLE , POLLING , SeqCst , SeqCst )
277
- . unwrap_or_else ( |x| x)
278
- {
291
+ . unwrap_or_else ( |x| x) ;
292
+ drop ( wakers_guard) ;
293
+
294
+ match prev {
279
295
IDLE => {
280
296
// Lock acquired, fall through
281
297
}
282
- POLLING => {
298
+ POLLING | WOKEN_DURING_POLLING => {
283
299
// Another task is currently polling, at this point we just want
284
300
// to ensure that the waker for this task is registered
285
301
this. inner = Some ( inner) ;
@@ -324,15 +340,21 @@ where
324
340
325
341
match poll_result {
326
342
Poll :: Pending => {
327
- if inner. notifier . state . compare_exchange ( POLLING , IDLE , SeqCst , SeqCst ) . is_ok ( )
328
- {
329
- // Success
330
- drop ( reset) ;
331
- this. inner = Some ( inner) ;
332
- return Poll :: Pending ;
333
- } else {
334
- unreachable ! ( )
343
+ match inner. notifier . state . compare_exchange ( POLLING , IDLE , SeqCst , SeqCst ) {
344
+ Ok ( POLLING ) => { } // success
345
+ Err ( WOKEN_DURING_POLLING ) => {
346
+ // waker has been called inside future.poll, need to wake any new wakers registered
347
+ let mut wakers = inner. notifier . wakers . lock ( ) . unwrap ( ) ;
348
+ wake_all ( & mut wakers) ;
349
+ let prev = inner. notifier . state . swap ( IDLE , SeqCst ) ;
350
+ assert_eq ! ( prev, WOKEN_DURING_POLLING ) ;
351
+ drop ( wakers) ;
352
+ }
353
+ _ => unreachable ! ( ) ,
335
354
}
355
+ drop ( reset) ;
356
+ this. inner = Some ( inner) ;
357
+ return Poll :: Pending ;
336
358
}
337
359
Poll :: Ready ( output) => output,
338
360
}
@@ -387,14 +409,9 @@ where
387
409
388
410
impl ArcWake for Notifier {
389
411
fn wake_by_ref ( arc_self : & Arc < Self > ) {
390
- let wakers = & mut * arc_self. wakers . lock ( ) . unwrap ( ) ;
391
- if let Some ( wakers) = wakers. as_mut ( ) {
392
- for ( _key, opt_waker) in wakers {
393
- if let Some ( waker) = opt_waker. take ( ) {
394
- waker. wake ( ) ;
395
- }
396
- }
397
- }
412
+ let mut wakers = arc_self. wakers . lock ( ) . unwrap ( ) ;
413
+ let _ = arc_self. state . compare_exchange ( POLLING , WOKEN_DURING_POLLING , SeqCst , SeqCst ) ;
414
+ wake_all ( & mut wakers) ;
398
415
}
399
416
}
400
417
0 commit comments