Skip to content

Commit

Permalink
task: add join_all method to JoinSet (#6784)
Browse files Browse the repository at this point in the history
Adds join_all method to JoinSet. join_all consumes JoinSet and awaits
the completion of all tasks on it, returning the results of the tasks in
a vec. An error or panic in the task will cause join_all to panic,
canceling all other tasks.

Fixes: #6664
  • Loading branch information
hmaka authored Aug 26, 2024
1 parent 1ac8dff commit cc70a21
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 1 deletion.
75 changes: 74 additions & 1 deletion tokio/src/task/join_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
//! of spawned tasks and allows asynchronously awaiting the output of those
//! tasks as they complete. See the documentation for the [`JoinSet`] type for
//! details.
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, panic};

use crate::runtime::Handle;
#[cfg(tokio_unstable)]
Expand Down Expand Up @@ -374,6 +374,79 @@ impl<T: 'static> JoinSet<T> {
while self.join_next().await.is_some() {}
}

/// Awaits the completion of all tasks in this `JoinSet`, returning a vector of their results.
///
/// The results will be stored in the order they completed not the order they were spawned.
/// This is a convenience method that is equivalent to calling [`join_next`] in
/// a loop. If any tasks on the `JoinSet` fail with an [`JoinError`], then this call
/// to `join_all` will panic and all remaining tasks on the `JoinSet` are
/// cancelled. To handle errors in any other way, manually call [`join_next`]
/// in a loop.
///
/// # Examples
///
/// Spawn multiple tasks and `join_all` them.
///
/// ```
/// use tokio::task::JoinSet;
/// use std::time::Duration;
///
/// #[tokio::main]
/// async fn main() {
/// let mut set = JoinSet::new();
///
/// for i in 0..3 {
/// set.spawn(async move {
/// tokio::time::sleep(Duration::from_secs(3 - i)).await;
/// i
/// });
/// }
///
/// let output = set.join_all().await;
/// assert_eq!(output, vec![2, 1, 0]);
/// }
/// ```
///
/// Equivalent implementation of `join_all`, using [`join_next`] and loop.
///
/// ```
/// use tokio::task::JoinSet;
/// use std::panic;
///
/// #[tokio::main]
/// async fn main() {
/// let mut set = JoinSet::new();
///
/// for i in 0..3 {
/// set.spawn(async move {i});
/// }
///
/// let mut output = Vec::new();
/// while let Some(res) = set.join_next().await{
/// match res {
/// Ok(t) => output.push(t),
/// Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
/// Err(err) => panic!("{err}"),
/// }
/// }
/// assert_eq!(output.len(),3);
/// }
/// ```
/// [`join_next`]: fn@Self::join_next
/// [`JoinError::id`]: fn@crate::task::JoinError::id
pub async fn join_all(mut self) -> Vec<T> {
let mut output = Vec::with_capacity(self.len());

while let Some(res) = self.join_next().await {
match res {
Ok(t) => output.push(t),
Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
Err(err) => panic!("{err}"),
}
}
output
}

/// Aborts all tasks on this `JoinSet`.
///
/// This does not remove the tasks from the `JoinSet`. To wait for the tasks to complete
Expand Down
40 changes: 40 additions & 0 deletions tokio/tests/task_join_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,46 @@ fn runtime_gone() {
.is_cancelled());
}

#[tokio::test]
async fn join_all() {
let mut set: JoinSet<i32> = JoinSet::new();

for _ in 0..5 {
set.spawn(async { 1 });
}
let res: Vec<i32> = set.join_all().await;

assert_eq!(res.len(), 5);
for itm in res.into_iter() {
assert_eq!(itm, 1)
}
}

#[cfg(panic = "unwind")]
#[tokio::test(start_paused = true)]
async fn task_panics() {
let mut set: JoinSet<()> = JoinSet::new();

let (tx, mut rx) = oneshot::channel();
assert_eq!(set.len(), 0);

set.spawn(async move {
tokio::time::sleep(Duration::from_secs(2)).await;
tx.send(()).unwrap();
});
assert_eq!(set.len(), 1);

set.spawn(async {
tokio::time::sleep(Duration::from_secs(1)).await;
panic!();
});
assert_eq!(set.len(), 2);

let panic = tokio::spawn(set.join_all()).await.unwrap_err();
assert!(rx.try_recv().is_err());
assert!(panic.is_panic());
}

#[tokio::test(start_paused = true)]
async fn abort_all() {
let mut set: JoinSet<()> = JoinSet::new();
Expand Down

0 comments on commit cc70a21

Please sign in to comment.