diff --git a/futures-util/src/lock/bilock.rs b/futures-util/src/lock/bilock.rs index 2174079c83..7ddc66ad2c 100644 --- a/futures-util/src/lock/bilock.rs +++ b/futures-util/src/lock/bilock.rs @@ -3,11 +3,11 @@ use alloc::boxed::Box; use alloc::sync::Arc; use core::cell::UnsafeCell; -use core::fmt; use core::ops::{Deref, DerefMut}; use core::pin::Pin; -use core::sync::atomic::AtomicUsize; +use core::sync::atomic::AtomicPtr; use core::sync::atomic::Ordering::SeqCst; +use core::{fmt, ptr}; #[cfg(feature = "bilock")] use futures_core::future::Future; use futures_core::task::{Context, Poll, Waker}; @@ -41,7 +41,7 @@ pub struct BiLock { #[derive(Debug)] struct Inner { - state: AtomicUsize, + state: AtomicPtr, value: Option>, } @@ -61,7 +61,10 @@ impl BiLock { /// Similarly, reuniting the lock and extracting the inner value is only /// possible when `T` is `Unpin`. pub fn new(t: T) -> (Self, Self) { - let arc = Arc::new(Inner { state: AtomicUsize::new(0), value: Some(UnsafeCell::new(t)) }); + let arc = Arc::new(Inner { + state: AtomicPtr::new(ptr::null_mut()), + value: Some(UnsafeCell::new(t)), + }); (Self { arc: arc.clone() }, Self { arc }) } @@ -87,7 +90,8 @@ impl BiLock { pub fn poll_lock(&self, cx: &mut Context<'_>) -> Poll> { let mut waker = None; loop { - match self.arc.state.swap(1, SeqCst) { + let n = self.arc.state.swap(invalid_ptr(1), SeqCst); + match n as usize { // Woohoo, we grabbed the lock! 0 => return Poll::Ready(BiLockGuard { bilock: self }), @@ -96,8 +100,8 @@ impl BiLock { // A task was previously blocked on this lock, likely our task, // so we need to update that task. - n => unsafe { - let mut prev = Box::from_raw(n as *mut Waker); + _ => unsafe { + let mut prev = Box::from_raw(n); *prev = cx.waker().clone(); waker = Some(prev); }, @@ -105,9 +109,9 @@ impl BiLock { // type ascription for safety's sake! let me: Box = waker.take().unwrap_or_else(|| Box::new(cx.waker().clone())); - let me = Box::into_raw(me) as usize; + let me = Box::into_raw(me); - match self.arc.state.compare_exchange(1, me, SeqCst, SeqCst) { + match self.arc.state.compare_exchange(invalid_ptr(1), me, SeqCst, SeqCst) { // The lock is still locked, but we've now parked ourselves, so // just report that we're scheduled to receive a notification. Ok(_) => return Poll::Pending, @@ -115,8 +119,8 @@ impl BiLock { // Oops, looks like the lock was unlocked after our swap above // and before the compare_exchange. Deallocate what we just // allocated and go through the loop again. - Err(0) => unsafe { - waker = Some(Box::from_raw(me as *mut Waker)); + Err(n) if n.is_null() => unsafe { + waker = Some(Box::from_raw(me)); }, // The top of this loop set the previous state to 1, so if we @@ -125,7 +129,7 @@ impl BiLock { // but we're trying to acquire the lock and there's only one // other reference of the lock, so it should be impossible for // that task to ever block itself. - Err(n) => panic!("invalid state: {}", n), + Err(n) => panic!("invalid state: {}", n as usize), } } } @@ -164,7 +168,8 @@ impl BiLock { } fn unlock(&self) { - match self.arc.state.swap(0, SeqCst) { + let n = self.arc.state.swap(ptr::null_mut(), SeqCst); + match n as usize { // we've locked the lock, shouldn't be possible for us to see an // unlocked lock. 0 => panic!("invalid unlocked state"), @@ -174,8 +179,8 @@ impl BiLock { // Another task has parked themselves on this lock, let's wake them // up as its now their turn. - n => unsafe { - Box::from_raw(n as *mut Waker).wake(); + _ => unsafe { + Box::from_raw(n).wake(); }, } } @@ -189,7 +194,7 @@ impl Inner { impl Drop for Inner { fn drop(&mut self) { - assert_eq!(self.state.load(SeqCst), 0); + assert!(self.state.load(SeqCst).is_null()); } } @@ -277,3 +282,12 @@ impl<'a, T> Future for BiLockAcquire<'a, T> { self.bilock.poll_lock(cx) } } + +// Based on core::ptr::invalid_mut. Equivalent to `addr as *mut T`, but is strict-provenance compatible. +#[allow(clippy::useless_transmute)] +#[inline] +fn invalid_ptr(addr: usize) -> *mut T { + // SAFETY: every valid integer is also a valid pointer (as long as you don't dereference that + // pointer). + unsafe { core::mem::transmute(addr) } +}