Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JoinSet #4335

Merged
merged 21 commits into from
Feb 1, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add detach_all and abort_all
Darksonn committed Jan 24, 2022
commit 8bbe89c6f1ad64d8e53633ff8c9f419fb2c56a17
46 changes: 46 additions & 0 deletions tokio/src/runtime/tests/loom_task_set.rs
Original file line number Diff line number Diff line change
@@ -32,3 +32,49 @@ fn test_task_set() {
drop(rt);
});
}

#[test]
fn abort_all_during_completion() {
use std::sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
};

// These booleans assert that at least one execution had the task complete first, and that at
// least one execution had the task be cancelled before it completed.
let complete_happened = Arc::new(AtomicBool::new(false));
let cancel_happened = Arc::new(AtomicBool::new(false));

{
let complete_happened = complete_happened.clone();
let cancel_happened = cancel_happened.clone();
loom::model(move || {
let rt = Builder::new_multi_thread()
.worker_threads(1)
.build()
.unwrap();

let mut set = TaskSet::new();

rt.block_on(async {
set.spawn(async { () });
set.abort_all();

match set.join_one().await {
Ok(Some(())) => complete_happened.store(true, SeqCst),
Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst),
Err(err) => panic!("fail: {}", err),
Ok(None) => unreachable!(),
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
}

assert!(matches!(set.join_one().await, Ok(None)));
});

drop(set);
drop(rt);
});
}

assert!(complete_happened.load(SeqCst));
assert!(cancel_happened.load(SeqCst));
}
19 changes: 18 additions & 1 deletion tokio/src/task/task_set.rs
Original file line number Diff line number Diff line change
@@ -7,7 +7,8 @@ use crate::runtime::Handle;
use crate::task::{JoinError, JoinHandle, LocalSet};
use crate::util::IdleNotifiedSet;

/// A collection of tasks spawned on a Tokio runtime.
/// A collection of tasks spawned on a Tokio runtime. A `TaskSet` is not ordered, and the tasks
/// will be returned in the order they completed.
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
///
/// All of the tasks must have the same return type `T`.
///
@@ -132,6 +133,22 @@ impl<T: 'static> TaskSet<T> {
crate::future::poll_fn(|cx| self.poll_join_one(cx)).await
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like it would also be nice to have APIs to await the completion of all the tasks in the task set. It seems like we might want a method that returns a future that just awaits completion, and maybe also some way to collect all of the future return values. We can think more about that in a future PR, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the details are unclear wrt. returning a Vec<Result<T, JoinError>> because it is not #[must_use] and I'm somewhat tired of seeing people ignoring errors when using join_all.

/// Abort all tasks on this `TaskSet`.
///
/// This does not remove the tasks from the `TaskSet`. To wait for the tasks to complete
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
/// cancellation, you should call `join_one` in a loop until the `TaskSet` is empty.
pub fn abort_all(&mut self) {
self.inner.for_each(|jh| jh.abort());
}

/// Remove all tasks from this `TaskSet` without aborting them.
///
/// The tasks removed by this call will continue to run in the background even if the `TaskSet`
/// is dropped.
pub fn detach_all(&mut self) {
self.inner.drain(drop);
}

/// Poll for one of the tasks in the set to complete.
///
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the
43 changes: 43 additions & 0 deletions tokio/src/util/idle_notified_set.rs
Original file line number Diff line number Diff line change
@@ -199,6 +199,49 @@ impl<T> IdleNotifiedSet<T> {
}
}

/// Call a function on every element in this list.
pub(crate) fn for_each<F: FnMut(&mut T)>(&mut self, mut func: F) {
fn get_ptrs<T>(list: &mut LinkedList<T>, ptrs: &mut Vec<*mut T>) {
let mut node = list.last();

while let Some(entry) = node {
ptrs.push(entry.value.with_mut(|ptr| {
let ptr: *mut ManuallyDrop<T> = ptr;
hawkw marked this conversation as resolved.
Show resolved Hide resolved
let ptr: *mut T = ptr.cast();
ptr
}));

let prev = entry.pointers.get_prev();
node = prev.map(|prev| unsafe { prev.as_ref() });
}
}

// Atomically get a raw pointer to the value of every entry.
//
// Since this only locks the mutex once, its not possible for a value to
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
// get moved from the idle list to the notified list during the
// operation, which would otherwise result in some value being listed
// twice.
let mut ptrs = Vec::with_capacity(self.len());
{
let mut lock = self.lists.inner.lock();

get_ptrs(&mut lock.idle, &mut ptrs);
get_ptrs(&mut lock.notified, &mut ptrs);
}
debug_assert_eq!(ptrs.len(), ptrs.capacity());

for ptr in ptrs {
// Safety: When we grabbed the pointers, the entries were in one of
// the two lists. This means that their value was valid at the time,
// and it must still be valid because we are the IdleNotifiedSet,
// and only we can remove an entry from the two lists. (It's
// possible that an entry is moved from one list to the other during
// this loop, but that is ok.)
func(unsafe { &mut *ptr });
}
}

/// Remove all entries in both lists, applying some function to each element.
///
/// The closure is called on all elements even if it panics. Having it panic
5 changes: 3 additions & 2 deletions tokio/src/util/linked_list.rs
Original file line number Diff line number Diff line change
@@ -219,6 +219,7 @@ impl<L: Link> fmt::Debug for LinkedList<L, L::Target> {

#[cfg(any(
feature = "fs",
feature = "rt",
all(unix, feature = "process"),
feature = "signal",
feature = "sync",
@@ -296,15 +297,15 @@ impl<T> Pointers<T> {
}
}

fn get_prev(&self) -> Option<NonNull<T>> {
pub(crate) fn get_prev(&self) -> Option<NonNull<T>> {
// SAFETY: prev is the first field in PointersInner, which is #[repr(C)].
unsafe {
let inner = self.inner.get();
let prev = inner as *const Option<NonNull<T>>;
ptr::read(prev)
}
}
fn get_next(&self) -> Option<NonNull<T>> {
pub(crate) fn get_next(&self) -> Option<NonNull<T>> {
// SAFETY: next is the second field in PointersInner, which is #[repr(C)].
unsafe {
let inner = self.inner.get();
37 changes: 37 additions & 0 deletions tokio/tests/task_set.rs
Original file line number Diff line number Diff line change
@@ -17,6 +17,13 @@ fn rt() -> tokio::runtime::Runtime {
async fn test_with_sleep() {
let mut set = TaskSet::new();

for i in 0..10 {
set.spawn(async move { i });
assert_eq!(set.len(), 1 + i);
}
set.detach_all();
assert_eq!(set.len(), 0);

assert!(matches!(set.join_one().await, Ok(None)));

for i in 0..10 {
@@ -153,3 +160,33 @@ async fn task_set_coop() {
assert!(coop_count >= 1);
assert_eq!(count, TASK_NUM);
}

#[tokio::test(start_paused = true)]
async fn abort_all() {
let mut set: TaskSet<()> = TaskSet::new();

for _ in 0..5 {
set.spawn(futures::future::pending());
}
for _ in 0..5 {
set.spawn(async {
tokio::time::sleep(Duration::from_secs(1)).await;
});
}

// The task set will now have 5 pending tasks and 5 ready tasks.
tokio::time::sleep(Duration::from_secs(2)).await;

set.abort_all();
assert_eq!(set.len(), 10);

let mut count = 0;
while let Some(res) = set.join_one().await.transpose() {
if let Err(err) = res {
assert!(err.is_cancelled());
}
count += 1;
}
assert_eq!(count, 10);
assert_eq!(set.len(), 0);
}