diff --git a/src/barrier.rs b/src/barrier.rs index 344254d..ff803c7 100644 --- a/src/barrier.rs +++ b/src/barrier.rs @@ -1,5 +1,12 @@ -use event_listener::Event; +use event_listener::{Event, EventListener}; +use futures_lite::ready; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::futures::Lock; use crate::Mutex; /// A counter to synchronize multiple tasks at the same time. @@ -72,24 +79,103 @@ impl Barrier { /// }); /// } /// ``` - pub async fn wait(&self) -> BarrierWaitResult { - let mut state = self.state.lock().await; - let local_gen = state.generation_id; - state.count += 1; - - if state.count < self.n { - while local_gen == state.generation_id && state.count < self.n { - let listener = self.event.listen(); - drop(state); - listener.await; - state = self.state.lock().await; + pub fn wait(&self) -> BarrierWait<'_> { + BarrierWait { + barrier: self, + lock: Some(self.state.lock()), + state: WaitState::Initial, + } + } +} + +/// The future returned by [`Barrier::wait()`]. +pub struct BarrierWait<'a> { + /// The barrier to wait on. + barrier: &'a Barrier, + + /// The ongoing mutex lock operation we are blocking on. + lock: Option<Lock<'a, State>>, + + /// The current state of the future. + state: WaitState, +} + +impl fmt::Debug for BarrierWait<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("BarrierWait { .. }") + } +} + +enum WaitState { + /// We are getting the original values of the state. + Initial, + + /// We are waiting for the listener to complete. + Waiting { evl: EventListener, local_gen: u64 }, + + /// Waiting to re-acquire the lock to check the state again. + Reacquiring(u64), +} + +impl Future for BarrierWait<'_> { + type Output = BarrierWaitResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + + loop { + match this.state { + WaitState::Initial => { + // See if the lock is ready yet. + let mut state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx)); + this.lock = None; + + let local_gen = state.generation_id; + state.count += 1; + + if state.count < this.barrier.n { + // We need to wait for the event. + this.state = WaitState::Waiting { + evl: this.barrier.event.listen(), + local_gen, + }; + } else { + // We are the last one. + state.count = 0; + state.generation_id = state.generation_id.wrapping_add(1); + this.barrier.event.notify(std::usize::MAX); + return Poll::Ready(BarrierWaitResult { is_leader: true }); + } + } + + WaitState::Waiting { + ref mut evl, + local_gen, + } => { + ready!(Pin::new(evl).poll(cx)); + + // We are now re-acquiring the mutex. + this.lock = Some(this.barrier.state.lock()); + this.state = WaitState::Reacquiring(local_gen); + } + + WaitState::Reacquiring(local_gen) => { + // Acquire the local state again. + let state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx)); + this.lock = None; + + if local_gen == state.generation_id && state.count < this.barrier.n { + // We need to wait for the event again. + this.state = WaitState::Waiting { + evl: this.barrier.event.listen(), + local_gen, + }; + } else { + // We are ready, but not the leader. + return Poll::Ready(BarrierWaitResult { is_leader: false }); + } + } } - BarrierWaitResult { is_leader: false } - } else { - state.count = 0; - state.generation_id = state.generation_id.wrapping_add(1); - self.event.notify(std::usize::MAX); - BarrierWaitResult { is_leader: true } } } } diff --git a/src/lib.rs b/src/lib.rs index 78a1ff5..8d14e63 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,3 +20,12 @@ pub use mutex::{Mutex, MutexGuard, MutexGuardArc}; pub use once_cell::OnceCell; pub use rwlock::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard}; pub use semaphore::{Semaphore, SemaphoreGuard, SemaphoreGuardArc}; + +pub mod futures { + //! Named futures for use with `async_lock` primitives. + + pub use crate::barrier::BarrierWait; + pub use crate::mutex::{Lock, LockArc}; + pub use crate::rwlock::{Read, UpgradableRead, Upgrade, Write}; + pub use crate::semaphore::{Acquire, AcquireArc}; +} diff --git a/src/mutex.rs b/src/mutex.rs index d596a60..88cc890 100644 --- a/src/mutex.rs +++ b/src/mutex.rs @@ -1,10 +1,15 @@ +use std::borrow::Borrow; use std::cell::UnsafeCell; use std::fmt; use std::future::Future; +use std::marker::PhantomData; +use std::mem; use std::ops::{Deref, DerefMut}; +use std::pin::Pin; use std::process; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::task::{Context, Poll}; // Note: we cannot use `target_family = "wasm"` here because it requires Rust 1.54. #[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))] @@ -12,7 +17,8 @@ use std::time::{Duration, Instant}; use std::usize; -use event_listener::Event; +use event_listener::{Event, EventListener}; +use futures_lite::ready; /// An async mutex. /// @@ -103,114 +109,10 @@ impl<T: ?Sized> Mutex<T> { /// # }) /// ``` #[inline] - pub async fn lock(&self) -> MutexGuard<'_, T> { - if let Some(guard) = self.try_lock() { - return guard; - } - self.acquire_slow().await; - MutexGuard(self) - } - - /// Slow path for acquiring the mutex. - #[cold] - async fn acquire_slow(&self) { - // Get the current time. - #[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))] - let start = Instant::now(); - - loop { - // Start listening for events. - let listener = self.lock_ops.listen(); - - // Try locking if nobody is being starved. - match self - .state - .compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire) - .unwrap_or_else(|x| x) - { - // Lock acquired! - 0 => return, - - // Lock is held and nobody is starved. - 1 => {} - - // Somebody is starved. - _ => break, - } - - // Wait for a notification. - listener.await; - - // Try locking if nobody is being starved. - match self - .state - .compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire) - .unwrap_or_else(|x| x) - { - // Lock acquired! - 0 => return, - - // Lock is held and nobody is starved. - 1 => {} - - // Somebody is starved. - _ => { - // Notify the first listener in line because we probably received a - // notification that was meant for a starved task. - self.lock_ops.notify(1); - break; - } - } - - // If waiting for too long, fall back to a fairer locking strategy that will prevent - // newer lock operations from starving us forever. - #[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))] - if start.elapsed() > Duration::from_micros(500) { - break; - } - } - - // Increment the number of starved lock operations. - if self.state.fetch_add(2, Ordering::Release) > usize::MAX / 2 { - // In case of potential overflow, abort. - process::abort(); - } - - // Decrement the counter when exiting this function. - let _call = CallOnDrop(|| { - self.state.fetch_sub(2, Ordering::Release); - }); - - loop { - // Start listening for events. - let listener = self.lock_ops.listen(); - - // Try locking if nobody else is being starved. - match self - .state - .compare_exchange(2, 2 | 1, Ordering::Acquire, Ordering::Acquire) - .unwrap_or_else(|x| x) - { - // Lock acquired! - 2 => return, - - // Lock is held by someone. - s if s % 2 == 1 => {} - - // Lock is available. - _ => { - // Be fair: notify the first listener and then go wait in line. - self.lock_ops.notify(1); - } - } - - // Wait for a notification. - listener.await; - - // Try acquiring the lock without waiting for others. - if self.state.fetch_or(1, Ordering::Acquire) % 2 == 0 { - return; - } + pub fn lock(&self) -> Lock<'_, T> { + Lock { + mutex: self, + acquire_slow: None, } } @@ -265,14 +167,6 @@ impl<T: ?Sized> Mutex<T> { } impl<T: ?Sized> Mutex<T> { - async fn lock_arc_impl(self: Arc<Self>) -> MutexGuardArc<T> { - if let Some(guard) = self.try_lock_arc() { - return guard; - } - self.acquire_slow().await; - MutexGuardArc(self) - } - /// Acquires the mutex and clones a reference to it. /// /// Returns an owned guard that releases the mutex when dropped. @@ -290,8 +184,8 @@ impl<T: ?Sized> Mutex<T> { /// # }) /// ``` #[inline] - pub fn lock_arc(self: &Arc<Self>) -> impl Future<Output = MutexGuardArc<T>> { - self.clone().lock_arc_impl() + pub fn lock_arc(self: &Arc<Self>) -> LockArc<T> { + LockArc(LockArcInnards::Unpolled(self.clone())) } /// Attempts to acquire the mutex and clone a reference to it. @@ -353,6 +247,295 @@ impl<T: Default + ?Sized> Default for Mutex<T> { } } +/// The future returned by [`Mutex::lock`]. +pub struct Lock<'a, T: ?Sized> { + /// Reference to the mutex. + mutex: &'a Mutex<T>, + + /// The future that waits for the mutex to become available. + acquire_slow: Option<AcquireSlow<&'a Mutex<T>, T>>, +} + +impl<'a, T: ?Sized> Unpin for Lock<'a, T> {} + +impl<T: ?Sized> fmt::Debug for Lock<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Lock { .. }") + } +} + +impl<'a, T: ?Sized> Future for Lock<'a, T> { + type Output = MutexGuard<'a, T>; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + + loop { + match this.acquire_slow.as_mut() { + None => { + // Try the fast path before trying to register slowly. + match this.mutex.try_lock() { + Some(guard) => return Poll::Ready(guard), + None => { + this.acquire_slow = Some(AcquireSlow::new(this.mutex)); + } + } + } + + Some(acquire_slow) => { + // Continue registering slowly. + let value = ready!(Pin::new(acquire_slow).poll(cx)); + return Poll::Ready(MutexGuard(value)); + } + } + } + } +} + +/// The future returned by [`Mutex::lock_arc`]. +pub struct LockArc<T: ?Sized>(LockArcInnards<T>); + +enum LockArcInnards<T: ?Sized> { + /// We have not tried to poll the fast path yet. + Unpolled(Arc<Mutex<T>>), + + /// We are acquiring the mutex through the slow path. + AcquireSlow(AcquireSlow<Arc<Mutex<T>>, T>), + + /// Empty hole to make taking easier. + Empty, +} + +impl<T: ?Sized> Unpin for LockArc<T> {} + +impl<T: ?Sized> fmt::Debug for LockArc<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("LockArc { .. }") + } +} + +impl<T: ?Sized> Future for LockArc<T> { + type Output = MutexGuardArc<T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + + loop { + match mem::replace(&mut this.0, LockArcInnards::Empty) { + LockArcInnards::Unpolled(mutex) => { + // Try the fast path before trying to register slowly. + match mutex.try_lock_arc() { + Some(guard) => return Poll::Ready(guard), + None => { + *this = LockArc(LockArcInnards::AcquireSlow(AcquireSlow::new( + mutex.clone(), + ))); + } + } + } + + LockArcInnards::AcquireSlow(mut acquire_slow) => { + // Continue registering slowly. + let value = match Pin::new(&mut acquire_slow).poll(cx) { + Poll::Pending => { + *this = LockArc(LockArcInnards::AcquireSlow(acquire_slow)); + return Poll::Pending; + } + Poll::Ready(value) => value, + }; + return Poll::Ready(MutexGuardArc(value)); + } + + LockArcInnards::Empty => panic!("future polled after completion"), + } + } + } +} + +/// Future for acquiring the mutex slowly. +struct AcquireSlow<B: Borrow<Mutex<T>>, T: ?Sized> { + /// Reference to the mutex. + mutex: Option<B>, + + /// The event listener waiting on the mutex. + listener: Option<EventListener>, + + /// The point at which the mutex lock was started. + #[cfg(not(any(target_arch = "wasm32", target_os = "wasm64")))] + start: Option<Instant>, + + /// This lock operation is starving. + starved: bool, + + /// Capture the `T` lifetime. + _marker: PhantomData<T>, +} + +impl<B: Borrow<Mutex<T>> + Unpin, T: ?Sized> Unpin for AcquireSlow<B, T> {} + +impl<T: ?Sized, B: Borrow<Mutex<T>>> AcquireSlow<B, T> { + /// Create a new `AcquireSlow` future. + #[cold] + fn new(mutex: B) -> Self { + AcquireSlow { + mutex: Some(mutex), + listener: None, + #[cfg(not(any(target_arch = "wasm32", target_os = "wasm64")))] + start: None, + starved: false, + _marker: PhantomData, + } + } + + /// Take the mutex reference out, decrementing the counter if necessary. + fn take_mutex(&mut self) -> Option<B> { + let mutex = self.mutex.take(); + + if self.starved { + if let Some(mutex) = mutex.as_ref() { + // Decrement this counter before we exit. + mutex.borrow().state.fetch_sub(2, Ordering::Release); + } + } + + mutex + } +} + +impl<T: ?Sized, B: Unpin + Borrow<Mutex<T>>> Future for AcquireSlow<B, T> { + type Output = B; + + #[cold] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = &mut *self; + #[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))] + let start = *this.start.get_or_insert_with(Instant::now); + let mutex = this + .mutex + .as_ref() + .expect("future polled after completion") + .borrow(); + + // Only use this hot loop if we aren't currently starved. + if !this.starved { + loop { + // Start listening for events. + match &mut this.listener { + listener @ None => { + // Start listening for events. + *listener = Some(mutex.lock_ops.listen()); + + // Try locking if nobody is being starved. + match mutex + .state + .compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire) + .unwrap_or_else(|x| x) + { + // Lock acquired! + 0 => return Poll::Ready(this.take_mutex().unwrap()), + + // Lock is held and nobody is starved. + 1 => {} + + // Somebody is starved. + _ => break, + } + } + Some(ref mut listener) => { + // Wait for a notification. + ready!(Pin::new(listener).poll(cx)); + this.listener = None; + + // Try locking if nobody is being starved. + match mutex + .state + .compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire) + .unwrap_or_else(|x| x) + { + // Lock acquired! + 0 => return Poll::Ready(this.take_mutex().unwrap()), + + // Lock is held and nobody is starved. + 1 => {} + + // Somebody is starved. + _ => { + // Notify the first listener in line because we probably received a + // notification that was meant for a starved task. + mutex.lock_ops.notify(1); + break; + } + } + + // If waiting for too long, fall back to a fairer locking strategy that will prevent + // newer lock operations from starving us forever. + #[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))] + if start.elapsed() > Duration::from_micros(500) { + break; + } + } + } + } + + // Increment the number of starved lock operations. + if mutex.state.fetch_add(2, Ordering::Release) > usize::MAX / 2 { + // In case of potential overflow, abort. + process::abort(); + } + + // Indicate that we are now starving and will use a fairer locking strategy. + this.starved = true; + } + + // Fairer locking loop. + loop { + match &mut this.listener { + listener @ None => { + // Start listening for events. + *listener = Some(mutex.lock_ops.listen()); + + // Try locking if nobody else is being starved. + match mutex + .state + .compare_exchange(2, 2 | 1, Ordering::Acquire, Ordering::Acquire) + .unwrap_or_else(|x| x) + { + // Lock acquired! + 2 => return Poll::Ready(this.take_mutex().unwrap()), + + // Lock is held by someone. + s if s % 2 == 1 => {} + + // Lock is available. + _ => { + // Be fair: notify the first listener and then go wait in line. + mutex.lock_ops.notify(1); + } + } + } + Some(ref mut listener) => { + // Wait for a notification. + ready!(Pin::new(listener).poll(cx)); + this.listener = None; + + // Try acquiring the lock without waiting for others. + if mutex.state.fetch_or(1, Ordering::Acquire) % 2 == 0 { + return Poll::Ready(this.take_mutex().unwrap()); + } + } + } + } + } +} + +impl<T: ?Sized, B: Borrow<Mutex<T>>> Drop for AcquireSlow<B, T> { + fn drop(&mut self) { + // Make sure the starvation counter is decremented. + self.take_mutex(); + } +} + /// A guard that releases the mutex when dropped. pub struct MutexGuard<'a, T: ?Sized>(&'a Mutex<T>); diff --git a/src/rwlock.rs b/src/rwlock.rs index f3a9c79..3d5795b 100644 --- a/src/rwlock.rs +++ b/src/rwlock.rs @@ -1,12 +1,17 @@ use std::cell::UnsafeCell; use std::fmt; +use std::future::Future; use std::mem; use std::ops::{Deref, DerefMut}; +use std::pin::Pin; use std::process; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; -use event_listener::Event; +use event_listener::{Event, EventListener}; +use futures_lite::ready; +use crate::futures::Lock; use crate::{Mutex, MutexGuard}; const WRITER_BIT: usize = 1; @@ -170,42 +175,11 @@ impl<T: ?Sized> RwLock<T> { /// assert!(lock.try_read().is_some()); /// # }) /// ``` - pub async fn read(&self) -> RwLockReadGuard<'_, T> { - let mut state = self.state.load(Ordering::Acquire); - - loop { - if state & WRITER_BIT == 0 { - // Make sure the number of readers doesn't overflow. - if state > std::isize::MAX as usize { - process::abort(); - } - - // If nobody is holding a write lock or attempting to acquire it, increment the - // number of readers. - match self.state.compare_exchange( - state, - state + ONE_READER, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => return RwLockReadGuard(self), - Err(s) => state = s, - } - } else { - // Start listening for "no writer" events. - let listener = self.no_writer.listen(); - - // Check again if there's a writer. - if self.state.load(Ordering::SeqCst) & WRITER_BIT != 0 { - // Wait until the writer is dropped. - listener.await; - // Notify the next reader waiting in line. - self.no_writer.notify(1); - } - - // Reload the state. - state = self.state.load(Ordering::Acquire); - } + pub fn read(&self) -> Read<'_, T> { + Read { + lock: self, + state: self.state.load(Ordering::Acquire), + listener: None, } } @@ -289,33 +263,10 @@ impl<T: ?Sized> RwLock<T> { /// *writer = 2; /// # }) /// ``` - pub async fn upgradable_read(&self) -> RwLockUpgradableReadGuard<'_, T> { - // First grab the mutex. - let lock = self.mutex.lock().await; - - let mut state = self.state.load(Ordering::Acquire); - - // Make sure the number of readers doesn't overflow. - if state > std::isize::MAX as usize { - process::abort(); - } - - // Increment the number of readers. - loop { - match self.state.compare_exchange( - state, - state + ONE_READER, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => { - return RwLockUpgradableReadGuard { - reader: RwLockReadGuard(self), - reserved: lock, - } - } - Err(s) => state = s, - } + pub fn upgradable_read(&self) -> UpgradableRead<'_, T> { + UpgradableRead { + lock: self, + acquire: self.mutex.lock(), } } @@ -372,30 +323,11 @@ impl<T: ?Sized> RwLock<T> { /// assert!(lock.try_read().is_none()); /// # }) /// ``` - pub async fn write(&self) -> RwLockWriteGuard<'_, T> { - // First grab the mutex. - let lock = self.mutex.lock().await; - - // Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled. - self.state.fetch_or(WRITER_BIT, Ordering::SeqCst); - let guard = RwLockWriteGuard { - writer: RwLockWriteGuardInner(self), - reserved: lock, - }; - - // If there are readers, we need to wait for them to finish. - while self.state.load(Ordering::SeqCst) != WRITER_BIT { - // Start listening for "no readers" events. - let listener = self.no_readers.listen(); - - // Check again if there are readers. - if self.state.load(Ordering::Acquire) != WRITER_BIT { - // Wait for the readers to finish. - listener.await; - } + pub fn write(&self) -> Write<'_, T> { + Write { + lock: self, + state: WriteState::Acquiring(self.mutex.lock()), } - - guard } /// Returns a mutable reference to the inner value. @@ -448,6 +380,230 @@ impl<T: Default + ?Sized> Default for RwLock<T> { } } +/// The future returned by [`RwLock::read`]. +pub struct Read<'a, T: ?Sized> { + /// The lock that is being acquired. + lock: &'a RwLock<T>, + + /// The last-observed state of the lock. + state: usize, + + /// The listener for the "no writers" event. + listener: Option<EventListener>, +} + +impl<T: ?Sized> fmt::Debug for Read<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Read { .. }") + } +} + +impl<T: ?Sized> Unpin for Read<'_, T> {} + +impl<'a, T: ?Sized> Future for Read<'a, T> { + type Output = RwLockReadGuard<'a, T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + + loop { + if this.state & WRITER_BIT == 0 { + // Make sure the number of readers doesn't overflow. + if this.state > std::isize::MAX as usize { + process::abort(); + } + + // If nobody is holding a write lock or attempting to acquire it, increment the + // number of readers. + match this.lock.state.compare_exchange( + this.state, + this.state + ONE_READER, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return Poll::Ready(RwLockReadGuard(this.lock)), + Err(s) => this.state = s, + } + } else { + // Start listening for "no writer" events. + let load_ordering = match &mut this.listener { + listener @ None => { + *listener = Some(this.lock.no_writer.listen()); + + // Make sure there really is no writer. + Ordering::SeqCst + } + + Some(ref mut listener) => { + // Wait for the writer to finish. + ready!(Pin::new(listener).poll(cx)); + this.listener = None; + + // Notify the next reader waiting in list. + this.lock.no_writer.notify(1); + + // Check the state again. + Ordering::Acquire + } + }; + + // Reload the state. + this.state = this.lock.state.load(load_ordering); + } + } + } +} + +/// The future returned by [`RwLock::upgradable_read`]. +pub struct UpgradableRead<'a, T: ?Sized> { + /// The lock that is being acquired. + lock: &'a RwLock<T>, + + /// The mutex we are trying to acquire. + acquire: Lock<'a, ()>, +} + +impl<T: ?Sized> fmt::Debug for UpgradableRead<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("UpgradableRead { .. }") + } +} + +impl<T: ?Sized> Unpin for UpgradableRead<'_, T> {} + +impl<'a, T: ?Sized> Future for UpgradableRead<'a, T> { + type Output = RwLockUpgradableReadGuard<'a, T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + + // Acquire the mutex. + let mutex_guard = ready!(Pin::new(&mut this.acquire).poll(cx)); + + let mut state = this.lock.state.load(Ordering::Acquire); + + // Make sure the number of readers doesn't overflow. + if state > std::isize::MAX as usize { + process::abort(); + } + + // Increment the number of readers. + loop { + match this.lock.state.compare_exchange( + state, + state + ONE_READER, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + return Poll::Ready(RwLockUpgradableReadGuard { + reader: RwLockReadGuard(this.lock), + reserved: mutex_guard, + }); + } + Err(s) => state = s, + } + } + } +} + +/// The future returned by [`RwLock::write`]. +pub struct Write<'a, T: ?Sized> { + /// The lock that is being acquired. + lock: &'a RwLock<T>, + + /// Current state fof this future. + state: WriteState<'a, T>, +} + +enum WriteState<'a, T: ?Sized> { + /// We are currently acquiring the inner mutex. + Acquiring(Lock<'a, ()>), + + /// We are currently waiting for readers to finish. + WaitingReaders { + /// Our current write guard. + guard: Option<RwLockWriteGuard<'a, T>>, + + /// The listener for the "no readers" event. + listener: Option<EventListener>, + }, +} + +impl<T: ?Sized> fmt::Debug for Write<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Write { .. }") + } +} + +impl<T: ?Sized> Unpin for Write<'_, T> {} + +impl<'a, T: ?Sized> Future for Write<'a, T> { + type Output = RwLockWriteGuard<'a, T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + + loop { + match &mut this.state { + WriteState::Acquiring(lock) => { + // First grab the mutex. + let mutex_guard = ready!(Pin::new(lock).poll(cx)); + + // Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled. + let new_state = this.lock.state.fetch_or(WRITER_BIT, Ordering::SeqCst); + let guard = RwLockWriteGuard { + writer: RwLockWriteGuardInner(this.lock), + reserved: mutex_guard, + }; + + // If we just acquired the writer lock, return it. + if new_state == WRITER_BIT { + return Poll::Ready(guard); + } + + // Start waiting for the readers to finish. + this.state = WriteState::WaitingReaders { + guard: Some(guard), + listener: Some(this.lock.no_readers.listen()), + }; + } + + WriteState::WaitingReaders { + guard, + ref mut listener, + } => { + let load_ordering = if listener.is_some() { + Ordering::Acquire + } else { + Ordering::SeqCst + }; + + // Check the state again. + if this.lock.state.load(load_ordering) == WRITER_BIT { + // We are the only ones holding the lock, return it. + return Poll::Ready(guard.take().unwrap()); + } + + // Wait for the readers to finish. + match listener { + None => { + // Register a listener. + *listener = Some(this.lock.no_readers.listen()); + } + + Some(ref mut evl) => { + // Wait for the readers to finish. + ready!(Pin::new(evl).poll(cx)); + *listener = None; + } + }; + } + } + } + } +} + /// A guard that releases the read lock when dropped. pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RwLock<T>); @@ -585,7 +741,7 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> { /// *writer = 2; /// # }) /// ``` - pub async fn upgrade(guard: Self) -> RwLockWriteGuard<'a, T> { + pub fn upgrade(guard: Self) -> Upgrade<'a, T> { // Set `WRITER_BIT` and decrement the number of readers at the same time. guard .reader @@ -596,19 +752,10 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> { // Convert into a write guard that unsets `WRITER_BIT` in case this future is canceled. let guard = guard.into_writer(); - // If there are readers, we need to wait for them to finish. - while guard.writer.0.state.load(Ordering::SeqCst) != WRITER_BIT { - // Start listening for "no readers" events. - let listener = guard.writer.0.no_readers.listen(); - - // Check again if there are readers. - if guard.writer.0.state.load(Ordering::Acquire) != WRITER_BIT { - // Wait for the readers to finish. - listener.await; - } + Upgrade { + guard: Some(guard), + listener: None, } - - guard } } @@ -632,6 +779,67 @@ impl<T: ?Sized> Deref for RwLockUpgradableReadGuard<'_, T> { } } +/// The future returned by [`RwLockUpgradableReadGuard::upgrade`]. +pub struct Upgrade<'a, T: ?Sized> { + /// The guard that we are upgrading to. + guard: Option<RwLockWriteGuard<'a, T>>, + + /// The event listener we are waiting on. + listener: Option<EventListener>, +} + +impl<T: ?Sized> fmt::Debug for Upgrade<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Upgrade").finish() + } +} + +impl<T: ?Sized> Unpin for Upgrade<'_, T> {} + +impl<'a, T: ?Sized> Future for Upgrade<'a, T> { + type Output = RwLockWriteGuard<'a, T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + let guard = this + .guard + .as_mut() + .expect("cannot poll future after completion"); + + // If there are readers, we need to wait for them to finish. + loop { + let load_ordering = if this.listener.is_some() { + Ordering::Acquire + } else { + Ordering::SeqCst + }; + + // See if the number of readers is zero. + let state = guard.writer.0.state.load(load_ordering); + if state == WRITER_BIT { + break; + } + + // If there are readers, wait for them to finish. + match &mut this.listener { + listener @ None => { + // Start listening for "no readers" events. + *listener = Some(guard.writer.0.no_readers.listen()); + } + + Some(ref mut listener) => { + // Wait for the readers to finish. + ready!(Pin::new(listener).poll(cx)); + this.listener = None; + } + } + } + + // We are done. + Poll::Ready(this.guard.take().unwrap()) + } +} + struct RwLockWriteGuardInner<'a, T: ?Sized>(&'a RwLock<T>); impl<T: ?Sized> Drop for RwLockWriteGuardInner<'_, T> { diff --git a/src/semaphore.rs b/src/semaphore.rs index 094b061..15482c4 100644 --- a/src/semaphore.rs +++ b/src/semaphore.rs @@ -1,8 +1,12 @@ +use std::fmt; use std::future::Future; +use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::task::{Context, Poll}; -use event_listener::Event; +use event_listener::{Event, EventListener}; +use futures_lite::ready; /// A counter for limiting the number of concurrent operations. #[derive(Debug)] @@ -80,18 +84,10 @@ impl Semaphore { /// let guard = s.acquire().await; /// # }); /// ``` - pub async fn acquire(&self) -> SemaphoreGuard<'_> { - let mut listener = None; - - loop { - if let Some(guard) = self.try_acquire() { - return guard; - } - - match listener.take() { - None => listener = Some(self.event.listen()), - Some(l) => l.await, - } + pub fn acquire(&self) -> Acquire<'_> { + Acquire { + semaphore: self, + listener: None, } } } @@ -136,21 +132,6 @@ impl Semaphore { } } - async fn acquire_arc_impl(self: Arc<Self>) -> SemaphoreGuardArc { - let mut listener = None; - - loop { - if let Some(guard) = self.try_acquire_arc() { - return guard; - } - - match listener.take() { - None => listener = Some(self.event.listen()), - Some(l) => l.await, - } - } - } - /// Waits for an owned permit for a concurrent operation. /// /// Returns a guard that releases the permit when dropped. @@ -166,8 +147,100 @@ impl Semaphore { /// let guard = s.acquire_arc().await; /// # }); /// ``` - pub fn acquire_arc(self: &Arc<Self>) -> impl Future<Output = SemaphoreGuardArc> { - self.clone().acquire_arc_impl() + pub fn acquire_arc(self: &Arc<Self>) -> AcquireArc { + AcquireArc { + semaphore: self.clone(), + listener: None, + } + } +} + +/// The future returned by [`Semaphore::acquire`]. +pub struct Acquire<'a> { + /// The semaphore being acquired. + semaphore: &'a Semaphore, + + /// The listener waiting on the semaphore. + listener: Option<EventListener>, +} + +impl fmt::Debug for Acquire<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Acquire { .. }") + } +} + +impl Unpin for Acquire<'_> {} + +impl<'a> Future for Acquire<'a> { + type Output = SemaphoreGuard<'a>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + + loop { + match this.semaphore.try_acquire() { + Some(guard) => return Poll::Ready(guard), + None => { + // Wait on the listener. + match &mut this.listener { + listener @ None => { + *listener = Some(this.semaphore.event.listen()); + } + Some(ref mut listener) => { + ready!(Pin::new(listener).poll(cx)); + this.listener = None; + } + } + } + } + } + } +} + +/// The future returned by [`Semaphore::acquire_arc`]. +pub struct AcquireArc { + /// The semaphore being acquired. + semaphore: Arc<Semaphore>, + + /// The listener waiting on the semaphore. + listener: Option<EventListener>, +} + +impl fmt::Debug for AcquireArc { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("AcquireArc { .. }") + } +} + +impl Unpin for AcquireArc {} + +impl Future for AcquireArc { + type Output = SemaphoreGuardArc; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.get_mut(); + + loop { + match this.semaphore.try_acquire_arc() { + Some(guard) => { + this.listener = None; + return Poll::Ready(guard); + } + None => { + // Wait on the listener. + match &mut this.listener.take() { + listener @ None => { + *listener = Some(this.semaphore.event.listen()); + } + Some(ref mut listener) => { + ready!(Pin::new(listener).poll(cx)); + this.listener = None; + } + } + } + } + } } }