diff --git a/tokio-util/src/task/join_map.rs b/tokio-util/src/task/join_map.rs index c9aed537d4b..1fbe274a2f8 100644 --- a/tokio-util/src/task/join_map.rs +++ b/tokio-util/src/task/join_map.rs @@ -5,6 +5,7 @@ use std::collections::hash_map::RandomState; use std::fmt; use std::future::Future; use std::hash::{BuildHasher, Hash, Hasher}; +use std::marker::PhantomData; use tokio::runtime::Handle; use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; @@ -626,6 +627,19 @@ where } } + /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order. + /// + /// If a task has completed, but its output hasn't yet been consumed by a + /// call to [`join_next`], this method will still return its key. + /// + /// [`join_next`]: fn@Self::join_next + pub fn keys(&self) -> JoinMapKeys<'_, K, V> { + JoinMapKeys { + iter: self.tasks_by_key.keys(), + _value: PhantomData, + } + } + /// Returns `true` if this `JoinMap` contains a task for the provided key. /// /// If the task has completed, but its output hasn't yet been consumed by a @@ -859,3 +873,32 @@ impl PartialEq for Key { } impl Eq for Key {} + +/// An iterator over the keys of a [`JoinMap`]. +#[derive(Debug, Clone)] +pub struct JoinMapKeys<'a, K, V> { + iter: hashbrown::hash_map::Keys<'a, Key, AbortHandle>, + /// To make it easier to change JoinMap in the future, keep V as a generic + /// parameter. + _value: PhantomData<&'a V>, +} + +impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> { + type Item = &'a K; + + fn next(&mut self) -> Option<&'a K> { + self.iter.next().map(|key| &key.key) + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> { + fn len(&self) -> usize { + self.iter.len() + } +} + +impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {} diff --git a/tokio-util/src/task/mod.rs b/tokio-util/src/task/mod.rs index de41dd5dbe8..a5f94a898e2 100644 --- a/tokio-util/src/task/mod.rs +++ b/tokio-util/src/task/mod.rs @@ -9,4 +9,4 @@ pub use spawn_pinned::LocalPoolHandle; #[cfg(tokio_unstable)] #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))] -pub use join_map::JoinMap; +pub use join_map::{JoinMap, JoinMapKeys}; diff --git a/tokio-util/tests/task_join_map.rs b/tokio-util/tests/task_join_map.rs index cef08b20252..1ab5f9ba832 100644 --- a/tokio-util/tests/task_join_map.rs +++ b/tokio-util/tests/task_join_map.rs @@ -109,6 +109,30 @@ async fn alternating() { } } +#[tokio::test] +async fn test_keys() { + use std::collections::HashSet; + + let mut map = JoinMap::new(); + + assert_eq!(map.len(), 0); + map.spawn(1, async {}); + assert_eq!(map.len(), 1); + map.spawn(2, async {}); + assert_eq!(map.len(), 2); + + let keys = map.keys().collect::>(); + assert!(keys.contains(&1)); + assert!(keys.contains(&2)); + + let _ = map.join_next().await.unwrap(); + let _ = map.join_next().await.unwrap(); + + assert_eq!(map.len(), 0); + let keys = map.keys().collect::>(); + assert!(keys.is_empty()); +} + #[tokio::test(start_paused = true)] async fn abort_by_key() { let mut map = JoinMap::new();