From c5f78cb997e803ff26398e9fc180cad7480c5de1 Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Mon, 20 Apr 2020 20:21:21 +0200 Subject: [PATCH 1/2] sync: add owned semaphore permit --- tokio/src/sync/mod.rs | 2 +- tokio/src/sync/semaphore.rs | 70 +++++++++++++++++++++++++-- tokio/tests/sync_semaphore_owned.rs | 75 +++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 6 deletions(-) create mode 100644 tokio/tests/sync_semaphore_owned.rs diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 0607f78ad42..4003bb69446 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -438,7 +438,7 @@ cfg_sync! { pub(crate) mod batch_semaphore; pub(crate) mod semaphore_ll; mod semaphore; - pub use semaphore::{Semaphore, SemaphorePermit}; + pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; mod rwlock; pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index 4cce7e8f5bc..c1dd975f282 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -1,5 +1,6 @@ use super::batch_semaphore as ll; // low level implementation use crate::coop::CoopFutureExt; +use std::sync::Arc; /// Counting semaphore performing asynchronous permit aquisition. /// @@ -18,7 +19,11 @@ pub struct Semaphore { ll_sem: ll::Semaphore, } -/// A permit from the semaphore +/// A permit from the semaphore. +/// +/// This type is created by the [`acquire`] method. +/// +/// [`acquire`]: crate::sync::Semaphore::acquire() #[must_use] #[derive(Debug)] pub struct SemaphorePermit<'a> { @@ -26,6 +31,18 @@ pub struct SemaphorePermit<'a> { permits: u16, } +/// An owned permit from the semaphore. +/// +/// This type is created by the [`acquire_owned`] method. +/// +/// [`acquire_owned`]: crate::sync::Semaphore::acquire_owned() +#[must_use] +#[derive(Debug)] +pub struct OwnedSemaphorePermit { + sem: Arc, + permits: u16, +} + /// Error returned from the [`Semaphore::try_acquire`] function. /// /// A `try_acquire` operation can only fail if the semaphore has no available @@ -51,14 +68,14 @@ fn bounds() { } impl Semaphore { - /// Creates a new semaphore with the initial number of permits + /// Creates a new semaphore with the initial number of permits. pub fn new(permits: usize) -> Self { Self { ll_sem: ll::Semaphore::new(permits), } } - /// Returns the current number of available permits + /// Returns the current number of available permits. pub fn available_permits(&self) -> usize { self.ll_sem.available_permits() } @@ -68,7 +85,7 @@ impl Semaphore { self.ll_sem.release(n); } - /// Acquires permit from the semaphore + /// Acquires permit from the semaphore. pub async fn acquire(&self) -> SemaphorePermit<'_> { self.ll_sem.acquire(1).cooperate().await.unwrap(); SemaphorePermit { @@ -77,7 +94,7 @@ impl Semaphore { } } - /// Tries to acquire a permit form the semaphore + /// Tries to acquire a permit from the semaphore. pub fn try_acquire(&self) -> Result, TryAcquireError> { match self.ll_sem.try_acquire(1) { Ok(_) => Ok(SemaphorePermit { @@ -87,6 +104,34 @@ impl Semaphore { Err(_) => Err(TryAcquireError(())), } } + + /// Acquires permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// + /// [`Arc`]: std::sync::Arc + pub async fn acquire_owned(self: Arc) -> OwnedSemaphorePermit { + self.ll_sem.acquire(1).cooperate().await.unwrap(); + OwnedSemaphorePermit { + sem: self.clone(), + permits: 1, + } + } + + /// Tries to acquire a permit from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// + /// [`Arc`]: std::sync::Arc + pub fn try_acquire_owned(self: Arc) -> Result { + match self.ll_sem.try_acquire(1) { + Ok(_) => Ok(OwnedSemaphorePermit { + sem: self.clone(), + permits: 1, + }), + Err(_) => Err(TryAcquireError(())), + } + } } impl<'a> SemaphorePermit<'a> { @@ -98,8 +143,23 @@ impl<'a> SemaphorePermit<'a> { } } +impl OwnedSemaphorePermit { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + pub fn forget(mut self) { + self.permits = 0; + } +} + impl<'a> Drop for SemaphorePermit<'_> { fn drop(&mut self) { self.sem.add_permits(self.permits as usize); } } + +impl Drop for OwnedSemaphorePermit { + fn drop(&mut self) { + self.sem.add_permits(self.permits as usize); + } +} diff --git a/tokio/tests/sync_semaphore_owned.rs b/tokio/tests/sync_semaphore_owned.rs new file mode 100644 index 00000000000..001867734fe --- /dev/null +++ b/tokio/tests/sync_semaphore_owned.rs @@ -0,0 +1,75 @@ +#![cfg(feature = "full")] + +use std::sync::Arc; +use tokio::sync::Semaphore; + +#[test] +fn try_acquire() { + let sem = Arc::new(Semaphore::new(1)); + { + let p1 = sem.try_acquire_owned(); + assert!(p1.is_ok()); + let p2 = sem.try_acquire_owned(); + assert!(p2.is_err()); + } + let p3 = sem.try_acquire_owned(); + assert!(p3.is_ok()); +} + +#[tokio::test] +async fn acquire() { + let sem = Arc::new(Semaphore::new(1)); + let p1 = sem.try_acquire_owned().unwrap(); + let sem_clone = sem.clone(); + let j = tokio::spawn(async move { + let _p2 = sem_clone.acquire_owned().await; + }); + drop(p1); + j.await.unwrap(); +} + +#[tokio::test] +async fn add_permits() { + let sem = Arc::new(Semaphore::new(0)); + let sem_clone = sem.clone(); + let j = tokio::spawn(async move { + let _p2 = sem_clone.acquire_owned().await; + }); + sem.add_permits(1); + j.await.unwrap(); +} + +#[test] +fn forget() { + let sem = Arc::new(Semaphore::new(1)); + { + let p = sem.try_acquire_owned().unwrap(); + assert_eq!(sem.available_permits(), 0); + p.forget(); + assert_eq!(sem.available_permits(), 0); + } + assert_eq!(sem.available_permits(), 0); + assert!(sem.try_acquire_owned().is_err()); +} + +#[tokio::test] +async fn stresstest() { + let sem = Arc::new(Semaphore::new(5)); + let mut join_handles = Vec::new(); + for _ in 0..1000 { + let sem_clone = sem.clone(); + join_handles.push(tokio::spawn(async move { + let _p = sem_clone.acquire_owned().await; + })); + } + for j in join_handles { + j.await.unwrap(); + } + // there should be exactly 5 semaphores available now + let _p1 = sem.try_acquire_owned().unwrap(); + let _p2 = sem.try_acquire_owned().unwrap(); + let _p3 = sem.try_acquire_owned().unwrap(); + let _p4 = sem.try_acquire_owned().unwrap(); + let _p5 = sem.try_acquire_owned().unwrap(); + assert!(sem.try_acquire_owned().is_err()); +} From 68e368910d1dc310f2482baa49919b49e63b9546 Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Mon, 20 Apr 2020 20:52:23 +0200 Subject: [PATCH 2/2] fmt --- tokio/tests/sync_semaphore_owned.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tokio/tests/sync_semaphore_owned.rs b/tokio/tests/sync_semaphore_owned.rs index 001867734fe..8ed6209f3b9 100644 --- a/tokio/tests/sync_semaphore_owned.rs +++ b/tokio/tests/sync_semaphore_owned.rs @@ -7,9 +7,9 @@ use tokio::sync::Semaphore; fn try_acquire() { let sem = Arc::new(Semaphore::new(1)); { - let p1 = sem.try_acquire_owned(); + let p1 = sem.clone().try_acquire_owned(); assert!(p1.is_ok()); - let p2 = sem.try_acquire_owned(); + let p2 = sem.clone().try_acquire_owned(); assert!(p2.is_err()); } let p3 = sem.try_acquire_owned(); @@ -19,7 +19,7 @@ fn try_acquire() { #[tokio::test] async fn acquire() { let sem = Arc::new(Semaphore::new(1)); - let p1 = sem.try_acquire_owned().unwrap(); + let p1 = sem.clone().try_acquire_owned().unwrap(); let sem_clone = sem.clone(); let j = tokio::spawn(async move { let _p2 = sem_clone.acquire_owned().await; @@ -43,7 +43,7 @@ async fn add_permits() { fn forget() { let sem = Arc::new(Semaphore::new(1)); { - let p = sem.try_acquire_owned().unwrap(); + let p = sem.clone().try_acquire_owned().unwrap(); assert_eq!(sem.available_permits(), 0); p.forget(); assert_eq!(sem.available_permits(), 0); @@ -66,10 +66,10 @@ async fn stresstest() { j.await.unwrap(); } // there should be exactly 5 semaphores available now - let _p1 = sem.try_acquire_owned().unwrap(); - let _p2 = sem.try_acquire_owned().unwrap(); - let _p3 = sem.try_acquire_owned().unwrap(); - let _p4 = sem.try_acquire_owned().unwrap(); - let _p5 = sem.try_acquire_owned().unwrap(); + let _p1 = sem.clone().try_acquire_owned().unwrap(); + let _p2 = sem.clone().try_acquire_owned().unwrap(); + let _p3 = sem.clone().try_acquire_owned().unwrap(); + let _p4 = sem.clone().try_acquire_owned().unwrap(); + let _p5 = sem.clone().try_acquire_owned().unwrap(); assert!(sem.try_acquire_owned().is_err()); }