From db1d90453c04740176c354002b56da6f8cb30f2c Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Tue, 9 Mar 2021 19:22:36 +0100 Subject: [PATCH] util: fuse PollSemaphore (#3578) --- tokio-util/src/sync/poll_semaphore.rs | 13 +++++----- tokio-util/tests/poll_semaphore.rs | 36 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) create mode 100644 tokio-util/tests/poll_semaphore.rs diff --git a/tokio-util/src/sync/poll_semaphore.rs b/tokio-util/src/sync/poll_semaphore.rs index d4594d03e86..6b22b0d633f 100644 --- a/tokio-util/src/sync/poll_semaphore.rs +++ b/tokio-util/src/sync/poll_semaphore.rs @@ -55,12 +55,13 @@ impl PollSemaphore { /// the `Waker` from the `Context` passed to the most recent call is /// scheduled to receive a wakeup. pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll> { - match ready!(self.permit_fut.poll(cx)) { - Ok(permit) => { - let next_fut = Arc::clone(&self.semaphore).acquire_owned(); - self.permit_fut.set(next_fut); - Poll::Ready(Some(permit)) - } + let result = ready!(self.permit_fut.poll(cx)); + + let next_fut = Arc::clone(&self.semaphore).acquire_owned(); + self.permit_fut.set(next_fut); + + match result { + Ok(permit) => Poll::Ready(Some(permit)), Err(_closed) => Poll::Ready(None), } } diff --git a/tokio-util/tests/poll_semaphore.rs b/tokio-util/tests/poll_semaphore.rs new file mode 100644 index 00000000000..0fdb3a446f7 --- /dev/null +++ b/tokio-util/tests/poll_semaphore.rs @@ -0,0 +1,36 @@ +use std::future::Future; +use std::sync::Arc; +use std::task::Poll; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; + +type SemRet = Option; + +fn semaphore_poll<'a>( + sem: &'a mut PollSemaphore, +) -> tokio_test::task::Spawn + 'a> { + let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx)); + tokio_test::task::spawn(fut) +} + +#[tokio::test] +async fn it_works() { + let sem = Arc::new(Semaphore::new(1)); + let mut poll_sem = PollSemaphore::new(sem.clone()); + + let permit = sem.acquire().await.unwrap(); + let mut poll = semaphore_poll(&mut poll_sem); + assert!(poll.poll().is_pending()); + drop(permit); + + assert!(matches!(poll.poll(), Poll::Ready(Some(_)))); + drop(poll); + + sem.close(); + + assert!(semaphore_poll(&mut poll_sem).await.is_none()); + + // Check that it is fused. + assert!(semaphore_poll(&mut poll_sem).await.is_none()); + assert!(semaphore_poll(&mut poll_sem).await.is_none()); +}