diff --git a/src/semaphore.rs b/src/semaphore.rs index 5a328b9..d036305 100644 --- a/src/semaphore.rs +++ b/src/semaphore.rs @@ -89,9 +89,7 @@ impl Semaphore { listener: None, } } -} -impl Semaphore { /// Attempts to get an owned permit for a concurrent operation. /// /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, an @@ -152,6 +150,30 @@ impl Semaphore { listener: None, } } + + /// Adds `n` additional permits to the semaphore. + /// + /// # Examples + /// + /// ``` + /// use async_lock::Semaphore; + /// + /// # futures_lite::future::block_on(async { + /// let s = Semaphore::new(1); + /// + /// let _guard = s.acquire().await; + /// assert!(s.try_acquire().is_none()); + /// + /// s.add_permits(2); + /// + /// let _guard = s.acquire().await; + /// let _guard = s.acquire().await; + /// # }); + /// ``` + pub fn add_permits(&self, n: usize) { + self.count.fetch_add(n, Ordering::AcqRel); + self.event.notify(n); + } } /// The future returned by [`Semaphore::acquire`]. diff --git a/tests/semaphore.rs b/tests/semaphore.rs index 89f8d5a..0779b68 100644 --- a/tests/semaphore.rs +++ b/tests/semaphore.rs @@ -1,12 +1,20 @@ mod common; -use std::sync::{mpsc, Arc}; +use std::future::Future; +use std::mem::forget; +use std::pin::Pin; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + mpsc, Arc, +}; +use std::task::Context; +use std::task::Poll; use std::thread; use common::check_yields_when_contended; use async_lock::Semaphore; -use futures_lite::future; +use futures_lite::{future, pin}; #[test] fn try_acquire() { @@ -105,3 +113,60 @@ fn yields_when_contended() { let s = Arc::new(s); check_yields_when_contended(s.try_acquire_arc().unwrap(), s.acquire_arc()); } + +#[test] +fn add_permits() { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let s = Arc::new(Semaphore::new(0)); + let (tx, rx) = mpsc::channel::<()>(); + + for _ in 0..50 { + let s = s.clone(); + let tx = tx.clone(); + + thread::spawn(move || { + future::block_on(async { + let perm = s.acquire().await; + forget(perm); + COUNTER.fetch_add(1, Ordering::Relaxed); + drop(tx); + }) + }); + } + + assert_eq!(COUNTER.load(Ordering::Relaxed), 0); + + s.add_permits(50); + + drop(tx); + let _ = rx.recv(); + + assert_eq!(COUNTER.load(Ordering::Relaxed), 50); +} + +#[test] +fn add_permits_2() { + future::block_on(AddPermitsTest); +} + +struct AddPermitsTest; + +impl Future for AddPermitsTest { + type Output = (); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let s = Semaphore::new(0); + let acq = s.acquire(); + pin!(acq); + let acq_2 = s.acquire(); + pin!(acq_2); + assert!(acq.as_mut().poll(cx).is_pending()); + assert!(acq_2.as_mut().poll(cx).is_pending()); + s.add_permits(1); + let g = acq.poll(cx); + assert!(g.is_ready()); + assert!(acq_2.poll(cx).is_pending()); + + Poll::Ready(()) + } +}