From e15ba66a4d08acca8ca47785db1c666b74a73bb4 Mon Sep 17 00:00:00 2001 From: Taiki Endo Date: Mon, 26 Feb 2024 23:27:39 +0900 Subject: [PATCH] Add 'static bound to waker_ref --- futures-task/src/waker.rs | 12 +-- futures-task/src/waker_ref.rs | 2 +- .../src/stream/futures_unordered/mod.rs | 3 +- .../src/stream/futures_unordered/task.rs | 95 ++++++++++++++++++- 4 files changed, 101 insertions(+), 11 deletions(-) diff --git a/futures-task/src/waker.rs b/futures-task/src/waker.rs index 4569c8581a..b1410a1927 100644 --- a/futures-task/src/waker.rs +++ b/futures-task/src/waker.rs @@ -3,7 +3,7 @@ use alloc::sync::Arc; use core::mem; use core::task::{RawWaker, RawWakerVTable, Waker}; -pub(super) fn waker_vtable() -> &'static RawWakerVTable { +pub(super) fn waker_vtable() -> &'static RawWakerVTable { &RawWakerVTable::new( clone_arc_raw::, wake_arc_raw::, @@ -28,7 +28,7 @@ where // FIXME: panics on Arc::clone / refcount changes could wreak havoc on the // code here. We should guard against this by aborting. -unsafe fn increase_refcount(data: *const ()) { +unsafe fn increase_refcount(data: *const ()) { // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop let arc = mem::ManuallyDrop::new(unsafe { Arc::::from_raw(data.cast::()) }); // Now increase refcount, but don't drop new refcount either @@ -36,23 +36,23 @@ unsafe fn increase_refcount(data: *const ()) { } // used by `waker_ref` -unsafe fn clone_arc_raw(data: *const ()) -> RawWaker { +unsafe fn clone_arc_raw(data: *const ()) -> RawWaker { unsafe { increase_refcount::(data) } RawWaker::new(data, waker_vtable::()) } -unsafe fn wake_arc_raw(data: *const ()) { +unsafe fn wake_arc_raw(data: *const ()) { let arc: Arc = unsafe { Arc::from_raw(data.cast::()) }; ArcWake::wake(arc); } // used by `waker_ref` -unsafe fn wake_by_ref_arc_raw(data: *const ()) { +unsafe fn wake_by_ref_arc_raw(data: *const ()) { // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop let arc = mem::ManuallyDrop::new(unsafe { Arc::::from_raw(data.cast::()) }); ArcWake::wake_by_ref(&arc); } -unsafe fn drop_arc_raw(data: *const ()) { +unsafe fn drop_arc_raw(data: *const ()) { drop(unsafe { Arc::::from_raw(data.cast::()) }) } diff --git a/futures-task/src/waker_ref.rs b/futures-task/src/waker_ref.rs index aac4109577..5957b4d46a 100644 --- a/futures-task/src/waker_ref.rs +++ b/futures-task/src/waker_ref.rs @@ -54,7 +54,7 @@ impl Deref for WakerRef<'_> { #[inline] pub fn waker_ref(wake: &Arc) -> WakerRef<'_> where - W: ArcWake, + W: ArcWake + 'static, { // simply copy the pointer instead of using Arc::into_raw, // as we don't actually keep a refcount by using ManuallyDrop.< diff --git a/futures-util/src/stream/futures_unordered/mod.rs b/futures-util/src/stream/futures_unordered/mod.rs index bcb0962ef4..2d4f15158f 100644 --- a/futures-util/src/stream/futures_unordered/mod.rs +++ b/futures-util/src/stream/futures_unordered/mod.rs @@ -510,7 +510,8 @@ impl Stream for FuturesUnordered { // We are only interested in whether the future is awoken before it // finishes polling, so reset the flag here. task.woken.store(false, Relaxed); - let waker = Task::waker_ref(task); + // SAFETY: see the comments of Bomb and this block. + let waker = unsafe { Task::waker_ref(task) }; let mut cx = Context::from_waker(&waker); // Safety: We won't move the future ever again diff --git a/futures-util/src/stream/futures_unordered/task.rs b/futures-util/src/stream/futures_unordered/task.rs index ec2114effa..2ae4cb7d85 100644 --- a/futures-util/src/stream/futures_unordered/task.rs +++ b/futures-util/src/stream/futures_unordered/task.rs @@ -5,7 +5,7 @@ use core::sync::atomic::{AtomicBool, AtomicPtr}; use super::abort::abort; use super::ReadyToRunQueue; -use crate::task::{waker_ref, ArcWake, WakerRef}; +use crate::task::ArcWake; pub(super) struct Task { // The future @@ -77,8 +77,8 @@ impl ArcWake for Task { impl Task { /// Returns a waker reference for this task without cloning the Arc. - pub(super) fn waker_ref(this: &Arc) -> WakerRef<'_> { - waker_ref(this) + pub(super) unsafe fn waker_ref(this: &Arc) -> waker_ref::WakerRef<'_> { + unsafe { waker_ref::waker_ref(this) } } /// Spins until `next_all` is no longer set to `pending_next_all`. @@ -123,3 +123,92 @@ impl Drop for Task { } } } + +mod waker_ref { + use alloc::sync::Arc; + use core::marker::PhantomData; + use core::mem; + use core::mem::ManuallyDrop; + use core::ops::Deref; + use core::task::{RawWaker, RawWakerVTable, Waker}; + use futures_task::ArcWake; + + pub(crate) struct WakerRef<'a> { + waker: ManuallyDrop, + _marker: PhantomData<&'a ()>, + } + + impl WakerRef<'_> { + #[inline] + fn new_unowned(waker: ManuallyDrop) -> Self { + Self { waker, _marker: PhantomData } + } + } + + impl Deref for WakerRef<'_> { + type Target = Waker; + + #[inline] + fn deref(&self) -> &Waker { + &self.waker + } + } + + /// Copy of `future_task::waker_ref` without `W: 'static` bound. + /// + /// # Safety + /// + /// The caller must guarantee that use-after-free will not occur. + #[inline] + pub(crate) unsafe fn waker_ref(wake: &Arc) -> WakerRef<'_> + where + W: ArcWake, + { + // simply copy the pointer instead of using Arc::into_raw, + // as we don't actually keep a refcount by using ManuallyDrop.< + let ptr = Arc::as_ptr(wake).cast::<()>(); + + let waker = + ManuallyDrop::new(unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::())) }); + WakerRef::new_unowned(waker) + } + + fn waker_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new( + clone_arc_raw::, + wake_arc_raw::, + wake_by_ref_arc_raw::, + drop_arc_raw::, + ) + } + + // FIXME: panics on Arc::clone / refcount changes could wreak havoc on the + // code here. We should guard against this by aborting. + + unsafe fn increase_refcount(data: *const ()) { + // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop + let arc = mem::ManuallyDrop::new(unsafe { Arc::::from_raw(data.cast::()) }); + // Now increase refcount, but don't drop new refcount either + let _arc_clone: mem::ManuallyDrop<_> = arc.clone(); + } + + unsafe fn clone_arc_raw(data: *const ()) -> RawWaker { + unsafe { increase_refcount::(data) } + RawWaker::new(data, waker_vtable::()) + } + + unsafe fn wake_arc_raw(data: *const ()) { + let arc: Arc = unsafe { Arc::from_raw(data.cast::()) }; + ArcWake::wake(arc); + } + + unsafe fn wake_by_ref_arc_raw(data: *const ()) { + // Retain Arc, but don't touch refcount by wrapping in ManuallyDrop + let arc = mem::ManuallyDrop::new(unsafe { Arc::::from_raw(data.cast::()) }); + ArcWake::wake_by_ref(&arc); + } + + unsafe fn drop_arc_raw(data: *const ()) { + drop(unsafe { Arc::::from_raw(data.cast::()) }) + } +}