diff --git a/benches/Cargo.toml b/benches/Cargo.toml index f4a1d8fb71e..75723ce2105 100644 --- a/benches/Cargo.toml +++ b/benches/Cargo.toml @@ -22,3 +22,14 @@ harness = false name = "scheduler" path = "scheduler.rs" harness = false + + +[[bench]] +name = "sync_rwlock" +path = "sync_rwlock.rs" +harness = false + +[[bench]] +name = "sync_semaphore" +path = "sync_semaphore.rs" +harness = false diff --git a/benches/sync_rwlock.rs b/benches/sync_rwlock.rs new file mode 100644 index 00000000000..4eca9807b2e --- /dev/null +++ b/benches/sync_rwlock.rs @@ -0,0 +1,147 @@ +use bencher::{black_box, Bencher}; +use std::sync::Arc; +use tokio::{sync::RwLock, task}; + +fn read_uncontended(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .core_threads(6) + .threaded_scheduler() + .build() + .unwrap(); + + let lock = Arc::new(RwLock::new(())); + b.iter(|| { + let lock = lock.clone(); + rt.block_on(async move { + for _ in 0..6 { + let read = lock.read().await; + black_box(read); + } + }) + }); +} + +fn read_concurrent_uncontended_multi(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .core_threads(6) + .threaded_scheduler() + .build() + .unwrap(); + + async fn task(lock: Arc>) { + let read = lock.read().await; + black_box(read); + } + + let lock = Arc::new(RwLock::new(())); + b.iter(|| { + let lock = lock.clone(); + rt.block_on(async move { + let j = tokio::try_join! { + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())) + }; + j.unwrap(); + }) + }); +} + +fn read_concurrent_uncontended(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + + async fn task(lock: Arc>) { + let read = lock.read().await; + black_box(read); + } + + let lock = Arc::new(RwLock::new(())); + b.iter(|| { + let lock = lock.clone(); + rt.block_on(async move { + tokio::join! { + task(lock.clone()), + task(lock.clone()), + task(lock.clone()), + task(lock.clone()), + task(lock.clone()), + task(lock.clone()) + }; + }) + }); +} + +fn read_concurrent_contended_multi(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .core_threads(6) + .threaded_scheduler() + .build() + .unwrap(); + + async fn task(lock: Arc>) { + let read = lock.read().await; + black_box(read); + } + + let lock = Arc::new(RwLock::new(())); + b.iter(|| { + let lock = lock.clone(); + rt.block_on(async move { + let write = lock.write().await; + let j = tokio::try_join! { + async move { drop(write); Ok(()) }, + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())), + task::spawn(task(lock.clone())), + }; + j.unwrap(); + }) + }); +} + +fn read_concurrent_contended(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + + async fn task(lock: Arc>) { + let read = lock.read().await; + black_box(read); + } + + let lock = Arc::new(RwLock::new(())); + b.iter(|| { + let lock = lock.clone(); + rt.block_on(async move { + let write = lock.write().await; + tokio::join! { + async move { drop(write) }, + task(lock.clone()), + task(lock.clone()), + task(lock.clone()), + task(lock.clone()), + task(lock.clone()), + }; + }) + }); +} + +bencher::benchmark_group!( + sync_rwlock, + read_uncontended, + read_concurrent_uncontended, + read_concurrent_uncontended_multi, + read_concurrent_contended, + read_concurrent_contended_multi +); + +bencher::benchmark_main!(sync_rwlock); diff --git a/benches/sync_semaphore.rs b/benches/sync_semaphore.rs new file mode 100644 index 00000000000..c43311c0d35 --- /dev/null +++ b/benches/sync_semaphore.rs @@ -0,0 +1,130 @@ +use bencher::Bencher; +use std::sync::Arc; +use tokio::{sync::Semaphore, task}; + +fn uncontended(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .core_threads(6) + .threaded_scheduler() + .build() + .unwrap(); + + let s = Arc::new(Semaphore::new(10)); + b.iter(|| { + let s = s.clone(); + rt.block_on(async move { + for _ in 0..6 { + let permit = s.acquire().await; + drop(permit); + } + }) + }); +} + +async fn task(s: Arc) { + let permit = s.acquire().await; + drop(permit); +} + +fn uncontended_concurrent_multi(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .core_threads(6) + .threaded_scheduler() + .build() + .unwrap(); + + let s = Arc::new(Semaphore::new(10)); + b.iter(|| { + let s = s.clone(); + rt.block_on(async move { + let j = tokio::try_join! { + task::spawn(task(s.clone())), + task::spawn(task(s.clone())), + task::spawn(task(s.clone())), + task::spawn(task(s.clone())), + task::spawn(task(s.clone())), + task::spawn(task(s.clone())) + }; + j.unwrap(); + }) + }); +} + +fn uncontended_concurrent_single(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + + let s = Arc::new(Semaphore::new(10)); + b.iter(|| { + let s = s.clone(); + rt.block_on(async move { + tokio::join! { + task(s.clone()), + task(s.clone()), + task(s.clone()), + task(s.clone()), + task(s.clone()), + task(s.clone()) + }; + }) + }); +} + +fn contended_concurrent_multi(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .core_threads(6) + .threaded_scheduler() + .build() + .unwrap(); + + let s = Arc::new(Semaphore::new(5)); + b.iter(|| { + let s = s.clone(); + rt.block_on(async move { + let j = tokio::try_join! { + task::spawn(task(s.clone())), + task::spawn(task(s.clone())), + task::spawn(task(s.clone())), + task::spawn(task(s.clone())), + task::spawn(task(s.clone())), + task::spawn(task(s.clone())) + }; + j.unwrap(); + }) + }); +} + +fn contended_concurrent_single(b: &mut Bencher) { + let mut rt = tokio::runtime::Builder::new() + .basic_scheduler() + .build() + .unwrap(); + + let s = Arc::new(Semaphore::new(5)); + b.iter(|| { + let s = s.clone(); + rt.block_on(async move { + tokio::join! { + task(s.clone()), + task(s.clone()), + task(s.clone()), + task(s.clone()), + task(s.clone()), + task(s.clone()) + }; + }) + }); +} + +bencher::benchmark_group!( + sync_semaphore, + uncontended, + uncontended_concurrent_multi, + uncontended_concurrent_single, + contended_concurrent_multi, + contended_concurrent_single +); + +bencher::benchmark_main!(sync_semaphore); diff --git a/tokio/src/coop.rs b/tokio/src/coop.rs index e4cb0224160..19302559350 100644 --- a/tokio/src/coop.rs +++ b/tokio/src/coop.rs @@ -46,6 +46,8 @@ // NOTE: The doctests in this module are ignored since the whole module is (currently) private. use std::cell::Cell; +use std::future::Future; +use std::pin::Pin; use std::task::{Context, Poll}; /// Constant used to determine how much "work" a task is allowed to do without yielding. @@ -250,6 +252,74 @@ pub async fn proceed() { poll_fn(|cx| poll_proceed(cx)).await; } +pin_project_lite::pin_project! { + /// A future that cooperatively yields to the task scheduler when polling, + /// if the task's budget is exhausted. + /// + /// Internally, this is simply a future combinator which calls + /// [`poll_proceed`] in its `poll` implementation before polling the wrapped + /// future. + /// + /// # Examples + /// + /// ```rust,ignore + /// # #[tokio::main] + /// # async fn main() { + /// use tokio::coop::CoopFutureExt; + /// + /// async { /* ... */ } + /// .cooperate() + /// .await; + /// # } + /// ``` + /// + /// [`poll_proceed`]: fn.poll_proceed.html + #[derive(Debug)] + #[allow(unreachable_pub, dead_code)] + pub struct CoopFuture { + #[pin] + future: F, + } +} + +impl Future for CoopFuture { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + ready!(poll_proceed(cx)); + self.project().future.poll(cx) + } +} + +impl CoopFuture { + /// Returns a new `CoopFuture` wrapping the given future. + /// + #[allow(unreachable_pub, dead_code)] + pub fn new(future: F) -> Self { + Self { future } + } +} + +// Currently only used by `tokio::sync`; and if we make this combinator public, +// it should probably be on the `FutureExt` trait instead. +cfg_sync! { + /// Extension trait providing `Future::cooperate` extension method. + /// + /// Note: if/when the co-op API becomes public, this method should probably be + /// provided by `FutureExt`, instead. + pub(crate) trait CoopFutureExt: Future { + /// Wrap `self` to cooperatively yield to the scheduler when polling, if the + /// task's budget is exhausted. + fn cooperate(self) -> CoopFuture + where + Self: Sized, + { + CoopFuture::new(self) + } + } + + impl CoopFutureExt for F where F: Future {} +} + #[cfg(all(test, not(loom)))] mod test { use super::*; diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs new file mode 100644 index 00000000000..d89ac6ab885 --- /dev/null +++ b/tokio/src/sync/batch_semaphore.rs @@ -0,0 +1,553 @@ +//! # Implementation Details +//! +//! The semaphore is implemented using an intrusive linked list of waiters. An +//! atomic counter tracks the number of available permits. If the semaphore does +//! not contain the required number of permits, the task attempting to acquire +//! permits places its waker at the end of a queue. When new permits are made +//! available (such as by releasing an initial acquisition), they are assigned +//! to the task at the front of the queue, waking that task if its requested +//! number of permits is met. +//! +//! Because waiters are enqueued at the back of the linked list and dequeued +//! from the front, the semaphore is fair. Tasks trying to acquire large numbers +//! of permits at a time will always be woken eventually, even if many other +//! tasks are acquiring smaller numbers of permits. This means that in a +//! use-case like tokio's read-write lock, writers will not be starved by +//! readers. +use crate::loom::cell::CausalCell; +use crate::loom::sync::{atomic::AtomicUsize, Mutex, MutexGuard}; +use crate::util::linked_list::{self, LinkedList}; + +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::*; +use std::task::Poll::*; +use std::task::{Context, Poll, Waker}; +use std::{cmp, fmt}; + +/// An asynchronous counting semaphore which permits waiting on multiple permits at once. +pub(crate) struct Semaphore { + waiters: Mutex, + /// The current number of available permits in the semaphore. + permits: AtomicUsize, +} + +struct Waitlist { + queue: LinkedList, + closed: bool, +} + +/// Error returned by `Semaphore::try_acquire`. +#[derive(Debug)] +pub(crate) enum TryAcquireError { + Closed, + NoPermits, +} +/// Error returned by `Semaphore::acquire`. +#[derive(Debug)] +pub(crate) struct AcquireError(()); + +pub(crate) struct Acquire<'a> { + node: Waiter, + semaphore: &'a Semaphore, + num_permits: u16, + queued: bool, +} + +/// An entry in the wait queue. +struct Waiter { + /// The current state of the waiter. + /// + /// This is either the number of remaining permits required by + /// the waiter, or a flag indicating that the waiter is not yet queued. + state: AtomicUsize, + + /// The waker to notify the task awaiting permits. + /// + /// # Safety + /// + /// This may only be accessed while the wait queue is locked. + waker: CausalCell>, + + /// Intrusive linked-list pointers. + /// + /// # Safety + /// + /// This may only be accessed while the wait queue is locked. + /// + /// TODO: Ideally, we would be able to use loom to enforce that + /// this isn't accessed concurrently. However, it is difficult to + /// use a `CausalCell` here, since the `Link` trait requires _returning_ + /// references to `Pointers`, and `CausalCell` requires that checked access + /// take place inside a closure. We should consider changing `Pointers` to + /// use `CausalCell` internally. + pointers: linked_list::Pointers, + + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +impl Semaphore { + /// The maximum number of permits which a semaphore can hold. + /// + /// Note that this reserves three bits of flags in the permit counter, but + /// we only actually use one of them. However, the previous semaphore + /// implementation used three bits, so we will continue to reserve them to + /// avoid a breaking change if additional flags need to be aadded in the + /// future. + pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3; + const CLOSED: usize = 1; + const PERMIT_SHIFT: usize = 1; + + /// Creates a new semaphore with the initial number of permits + pub(crate) fn new(permits: usize) -> Self { + assert!( + permits <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::new(Waitlist { + queue: LinkedList::new(), + closed: false, + }), + } + } + + /// Returns the current number of available permits + pub(crate) fn available_permits(&self) -> usize { + self.permits.load(Acquire) >> Self::PERMIT_SHIFT + } + + /// Adds `n` new permits to the semaphore. + pub(crate) fn release(&self, added: usize) { + if added == 0 { + return; + } + + // Assign permits to the wait queue, returning a list containing all the + // waiters at the back of the queue that received enough permits to wake + // up. + let notified = self.add_permits_locked(added, self.waiters.lock().unwrap()); + + // Once we release the lock, notify all woken waiters. + notify_all(notified); + } + + /// Closes the semaphore. This prevents the semaphore from issuing new + /// permits and notifies all pending waiters. + // This will be used once the bounded MPSC is updated to use the new + // semaphore implementation. + #[allow(dead_code)] + pub(crate) fn close(&self) { + let notified = { + let mut waiters = self.waiters.lock().unwrap(); + // If the semaphore's permits counter has enough permits for an + // unqueued waiter to acquire all the permits it needs immediately, + // it won't touch the wait list. Therefore, we have to set a bit on + // the permit counter as well. However, we must do this while + // holding the lock --- otherwise, if we set the bit and then wait + // to acquire the lock we'll enter an inconsistent state where the + // permit counter is closed, but the wait list is not. + self.permits.fetch_or(Self::CLOSED, Release); + waiters.closed = true; + waiters.queue.take_all() + }; + notify_all(notified) + } + + pub(crate) fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { + let mut curr = self.permits.load(Acquire); + let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; + loop { + // Has the semaphore closed?git + if curr & Self::CLOSED > 0 { + return Err(TryAcquireError::Closed); + } + + // Are there enough permits remaining? + if curr < num_permits { + return Err(TryAcquireError::NoPermits); + } + + let next = curr - num_permits; + + match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => return Ok(()), + Err(actual) => curr = actual, + } + } + } + + pub(crate) fn acquire(&self, num_permits: u16) -> Acquire<'_> { + Acquire::new(self, num_permits) + } + + /// Release `rem` permits to the semaphore's wait list, starting from the + /// end of the queue. + /// + /// This returns a new `LinkedList` containing all the waiters that received + /// enough permits to be notified. Once the lock on the wait list is + /// released, this list should be drained and the waiters in it notified. + /// + /// If `rem` exceeds the number of permits needed by the wait list, the + /// remainder are assigned back to the semaphore. + fn add_permits_locked( + &self, + mut rem: usize, + mut waiters: MutexGuard<'_, Waitlist>, + ) -> LinkedList { + // Starting from the back of the wait queue, assign each waiter as many + // permits as it needs until we run out of permits to assign. + let mut last = None; + for waiter in waiters.queue.iter().rev() { + // Was the waiter assigned enough permits to wake it? + if !waiter.assign_permits(&mut rem) { + break; + } + last = Some(NonNull::from(waiter)); + } + + // If we assigned permits to all the waiters in the queue, and there are + // still permits left over, assign them back to the semaphore. + if rem > 0 { + let permits = rem << Self::PERMIT_SHIFT; + assert!( + permits < Self::MAX_PERMITS, + "cannot add more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let prev = self.permits.fetch_add(rem << Self::PERMIT_SHIFT, Release); + assert!( + prev + permits <= Self::MAX_PERMITS, + "number of added permits ({}) would overflow MAX_PERMITS ({})", + rem, + Self::MAX_PERMITS + ); + } + + // Split off the queue at the last waiter that was satisfied, creating a + // new list. Once we release the lock, we'll drain this list and notify + // the waiters in it. + if let Some(waiter) = last { + // Safety: it's only safe to call `split_back` with a pointer to a + // node in the same list as the one we call `split_back` on. Since + // we got the waiter pointer from the list's iterator, this is fine. + unsafe { waiters.queue.split_back(waiter) } + } else { + LinkedList::new() + } + } + + fn poll_acquire( + &self, + cx: &mut Context<'_>, + num_permits: u16, + node: Pin<&mut Waiter>, + queued: bool, + ) -> Poll> { + let mut acquired = 0; + + let needed = if queued { + node.state.load(Acquire) << Self::PERMIT_SHIFT + } else { + (num_permits as usize) << Self::PERMIT_SHIFT + }; + + let mut lock = None; + // First, try to take the requested number of permits from the + // semaphore. + let mut curr = self.permits.load(Acquire); + let mut waiters = loop { + // Has the semaphore closed? + if curr & Self::CLOSED > 0 { + return Ready(Err(AcquireError::closed())); + } + + let mut remaining = 0; + let total = curr + .checked_add(acquired) + .expect("number of permits must not overflow"); + let (next, acq) = if total >= needed { + let next = curr - (needed - acquired); + (next, needed >> Self::PERMIT_SHIFT) + } else { + remaining = (needed - acquired) - curr; + (0, curr >> Self::PERMIT_SHIFT) + }; + + if remaining > 0 && lock.is_none() { + // No permits were immediately available, so this permit will + // (probably) need to wait. We'll need to acquire a lock on the + // wait queue before continuing. We need to do this _before_ the + // CAS that sets the new value of the semaphore's `permits` + // counter. Otherwise, if we subtract the permits and then + // acquire the lock, we might miss additional permits being + // added while waiting for the lock. + lock = Some(self.waiters.lock().unwrap()); + } + + match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + acquired += acq; + if remaining == 0 { + if !queued { + return Ready(Ok(())); + } else if lock.is_none() { + break self.waiters.lock().unwrap(); + } + } + break lock.expect("lock must be acquired before waiting"); + } + Err(actual) => curr = actual, + } + }; + + if waiters.closed { + return Ready(Err(AcquireError::closed())); + } + + if node.assign_permits(&mut acquired) { + self.add_permits_locked(acquired, waiters); + return Ready(Ok(())); + } + + assert_eq!(acquired, 0); + + // Otherwise, register the waker & enqueue the node. + node.waker.with_mut(|waker| { + // Safety: the wait list is locked, so we may modify the waker. + let waker = unsafe { &mut *waker }; + // Do we need to register the new waker? + if waker + .as_ref() + .map(|waker| !waker.will_wake(cx.waker())) + .unwrap_or(true) + { + *waker = Some(cx.waker().clone()); + } + }); + + // If the waiter is not already in the wait queue, enqueue it. + if !queued { + let node = unsafe { + let node = Pin::into_inner_unchecked(node) as *mut _; + NonNull::new_unchecked(node) + }; + + waiters.queue.push_front(node); + } + + Pending + } +} + +impl fmt::Debug for Semaphore { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Semaphore") + .field("permits", &self.permits.load(Relaxed)) + .finish() + } +} + +/// Pop all waiters from `list`, starting at the end of the queue, and notify +/// them. +fn notify_all(mut list: LinkedList) { + while let Some(waiter) = list.pop_back() { + let waker = unsafe { waiter.as_ref().waker.with_mut(|waker| (*waker).take()) }; + + waker + .expect("if a node is in the wait list, it must have a waker") + .wake(); + } +} + +impl Waiter { + fn new(num_permits: u16) -> Self { + Waiter { + waker: CausalCell::new(None), + state: AtomicUsize::new(num_permits as usize), + pointers: linked_list::Pointers::new(), + _p: PhantomPinned, + } + } + + /// Assign permits to the waiter. + /// + /// Returns `true` if the waiter should be removed from the queue + fn assign_permits(&self, n: &mut usize) -> bool { + let mut curr = self.state.load(Acquire); + loop { + let assign = cmp::min(curr, *n); + let next = curr - assign; + match self.state.compare_exchange(curr, next, AcqRel, Acquire) { + Ok(_) => { + *n -= assign; + return next == 0; + } + Err(actual) => curr = actual, + } + } + } +} + +impl Future for Acquire<'_> { + type Output = Result<(), AcquireError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let (node, semaphore, needed, queued) = self.project(); + match semaphore.poll_acquire(cx, needed, node, *queued) { + Pending => { + *queued = true; + Pending + } + Ready(r) => { + r?; + *queued = false; + Ready(Ok(())) + } + } + } +} + +impl<'a> Acquire<'a> { + fn new(semaphore: &'a Semaphore, num_permits: u16) -> Self { + Self { + node: Waiter::new(num_permits), + semaphore, + num_permits, + queued: false, + } + } + + fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u16, &mut bool) { + fn is_unpin() {} + unsafe { + // Safety: all fields other than `node` are `Unpin` + + is_unpin::<&Semaphore>(); + is_unpin::<&mut bool>(); + is_unpin::(); + + let this = self.get_unchecked_mut(); + ( + Pin::new_unchecked(&mut this.node), + &this.semaphore, + this.num_permits, + &mut this.queued, + ) + } + } +} + +impl Drop for Acquire<'_> { + fn drop(&mut self) { + // If the future is completed, there is no node in the wait list, so we + // can skip acquiring the lock. + if !self.queued { + return; + } + + // This is where we ensure safety. The future is being dropped, + // which means we must ensure that the waiter entry is no longer stored + // in the linked list. + let mut waiters = match self.semaphore.waiters.lock() { + Ok(lock) => lock, + // Removing the node from the linked list is necessary to ensure + // safety. Even if the lock was poisoned, we need to make sure it is + // removed from the linked list before dropping it --- otherwise, + // the list will contain a dangling pointer to this node. + Err(e) => e.into_inner(), + }; + + // remove the entry from the list + let node = NonNull::from(&mut self.node); + // Safety: we have locked the wait list. + unsafe { waiters.queue.remove(node) }; + + let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire); + if acquired_permits > 0 { + let notified = self.semaphore.add_permits_locked(acquired_permits, waiters); + notify_all(notified); + } + } +} + +// ===== impl AcquireError ==== + +impl AcquireError { + fn closed() -> AcquireError { + AcquireError(()) + } +} + +impl fmt::Display for AcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "semaphore closed") + } +} + +impl std::error::Error for AcquireError {} + +// ===== impl TryAcquireError ===== + +impl TryAcquireError { + /// Returns `true` if the error was caused by a closed semaphore. + #[allow(dead_code)] // may be used later! + pub(crate) fn is_closed(&self) -> bool { + match self { + TryAcquireError::Closed => true, + _ => false, + } + } + + /// Returns `true` if the error was caused by calling `try_acquire` on a + /// semaphore with no available permits. + #[allow(dead_code)] // may be used later! + pub(crate) fn is_no_permits(&self) -> bool { + match self { + TryAcquireError::NoPermits => true, + _ => false, + } + } +} + +impl fmt::Display for TryAcquireError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryAcquireError::Closed => write!(fmt, "{}", "semaphore closed"), + TryAcquireError::NoPermits => write!(fmt, "{}", "no permits available"), + } + } +} + +impl std::error::Error for TryAcquireError {} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + // XXX: ideally, we would be able to use `Pin` here, to enforce the + // invariant that list entries may not move while in the list. However, we + // can't do this currently, as using `Pin<&'a mut Waiter>` as the `Handle` + // type would require `Semaphore` to be generic over a lifetime. We can't + // use `Pin<*mut Waiter>`, as raw pointers are `Unpin` regardless of whether + // or not they dereference to an `!Unpin` target. + type Handle = NonNull; + type Target = Waiter; + + fn as_raw(handle: &Self::Handle) -> NonNull { + *handle + } + + unsafe fn from_raw(ptr: NonNull) -> NonNull { + ptr + } + + unsafe fn pointers(mut target: NonNull) -> NonNull> { + NonNull::from(&mut target.as_mut().pointers) + } +} diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 5d7b29aed99..0607f78ad42 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -435,6 +435,7 @@ cfg_sync! { pub mod oneshot; + pub(crate) mod batch_semaphore; pub(crate) mod semaphore_ll; mod semaphore; pub use semaphore::{Semaphore, SemaphorePermit}; diff --git a/tokio/src/sync/mutex.rs b/tokio/src/sync/mutex.rs index 4aceb000d1f..dac5ac16c42 100644 --- a/tokio/src/sync/mutex.rs +++ b/tokio/src/sync/mutex.rs @@ -78,9 +78,8 @@ //! //! [`Mutex`]: struct.Mutex.html //! [`MutexGuard`]: struct.MutexGuard.html - -use crate::future::poll_fn; -use crate::sync::semaphore_ll as semaphore; +use crate::coop::CoopFutureExt; +use crate::sync::batch_semaphore as semaphore; use std::cell::UnsafeCell; use std::error::Error; @@ -108,7 +107,6 @@ pub struct Mutex { /// will succeed yet again. pub struct MutexGuard<'a, T> { lock: &'a Mutex, - permit: semaphore::Permit, } // As long as T: Send, it's fine to send and share Mutex between threads. @@ -137,8 +135,10 @@ impl Error for TryLockError {} #[test] #[cfg(not(loom))] fn bounds() { - fn check() {} - check::>(); + fn check_send() {} + fn check_unpin() {} + check_send::>(); + check_unpin::>(); } impl Mutex { @@ -152,30 +152,18 @@ impl Mutex { /// A future that resolves on acquiring the lock and returns the `MutexGuard`. pub async fn lock(&self) -> MutexGuard<'_, T> { - let mut guard = MutexGuard { - lock: self, - permit: semaphore::Permit::new(), - }; - poll_fn(|cx| { - // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); - - guard.permit.poll_acquire(cx, 1, &self.s) - }) - .await - .unwrap_or_else(|_| { + self.s.acquire(1).cooperate().await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); - guard + MutexGuard { lock: self } } /// Tries to acquire the lock pub fn try_lock(&self) -> Result, TryLockError> { - let mut permit = semaphore::Permit::new(); - match permit.try_acquire(1, &self.s) { - Ok(_) => Ok(MutexGuard { lock: self, permit }), + match self.s.try_acquire(1) { + Ok(_) => Ok(MutexGuard { lock: self }), Err(_) => Err(TryLockError(())), } } @@ -188,7 +176,7 @@ impl Mutex { impl<'a, T> Drop for MutexGuard<'a, T> { fn drop(&mut self) { - self.permit.release(1, &self.lock.s); + self.lock.s.release(1) } } @@ -210,14 +198,12 @@ where impl<'a, T> Deref for MutexGuard<'a, T> { type Target = T; fn deref(&self) -> &Self::Target { - assert!(self.permit.is_acquired()); unsafe { &*self.lock.c.get() } } } impl<'a, T> DerefMut for MutexGuard<'a, T> { fn deref_mut(&mut self) -> &mut Self::Target { - assert!(self.permit.is_acquired()); unsafe { &mut *self.lock.c.get() } } } diff --git a/tokio/src/sync/rwlock.rs b/tokio/src/sync/rwlock.rs index 97921b9fded..7cce69a5c5d 100644 --- a/tokio/src/sync/rwlock.rs +++ b/tokio/src/sync/rwlock.rs @@ -1,8 +1,7 @@ -use crate::future::poll_fn; -use crate::sync::semaphore_ll::{AcquireError, Permit, Semaphore}; +use crate::coop::CoopFutureExt; +use crate::sync::batch_semaphore::{AcquireError, Semaphore}; use std::cell::UnsafeCell; use std::ops; -use std::task::{Context, Poll}; #[cfg(not(loom))] const MAX_READS: usize = 32; @@ -109,29 +108,42 @@ pub struct RwLockWriteGuard<'a, T> { #[derive(Debug)] struct ReleasingPermit<'a, T> { num_permits: u16, - permit: Permit, lock: &'a RwLock, } impl<'a, T> ReleasingPermit<'a, T> { - fn poll_acquire( - &mut self, - cx: &mut Context<'_>, - s: &Semaphore, - ) -> Poll> { - // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); - - self.permit.poll_acquire(cx, self.num_permits, s) + async fn acquire( + lock: &'a RwLock, + num_permits: u16, + ) -> Result, AcquireError> { + lock.s.acquire(num_permits).cooperate().await?; + Ok(Self { num_permits, lock }) } } impl<'a, T> Drop for ReleasingPermit<'a, T> { fn drop(&mut self) { - self.permit.release(self.num_permits, &self.lock.s); + self.lock.s.release(self.num_permits as usize); } } +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_send() {} + fn check_sync() {} + fn check_unpin() {} + check_send::>(); + check_sync::>(); + check_unpin::>(); + + check_sync::>(); + check_unpin::>(); + + check_sync::>(); + check_unpin::>(); +} + // As long as T: Send + Sync, it's fine to send and share RwLock between threads. // If T were not Send, sending and sharing a RwLock would be bad, since you can access T through // RwLock. @@ -189,19 +201,11 @@ impl RwLock { ///} /// ``` pub async fn read(&self) -> RwLockReadGuard<'_, T> { - let mut permit = ReleasingPermit { - num_permits: 1, - permit: Permit::new(), - lock: self, - }; - - poll_fn(|cx| permit.poll_acquire(cx, &self.s)) - .await - .unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. - unreachable!() - }); + let permit = ReleasingPermit::acquire(self, 1).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); RwLockReadGuard { lock: self, permit } } @@ -228,13 +232,7 @@ impl RwLock { ///} /// ``` pub async fn write(&self) -> RwLockWriteGuard<'_, T> { - let mut permit = ReleasingPermit { - num_permits: MAX_READS as u16, - permit: Permit::new(), - lock: self, - }; - - poll_fn(|cx| permit.poll_acquire(cx, &self.s)) + let permit = ReleasingPermit::acquire(self, MAX_READS as u16) .await .unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index ec43bc522b0..e34e49cc7fe 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -1,5 +1,5 @@ -use super::semaphore_ll as ll; // low level implementation -use crate::future::poll_fn; +use super::batch_semaphore as ll; // low level implementation +use crate::coop::CoopFutureExt; /// Counting semaphore performing asynchronous permit aquisition. /// @@ -23,8 +23,7 @@ pub struct Semaphore { #[derive(Debug)] pub struct SemaphorePermit<'a> { sem: &'a Semaphore, - // the low level permit - ll_permit: ll::Permit, + permits: u16, } /// Error returned from the [`Semaphore::try_acquire`] function. @@ -36,6 +35,14 @@ pub struct SemaphorePermit<'a> { #[derive(Debug)] pub struct TryAcquireError(()); +#[test] +#[cfg(not(loom))] +fn bounds() { + fn check_unpin() {} + check_unpin::(); + check_unpin::>(); +} + impl Semaphore { /// Creates a new semaphore with the initial number of permits pub fn new(permits: usize) -> Self { @@ -51,33 +58,24 @@ impl Semaphore { /// Adds `n` new permits to the semaphore. pub fn add_permits(&self, n: usize) { - self.ll_sem.add_permits(n); + self.ll_sem.release(n); } /// Acquires permit from the semaphore pub async fn acquire(&self) -> SemaphorePermit<'_> { - let mut permit = SemaphorePermit { + self.ll_sem.acquire(1).cooperate().await.unwrap(); + SemaphorePermit { sem: &self, - ll_permit: ll::Permit::new(), - }; - poll_fn(|cx| { - // Keep track of task budget - ready!(crate::coop::poll_proceed(cx)); - - permit.ll_permit.poll_acquire(cx, 1, &self.ll_sem) - }) - .await - .unwrap(); - permit + permits: 1, + } } /// Tries to acquire a permit form the semaphore pub fn try_acquire(&self) -> Result, TryAcquireError> { - let mut ll_permit = ll::Permit::new(); - match ll_permit.try_acquire(1, &self.ll_sem) { + match self.ll_sem.try_acquire(1) { Ok(_) => Ok(SemaphorePermit { sem: self, - ll_permit, + permits: 1, }), Err(_) => Err(TryAcquireError(())), } @@ -89,12 +87,12 @@ impl<'a> SemaphorePermit<'a> { /// This can be used to reduce the amount of permits available from a /// semaphore. pub fn forget(mut self) { - self.ll_permit.forget(1); + self.permits = 0; } } impl<'a> Drop for SemaphorePermit<'_> { fn drop(&mut self) { - self.ll_permit.release(1, &self.sem.ll_sem); + self.sem.add_permits(self.permits as usize); } } diff --git a/tokio/src/sync/semaphore_ll.rs b/tokio/src/sync/semaphore_ll.rs index 69fd4a6a5d3..b56f21a8135 100644 --- a/tokio/src/sync/semaphore_ll.rs +++ b/tokio/src/sync/semaphore_ll.rs @@ -610,6 +610,7 @@ impl Permit { } /// Returns `true` if the permit has been acquired + #[allow(dead_code)] // may be used later pub(crate) fn is_acquired(&self) -> bool { match self.state { PermitState::Acquired(num) if num > 0 => true, diff --git a/tokio/src/sync/tests/loom_semaphore_batch.rs b/tokio/src/sync/tests/loom_semaphore_batch.rs new file mode 100644 index 00000000000..4c1936c5998 --- /dev/null +++ b/tokio/src/sync/tests/loom_semaphore_batch.rs @@ -0,0 +1,171 @@ +use crate::sync::batch_semaphore::*; + +use futures::future::poll_fn; +use loom::future::block_on; +use loom::sync::atomic::AtomicUsize; +use loom::thread; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::Arc; +use std::task::Poll::Ready; +use std::task::{Context, Poll}; + +#[test] +fn basic_usage() { + const NUM: usize = 2; + + struct Shared { + semaphore: Semaphore, + active: AtomicUsize, + } + + async fn actor(shared: Arc) { + shared.semaphore.acquire(1).await.unwrap(); + let actual = shared.active.fetch_add(1, SeqCst); + assert!(actual <= NUM - 1); + + let actual = shared.active.fetch_sub(1, SeqCst); + assert!(actual <= NUM); + shared.semaphore.release(1); + } + + loom::model(|| { + let shared = Arc::new(Shared { + semaphore: Semaphore::new(NUM), + active: AtomicUsize::new(0), + }); + + for _ in 0..NUM { + let shared = shared.clone(); + + thread::spawn(move || { + block_on(actor(shared)); + }); + } + + block_on(actor(shared)); + }); +} + +#[test] +fn release() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + { + let semaphore = semaphore.clone(); + thread::spawn(move || { + block_on(semaphore.acquire(1)).unwrap(); + semaphore.release(1); + }); + } + + block_on(semaphore.acquire(1)).unwrap(); + + semaphore.release(1); + }); +} + +#[test] +fn basic_closing() { + const NUM: usize = 2; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + for _ in 0..2 { + block_on(semaphore.acquire(1)).map_err(|_| ())?; + + semaphore.release(1); + } + + Ok::<(), ()>(()) + }); + } + + semaphore.close(); + }); +} + +#[test] +fn concurrent_close() { + const NUM: usize = 3; + + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(1)); + + for _ in 0..NUM { + let semaphore = semaphore.clone(); + + thread::spawn(move || { + block_on(semaphore.acquire(1)).map_err(|_| ())?; + semaphore.release(1); + semaphore.close(); + + Ok::<(), ()>(()) + }); + } + }); +} + +#[test] +fn batch() { + let mut b = loom::model::Builder::new(); + b.preemption_bound = Some(1); + + b.check(|| { + let semaphore = Arc::new(Semaphore::new(10)); + let active = Arc::new(AtomicUsize::new(0)); + let mut ths = vec![]; + + for _ in 0..2 { + let semaphore = semaphore.clone(); + let active = active.clone(); + + ths.push(thread::spawn(move || { + for n in &[4, 10, 8] { + block_on(semaphore.acquire(*n)).unwrap(); + + active.fetch_add(*n as usize, SeqCst); + + let num_active = active.load(SeqCst); + assert!(num_active <= 10); + + thread::yield_now(); + + active.fetch_sub(*n as usize, SeqCst); + + semaphore.release(*n as usize); + } + })); + } + + for th in ths.into_iter() { + th.join().unwrap(); + } + + assert_eq!(10, semaphore.available_permits()); + }); +} + +#[test] +fn release_during_acquire() { + loom::model(|| { + let semaphore = Arc::new(Semaphore::new(10)); + semaphore + .try_acquire(8) + .expect("try_acquire should succeed; semaphore uncontended"); + let semaphore2 = semaphore.clone(); + let thread = thread::spawn(move || block_on(semaphore2.acquire(4)).unwrap()); + + semaphore.release(8); + thread.join().unwrap(); + semaphore.release(4); + assert_eq!(10, semaphore.available_permits()); + }) +} diff --git a/tokio/src/sync/tests/mod.rs b/tokio/src/sync/tests/mod.rs index 7225ce9c58c..d571754c011 100644 --- a/tokio/src/sync/tests/mod.rs +++ b/tokio/src/sync/tests/mod.rs @@ -1,6 +1,7 @@ cfg_not_loom! { mod atomic_waker; mod semaphore_ll; + mod semaphore_batch; } cfg_loom! { @@ -10,5 +11,6 @@ cfg_loom! { mod loom_mpsc; mod loom_notify; mod loom_oneshot; + mod loom_semaphore_batch; mod loom_semaphore_ll; } diff --git a/tokio/src/sync/tests/semaphore_batch.rs b/tokio/src/sync/tests/semaphore_batch.rs new file mode 100644 index 00000000000..60f3f231e76 --- /dev/null +++ b/tokio/src/sync/tests/semaphore_batch.rs @@ -0,0 +1,250 @@ +use crate::sync::batch_semaphore::Semaphore; +use tokio_test::*; + +#[test] +fn poll_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 99); +} + +#[test] +fn poll_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + // Polling for a permit succeeds immediately + assert_ready_ok!(task::spawn(s.acquire(5)).poll()); + assert_eq!(s.available_permits(), 95); + + assert_ready_ok!(task::spawn(s.acquire(5)).poll()); + assert_eq!(s.available_permits(), 90); +} + +#[test] +fn try_acquire_one_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 99); + + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 98); +} + +#[test] +fn try_acquire_many_available() { + let s = Semaphore::new(100); + assert_eq!(s.available_permits(), 100); + + assert_ok!(s.try_acquire(5)); + assert_eq!(s.available_permits(), 95); + + assert_ok!(s.try_acquire(5)); + assert_eq!(s.available_permits(), 90); +} + +#[test] +fn poll_acquire_one_unavailable() { + let s = Semaphore::new(1); + + // Acquire the first permit + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 0); + + let mut acquire_2 = task::spawn(s.acquire(1)); + // Try to acquire the second permit + assert_pending!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + + assert_eq!(s.available_permits(), 0); + assert!(acquire_2.is_woken()); + assert_ready_ok!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn poll_acquire_many_unavailable() { + let s = Semaphore::new(5); + + // Acquire the first permit + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_eq!(s.available_permits(), 4); + + // Try to acquire the second permit + let mut acquire_2 = task::spawn(s.acquire(5)); + assert_pending!(acquire_2.poll()); + assert_eq!(s.available_permits(), 0); + + // Try to acquire the third permit + let mut acquire_3 = task::spawn(s.acquire(3)); + assert_pending!(acquire_3.poll()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + + assert_eq!(s.available_permits(), 0); + assert!(acquire_2.is_woken()); + assert_ready_ok!(acquire_2.poll()); + + assert!(!acquire_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + s.release(1); + assert!(!acquire_3.is_woken()); + assert_eq!(s.available_permits(), 0); + + s.release(2); + assert!(acquire_3.is_woken()); + + assert_ready_ok!(acquire_3.poll()); +} + +#[test] +fn try_acquire_one_unavailable() { + let s = Semaphore::new(1); + + // Acquire the first permit + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 0); + + assert_err!(s.try_acquire(1)); + + s.release(1); + + assert_eq!(s.available_permits(), 1); + assert_ok!(s.try_acquire(1)); + + s.release(1); + assert_eq!(s.available_permits(), 1); +} + +#[test] +fn try_acquire_many_unavailable() { + let s = Semaphore::new(5); + + // Acquire the first permit + assert_ok!(s.try_acquire(1)); + assert_eq!(s.available_permits(), 4); + + assert_err!(s.try_acquire(5)); + + s.release(1); + assert_eq!(s.available_permits(), 5); + + assert_ok!(s.try_acquire(5)); + + s.release(1); + assert_eq!(s.available_permits(), 1); + + s.release(1); + assert_eq!(s.available_permits(), 2); +} + +#[test] +fn poll_acquire_one_zero_permits() { + let s = Semaphore::new(0); + assert_eq!(s.available_permits(), 0); + + // Try to acquire the permit + let mut acquire = task::spawn(s.acquire(1)); + assert_pending!(acquire.poll()); + + s.release(1); + + assert!(acquire.is_woken()); + assert_ready_ok!(acquire.poll()); +} + +#[test] +#[should_panic] +fn validates_max_permits() { + use std::usize; + Semaphore::new((usize::MAX >> 2) + 1); +} + +#[test] +fn close_semaphore_prevents_acquire() { + let s = Semaphore::new(5); + s.close(); + + assert_eq!(5, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + assert_eq!(5, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + assert_eq!(5, s.available_permits()); +} + +#[test] +fn close_semaphore_notifies_permit1() { + let s = Semaphore::new(0); + let mut acquire = task::spawn(s.acquire(1)); + + assert_pending!(acquire.poll()); + + s.close(); + + assert!(acquire.is_woken()); + assert_ready_err!(acquire.poll()); +} + +#[test] +fn close_semaphore_notifies_permit2() { + let s = Semaphore::new(2); + + // Acquire a couple of permits + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + assert_ready_ok!(task::spawn(s.acquire(1)).poll()); + + let mut acquire3 = task::spawn(s.acquire(1)); + let mut acquire4 = task::spawn(s.acquire(1)); + assert_pending!(acquire3.poll()); + assert_pending!(acquire4.poll()); + + s.close(); + + assert!(acquire3.is_woken()); + assert!(acquire4.is_woken()); + + assert_ready_err!(acquire3.poll()); + assert_ready_err!(acquire4.poll()); + + assert_eq!(0, s.available_permits()); + + s.release(1); + + assert_eq!(1, s.available_permits()); + + assert_ready_err!(task::spawn(s.acquire(1)).poll()); + + s.release(1); + + assert_eq!(2, s.available_permits()); +} + +#[test] +fn cancel_acquire_releases_permits() { + let s = Semaphore::new(10); + let _permit1 = s.try_acquire(4).expect("uncontended try_acquire succeeds"); + assert_eq!(6, s.available_permits()); + + let mut acquire = task::spawn(s.acquire(8)); + assert_pending!(acquire.poll()); + + assert_eq!(0, s.available_permits()); + drop(acquire); + + assert_eq!(6, s.available_permits()); + assert_ok!(s.try_acquire(6)); +} diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index 07c25fe983a..1a48803273a 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -4,6 +4,7 @@ //! structure's APIs are `unsafe` as they require the caller to ensure the //! specified node is actually contained by the list. +use core::fmt; use core::mem::ManuallyDrop; use core::ptr::NonNull; @@ -11,7 +12,6 @@ use core::ptr::NonNull; /// /// Currently, the list is not emptied on drop. It is the caller's /// responsibility to ensure the list is empty before dropping it. -#[derive(Debug)] pub(crate) struct LinkedList { /// Linked list head head: Option>, @@ -53,7 +53,6 @@ pub(crate) unsafe trait Link { } /// Previous / next pointers -#[derive(Debug)] pub(crate) struct Pointers { /// The previous node in the list. null if there is no previous node. prev: Option>, @@ -81,7 +80,7 @@ impl LinkedList { // The value should not be dropped, it is being inserted into the list let val = ManuallyDrop::new(val); let ptr = T::as_raw(&*val); - + assert_ne!(self.head, Some(ptr)); unsafe { T::pointers(ptr).as_mut().next = self.head; T::pointers(ptr).as_mut().prev = None; @@ -165,32 +164,98 @@ impl LinkedList { } } -// ===== impl Iter ===== +cfg_sync! { + impl LinkedList { + /// Splits this list off at `node`, returning a new list with `node` at its + /// front. + /// + /// If `node` is at the the front of this list, then this list will be empty after + /// splitting. If `node` is the last node in this list, then the returned + /// list will contain only `node`. + /// + /// # Safety + /// + /// The caller **must** ensure that `node` is currently contained by + /// `self` or not contained by any other list. + pub(crate) unsafe fn split_back(&mut self, node: NonNull) -> Self { + let new_tail = T::pointers(node).as_mut().prev.take().map(|prev| { + T::pointers(prev).as_mut().next = None; + prev + }); + if new_tail.is_none() { + self.head = None; + } + let tail = std::mem::replace(&mut self.tail, new_tail); + Self { + head: Some(node), + tail, + } + } -cfg_rt_threaded! { - use core::marker::PhantomData; + /// Takes all entries from this list, returning a new list. + /// + /// This list will be left empty. + pub(crate) fn take_all(&mut self) -> Self { + Self { + head: self.head.take(), + tail: self.tail.take(), + } + } + } +} - pub(crate) struct Iter<'a, T: Link> { - curr: Option>, - _p: PhantomData<&'a T>, +impl fmt::Debug for LinkedList { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LinkedList") + .field("head", &self.head) + .field("tail", &self.tail) + .finish() } +} - impl LinkedList { - pub(crate) fn iter(&self) -> Iter<'_, T> { - Iter { - curr: self.head, - _p: PhantomData, - } +// ===== impl Iter ===== + +#[cfg(any(feature = "sync", feature = "rt-threaded"))] +pub(crate) struct Iter<'a, T: Link> { + curr: Option>, + #[cfg(feature = "sync")] + curr_back: Option>, + _p: core::marker::PhantomData<&'a T>, +} + +#[cfg(any(feature = "sync", feature = "rt-threaded"))] +impl LinkedList { + pub(crate) fn iter(&self) -> Iter<'_, T> { + Iter { + curr: self.head, + #[cfg(feature = "sync")] + curr_back: self.tail, + _p: core::marker::PhantomData, } } +} - impl<'a, T: Link> Iterator for Iter<'a, T> { - type Item = &'a T::Target; +#[cfg(any(feature = "sync", feature = "rt-threaded"))] +impl<'a, T: Link> Iterator for Iter<'a, T> { + type Item = &'a T::Target; + + fn next(&mut self) -> Option<&'a T::Target> { + let curr = self.curr?; + // safety: the pointer references data contained by the list + self.curr = unsafe { T::pointers(curr).as_ref() }.next; + + // safety: the value is still owned by the linked list. + Some(unsafe { &*curr.as_ptr() }) + } +} + +cfg_sync! { + impl<'a, T: Link> DoubleEndedIterator for Iter<'a, T> { + fn next_back(&mut self) -> Option<&'a T::Target> { + let curr = self.curr_back?; - fn next(&mut self) -> Option<&'a T::Target> { - let curr = self.curr?; // safety: the pointer references data contained by the list - self.curr = unsafe { T::pointers(curr).as_ref() }.next; + self.curr_back = unsafe { T::pointers(curr).as_ref() }.prev; // safety: the value is still owned by the linked list. Some(unsafe { &*curr.as_ptr() }) @@ -210,6 +275,15 @@ impl Pointers { } } +impl fmt::Debug for Pointers { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Pointers") + .field("prev", &self.prev) + .field("next", &self.next) + .finish() + } +} + #[cfg(test)] #[cfg(not(loom))] mod tests { @@ -217,6 +291,7 @@ mod tests { use std::pin::Pin; + #[derive(Debug)] struct Entry { pointers: Pointers, val: i32, @@ -489,6 +564,86 @@ mod tests { assert!(i.next().is_none()); } + #[test] + fn split_back() { + let a = entry(1); + let b = entry(2); + let c = entry(3); + let d = entry(4); + + { + let mut list1 = LinkedList::<&Entry>::new(); + + push_all( + &mut list1, + &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()], + ); + let mut list2 = unsafe { list1.split_back(ptr(&a)) }; + + assert_eq!([2, 3, 4].to_vec(), collect_list(&mut list1)); + assert_eq!([1].to_vec(), collect_list(&mut list2)); + } + + { + let mut list1 = LinkedList::<&Entry>::new(); + + push_all( + &mut list1, + &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()], + ); + let mut list2 = unsafe { list1.split_back(ptr(&b)) }; + + assert_eq!([3, 4].to_vec(), collect_list(&mut list1)); + assert_eq!([1, 2].to_vec(), collect_list(&mut list2)); + } + + { + let mut list1 = LinkedList::<&Entry>::new(); + + push_all( + &mut list1, + &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()], + ); + let mut list2 = unsafe { list1.split_back(ptr(&c)) }; + + assert_eq!([4].to_vec(), collect_list(&mut list1)); + assert_eq!([1, 2, 3].to_vec(), collect_list(&mut list2)); + } + + { + let mut list1 = LinkedList::<&Entry>::new(); + + push_all( + &mut list1, + &[a.as_ref(), b.as_ref(), c.as_ref(), d.as_ref()], + ); + let mut list2 = unsafe { list1.split_back(ptr(&d)) }; + + assert_eq!(Vec::::new(), collect_list(&mut list1)); + assert_eq!([1, 2, 3, 4].to_vec(), collect_list(&mut list2)); + } + } + + #[test] + fn take_all() { + let mut list1 = LinkedList::<&Entry>::new(); + let a = entry(1); + let b = entry(2); + + list1.push_front(a.as_ref()); + list1.push_front(b.as_ref()); + + assert!(!list1.is_empty()); + + let mut list2 = list1.take_all(); + + assert!(list1.is_empty()); + assert!(!list2.is_empty()); + + assert_eq!(Vec::::new(), collect_list(&mut list1)); + assert_eq!([1, 2].to_vec(), collect_list(&mut list2)); + } + proptest::proptest! { #[test] fn fuzz_linked_list(ops: Vec) {