diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index e596375d630..f1903a5aad5 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -41,10 +41,12 @@ full = [ "rt-core", "rt-util", "rt-threaded", + "scope", "signal", "stream", "sync", "time", + "futures" ] blocking = ["rt-core"] @@ -73,6 +75,7 @@ rt-threaded = [ "num_cpus", "rt-core", ] +scope = [] signal = [ "io-driver", "lazy_static", @@ -99,6 +102,7 @@ pin-project-lite = "0.1.1" # Everything else is optional... fnv = { version = "1.0.6", optional = true } futures-core = { version = "0.3.0", optional = true } +futures = { version = "0.3.0", optional = true } lazy_static = { version = "1.0.2", optional = true } memchr = { version = "2.2", optional = true } mio = { version = "0.6.20", optional = true } @@ -122,7 +126,7 @@ optional = true [dev-dependencies] tokio-test = { version = "0.2.0", path = "../tokio-test" } -futures = { version = "0.3.0", features = ["async-await"] } +futures = { version = "0.3.0", features = ["async-await", "executor"] } futures-test = "0.3.0" proptest = "0.9.4" tempfile = "3.1.0" diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index 288f58d2f40..2cc2c4f7f9f 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -204,6 +204,16 @@ macro_rules! cfg_process { } } +macro_rules! cfg_scope { + ($($item:item)*) => { + $( + #[cfg(feature = "scope")] + #[cfg_attr(docsrs, doc(cfg(feature = "scope")))] + $item + )* + } +} + macro_rules! cfg_signal { ($($item:item)*) => { $( diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 3d96106d2df..9509a585415 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -460,6 +460,11 @@ cfg_sync! { mod task; pub(crate) use task::AtomicWaker; + cfg_unstable! { + mod wait_group; + pub(crate) use wait_group::{SharedWaitGroup}; + } + pub mod watch; } diff --git a/tokio/src/sync/wait_group.rs b/tokio/src/sync/wait_group.rs new file mode 100644 index 00000000000..8ed5f46a529 --- /dev/null +++ b/tokio/src/sync/wait_group.rs @@ -0,0 +1,338 @@ +//! An asynchronously awaitable WaitGroup which allows to wait for running tasks +//! to complete. + +use crate::{ + loom::sync::{Arc, Mutex}, + util::linked_list::{self, LinkedList}, +}; +use std::{ + cell::UnsafeCell, + future::Future, + marker::PhantomPinned, + pin::Pin, + ptr::NonNull, + task::{Context, Poll, Waker}, +}; + +/// A synchronization primitive which allows to wait until all tracked tasks +/// have finished. +/// +/// Tasks can wait for tracked tasks to finish by obtaining a Future via `wait`. +/// This Future will get fulfilled when no tasks are running anymore. +pub(crate) struct WaitGroup { + inner: Mutex, +} + +// The Group can be sent to other threads as long as it's not borrowed +unsafe impl Send for WaitGroup {} +// The Group is thread-safe as long as the utilized Mutex is thread-safe +unsafe impl Sync for WaitGroup {} + +impl core::fmt::Debug for WaitGroup { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitGroup").finish() + } +} + +impl WaitGroup { + /// Creates a new WaitGroup + pub(crate) fn new() -> WaitGroup { + WaitGroup { + inner: Mutex::new(GroupState::new(0)), + } + } + + /// Adds a pending task to the WaitGroup + pub(crate) fn add(&self) { + self.inner.lock().unwrap().add() + } + + /// Removes a task that has finished from the WaitGroup + pub(crate) fn remove(&self) { + self.inner.lock().unwrap().remove() + } + + /// Returns a future that gets fulfilled when all tracked tasks complete + pub(crate) fn wait(&self) -> WaitGroupFuture<'_> { + WaitGroupFuture { + group: Some(self), + waiter: UnsafeCell::new(Waiter::new()), + } + } + + unsafe fn try_wait(&self, waiter: &mut UnsafeCell, cx: &mut Context<'_>) -> Poll<()> { + let mut guard = self.inner.lock().unwrap(); + // Safety: The wait node is only accessed inside the Mutex + let waiter = &mut *waiter.get(); + guard.try_wait(waiter, cx) + } + + fn remove_waiter(&self, waiter: &mut UnsafeCell) { + let mut guard = self.inner.lock().unwrap(); + // Safety: The wait node is only accessed inside the Mutex + let waiter = unsafe { &mut *waiter.get() }; + guard.remove_waiter(waiter) + } +} + +/// A Future that is resolved once the corresponding WaitGroup has reached +/// 0 active tasks. +#[must_use = "futures do nothing unless polled"] +pub(crate) struct WaitGroupFuture<'a> { + /// The WaitGroup that is associated with this WaitGroupFuture + group: Option<&'a WaitGroup>, + /// Node for waiting at the group + waiter: UnsafeCell, +} + +// Safety: Futures can be sent between threads, since the underlying +// group is thread-safe (Sync), which allows to poll/register/unregister from +// a different thread. +unsafe impl<'a> Send for WaitGroupFuture<'a> {} + +impl<'a> core::fmt::Debug for WaitGroupFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitGroupFuture").finish() + } +} + +impl Future for WaitGroupFuture<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + // It might be possible to use Pin::map_unchecked here instead of the two unsafe APIs. + // However this didn't seem to work for some borrow checker reasons + + // Safety: The next operations are safe, because Pin promises us that + // the address of the wait queue entry inside WaitGroupFuture is stable, + // and we don't move any fields inside the future until it gets dropped. + let mut_self: &mut WaitGroupFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; + + let group = mut_self + .group + .expect("polled WaitGroupFuture after completion"); + + let poll_res = unsafe { group.try_wait(&mut mut_self.waiter, cx) }; + + if let Poll::Ready(()) = poll_res { + mut_self.group = None; + } + + poll_res + } +} + +impl<'a> Drop for WaitGroupFuture<'a> { + fn drop(&mut self) { + // If this WaitGroupFuture has been polled and it was added to the + // wait queue at the group, it must be removed before dropping. + // Otherwise the group would access invalid memory. + if let Some(ev) = self.group { + ev.remove_waiter(&mut self.waiter); + } + } +} + +/// A cloneable [`WaitGroup`] +/// +/// When tasks are added to this [`WaitGroup`] a [`WaitGroupReleaser`] will be +/// returned, which will automatically decrement the count of active tasks in +/// the [`SharedWaitGroup`] when dropped. +#[derive(Clone)] +pub(crate) struct SharedWaitGroup { + inner: Arc, +} + +impl SharedWaitGroup { + /// Creates a new [`SharedWaitGroup`] + pub(crate) fn new() -> Self { + Self { + inner: Arc::new(WaitGroup::new()), + } + } + + /// Registers a task at the [`SharedWaitGroup`] + /// + /// The method returns a [`WaitGroupReleaser`] which is intended to be dropped + /// once the task completes. + #[must_use] + pub(crate) fn add(&self) -> WaitGroupReleaser { + self.inner.add(); + WaitGroupReleaser { + inner: self.inner.clone(), + } + } + + /// Returns a [`Future`] which will complete once all tasks which have been + /// previously added have dropped their [`WaitGroupReleaser`] and are thereby + /// deemed as finished. + pub(crate) fn wait_future(&self) -> WaitGroupFuture<'_> { + self.inner.wait() + } +} + +/// A handle which tracks an active task which is monitored by the [`SharedWaitGroup`]. +/// When this object is dropped, the task will be automatically be marked as +/// completed inside the [`SharedWaitGroup`]. +pub(crate) struct WaitGroupReleaser { + inner: Arc, +} + +impl Drop for WaitGroupReleaser { + fn drop(&mut self) { + self.inner.remove(); + } +} + +/// Tracks how the future had interacted with the group +#[derive(PartialEq)] +enum PollState { + /// The task has never interacted with the group. + New, + /// The task was added to the wait queue at the group. + Waiting, + /// The task has been polled to completion. + Done, +} + +/// A `Waiter` allows a task to wait o the `WaitGroup`. A `Waiter` is a node +/// in a linked list which is managed through the `WaitGroup`. +/// Access to this struct is synchronized through the mutex in the WaitGroup. +struct Waiter { + /// Intrusive linked-list pointers + pointers: linked_list::Pointers, + /// The task handle of the waiting task + waker: Option, + /// Current polling state + state: PollState, + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +impl Waiter { + /// Creates a new Waiter + fn new() -> Waiter { + Waiter { + pointers: linked_list::Pointers::new(), + waker: None, + state: PollState::New, + _p: PhantomPinned, + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull; + type Target = Waiter; + + fn as_raw(handle: &NonNull) -> NonNull { + *handle + } + + unsafe fn from_raw(ptr: NonNull) -> NonNull { + ptr + } + + unsafe fn pointers(mut target: NonNull) -> NonNull> { + NonNull::from(&mut target.as_mut().pointers) + } +} + +/// Internal state of the `WaitGroup` +struct GroupState { + count: usize, + waiters: LinkedList, +} + +impl GroupState { + fn new(count: usize) -> GroupState { + GroupState { + count, + waiters: LinkedList::new(), + } + } + + fn add(&mut self) { + self.count += 1; + } + + fn remove(&mut self) { + if self.count == 0 { + return; + } + self.count -= 1; + if self.count != 0 { + return; + } + + // Wakeup all waiters + while let Some(mut waiter) = self.waiters.pop_back() { + // Safety: waiters lock is held + let waiter = unsafe { waiter.as_mut() }; + if let Some(handle) = (*waiter).waker.take() { + handle.wake(); + } + (*waiter).state = PollState::Done; + } + } + + /// Checks how many tasks are running. If none are running, this returns + /// `Poll::Ready` immediately. + /// If tasks are running, the WaitGroupFuture gets added to the wait + /// queue at the group, and will be signalled once the tasks completed. + /// This function is only safe as long as the `waiter`s address is guaranteed + /// to be stable until it gets removed from the queue. + unsafe fn try_wait(&mut self, waiter: &mut Waiter, cx: &mut Context<'_>) -> Poll<()> { + match waiter.state { + PollState::New => { + if self.count == 0 { + // The group is already signaled + waiter.state = PollState::Done; + Poll::Ready(()) + } else { + // Added the task to the wait queue + waiter.waker = Some(cx.waker().clone()); + waiter.state = PollState::Waiting; + self.waiters.push_front(waiter.into()); + Poll::Pending + } + } + PollState::Waiting => { + // The WaitGroupFuture is already in the queue. + // The group can't have reached 0 tasks, since this would change the + // waitstate inside the mutex. However the caller might have + // passed a different `Waker`. In this case we need to update it. + if waiter + .waker + .as_ref() + .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker())) + { + waiter.waker = Some(cx.waker().clone()); + } + + Poll::Pending + } + PollState::Done => { + // We have been woken up by the group. + // This does not guarantee that the group still has 0 running tasks. + Poll::Ready(()) + } + } + } + + fn remove_waiter(&mut self, waiter: &mut Waiter) { + // WaitGroupFuture only needs to get removed if it has been added to + // the wait queue of the WaitGroup. This has happened in the PollState::Waiting case. + if let PollState::Waiting = waiter.state { + if unsafe { self.waiters.remove(waiter.into()).is_none() } { + // Panic if the address isn't found. This can only happen if the contract was + // violated, e.g. the Waiter got moved after the initial poll. + panic!("Future could not be removed from wait queue"); + } + waiter.state = PollState::Done; + } + } +} diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index 5c89393a5e2..08dcb0bc966 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -240,3 +240,13 @@ cfg_rt_util! { mod task_local; pub use task_local::LocalKey; } + +cfg_scope! { + // scope requires `CancellationToken` + cfg_sync! { + cfg_unstable! { + mod scope; + pub use scope::{scope, Scope, ScopeConfig, ScopeConfigBuilder, ScopeHandle, ScopedJoinHandle}; + } + } +} diff --git a/tokio/src/task/scope.rs b/tokio/src/task/scope.rs new file mode 100644 index 00000000000..c81c43348b2 --- /dev/null +++ b/tokio/src/task/scope.rs @@ -0,0 +1,439 @@ +//! Tools for structuring concurrent tasks +//! +//! Tokio tasks can run completely independent of each other. However it is +//! often useful to group tasks which try to fulfill a common goal. +//! These groups of tasks should share the same lifetime. If the task group is +//! no longer needed all tasks should stop. If one task errors, the other tasks +//! might no longer be needed and should also be cancelled. +//! +//! The utilities inside this module allow to group tasks by following the +//! concept of structured concurrency. + +use crate::{ + sync::{CancellationToken, SharedWaitGroup}, + task::{JoinError, JoinHandle} +}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +#[derive(Clone)] +struct ScopeState { + wait_group: SharedWaitGroup, + config: ScopeConfig, +} + +impl ScopeState { + fn new(config: ScopeConfig) -> Self { + Self { + config, + wait_group: SharedWaitGroup::new(), + } + } + + fn is_cancelled(&self) -> bool { + self.config.cancellation_token.is_cancelled() + } +} + +pin_project! { + /// Allows to wait for a child task to join + pub struct ScopedJoinHandle<'scope, T> { + #[pin] + handle: JoinHandle, + phantom: core::marker::PhantomData<&'scope ()>, + } +} + +impl<'scope, T> Future for ScopedJoinHandle<'scope, T> { + // The actual type is Result, JoinError> + // However the cancellation will only happen at the exit of the scope. This + // means in all cases the user still has a handle to the task, the task can + // not be cancelled yet. + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().handle.poll(cx) + } +} + +struct CancelTasksGuard<'a> { + scope: &'a CancellationToken, +} + +impl<'a> Drop for CancelTasksGuard<'a> { + fn drop(&mut self) { + self.scope.cancel(); + } +} + +struct WaitForTasksToJoinGuard<'a> { + wait_group: &'a SharedWaitGroup, + drop_behavior: ScopeDropBehavior, + enabled: bool, +} + +impl<'a> WaitForTasksToJoinGuard<'a> { + fn disarm(&mut self) { + self.enabled = false; + } +} + +impl<'a> Drop for WaitForTasksToJoinGuard<'a> { + fn drop(&mut self) { + if !self.enabled { + return; + } + + match self.drop_behavior { + ScopeDropBehavior::BlockToCompletion => { + let wait_fut = self.wait_group.wait_future(); + + // TODOs: + // - This should not have a futures dependency + // - This might block multithreaded runtimes, since the tasks might need + // the current executor thread to make progress, due to dependening on + // its IO handles. We need to do something along task::block_in_place + // to solve this. + futures::executor::block_on(wait_fut); + } + ScopeDropBehavior::Panic => { + panic!("Scope was dropped before child tasks run to completion"); + } + ScopeDropBehavior::Abort => { + eprintln!("[ERROR] A scope was dropped without being awaited"); + std::process::abort(); + } + ScopeDropBehavior::ContinueTasks => { + // Do nothing + } + } + } +} + +/// A handle to the scope, which allows to spawn child tasks +#[derive(Clone)] +pub struct ScopeHandle { + scope: ScopeState, +} + +impl core::fmt::Debug for ScopeHandle { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("ScopeHandle").finish() + } +} + +impl ScopeHandle { + /// Returns a reference to the `CancellationToken` which signals whether the + /// scope had been cancelled. + pub fn cancellation_token(&self) -> &CancellationToken { + &self.scope.config.cancellation_token + } + + /// Creates a child scope. + /// The child scope inherit all properties of this scope + pub fn child(&self) -> Scope { + Scope::with_parent(self.clone()) + } + + /// Spawns a task on the scope + pub fn spawn<'inner, T, R>(&'inner self, task: T) -> ScopedJoinHandle<'inner, R> + where + T: Future + Send + 'static, + R: Send + 'static, + T: 'inner, + { + let spawn_handle = + crate::runtime::context::spawn_handle().expect("Spawn handle must be available"); + + let releaser = self.scope.wait_group.add(); + + let child_task = { + spawn_handle.spawn(async move { + // Drop this at the end of the task to signal we are done and unblock + // the WaitGroup + let _wait_group_releaser = releaser; + + // Execute the child task + task.await + }) + }; + + // Since `Scope` is `Sync` and `Send`, cancellations can happen at any time + // in case of invalid use. Therefore we only check cancellations once: + // After the task has been spawned. Since the cancellation is already set, + // we need to wait for the task to complete. Then we panic due to invalid + // API usage. + if self.scope.is_cancelled() { + futures::executor::block_on(async { + let _ = child_task.await; + }); + panic!("Spawn on cancelled Scope"); + } + + ScopedJoinHandle { + handle: child_task, + phantom: core::marker::PhantomData, + } + } + + /// Spawns a task on the scope, which will get automatically cancelled if + /// the `CancellationToken` which is associated with the current `ScopeHandle` + /// gets cancelled. + pub fn spawn_cancellable<'inner, T, R>( + &'inner self, + task: T, + ) -> ScopedJoinHandle<'inner, Result> + where + T: Future + Send + 'static, + R: Send + 'static, + T: 'inner, + { + let cancel_token = self.cancellation_token().clone(); + + self.spawn(async move { + crate::pin!(task); + use futures::FutureExt; + + futures::select! { + _ = cancel_token.cancelled().fuse() => { + // The child task was cancelled + Err(CancellationError{}) + }, + res = task.fuse() => { + Ok(res) + } + } + }) + } +} + +/// Defines how a scope will behave if the `Future` it returns get dropped +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ScopeDropBehavior { + /// When a scope is dropped while tasks are outstanding, the current thread + /// will panic. Since this will not wait for child tasks to complete, the + /// child tasks can outlive the parent in this case. + Panic, + /// When a scope is dropped while tasks are outstanding, the process will be + /// aborted. + Abort, + /// When a scope is dropped while tasks are outstanding, the current thread + /// will be blocked until the tasks in the `scope` completed. This option + /// is only available in multithreaded tokio runtimes, and is the default there. + BlockToCompletion, + /// Ignore that the scope got dropped and continue to run the child tasks. + /// Choosing this option will break structured concurrency. It is therefore + /// not recommended to pick the option. + ContinueTasks, +} + +/// Advanced configuration options for `scope` +#[derive(Debug, Clone)] +pub struct ScopeConfig { + drop_behavior: ScopeDropBehavior, + cancellation_token: CancellationToken, +} + +/// Allows to configure a new `Scope` +#[derive(Debug)] +pub struct ScopeConfigBuilder { + /// The parent scope if available + parent: Option, + /// Drop behavior overwrite + drop_behavior: Option, +} + +impl ScopeConfigBuilder { + /// Creates a new scope which is treated as a child scope of the given + /// [`ScopeHandle`]. The new scope will inherit all properties of the parent + /// scope. In addition tasks inside the new scope will get cancelled when + /// the parent scope gets cancelled. + pub fn with_parent(parent: ScopeHandle) -> Self { + Self { + parent: Some(parent), + drop_behavior: None, + } + } + + /// Creates a new scope which is detached from any parent scope. + /// Tasks spawned on this `scope` will not get cancelled if any parent scope + /// gets cancelled. Instead those tasks would only get cancelled if the + /// scope itself gets cancelled. + pub fn detached() -> Self { + Self { + parent: None, + drop_behavior: Some(ScopeDropBehavior::Panic), + } + } + + /// If the scope is dropped instead of being awaited, the thread which + /// performing the drop of the scope will block until all child tasks in the + /// scope will have run to completion. + pub fn block_to_completion(&mut self) { + self.drop_behavior = Some(ScopeDropBehavior::BlockToCompletion); + } + + /// If the scope is dropped instead of being awaited, child tasks will + /// continue to run. This breaks "structured concurrency", since child tasks + /// are now able to outlive the parent task. + pub fn continue_tasks_on_drop(&mut self) { + self.drop_behavior = Some(ScopeDropBehavior::ContinueTasks); + } + + /// If a scope `Future` gets dropped instead of awaited, the current process + /// will be aborted. This settings tries to provide higher guarantees about + /// child tasks not outliving their parent tasks. + pub fn abort_on_drop(&mut self) { + self.drop_behavior = Some(ScopeDropBehavior::Abort); + } + + /// If a scope `Future` gets dropped instead of awaited, the current + /// thread will `panic!` in order to indicate that the scope did not + /// correctly wait for child tasks to complete, and that there exist detached + /// child tasks as a result of this action. + pub fn panic_on_drop(&mut self) { + self.drop_behavior = Some(ScopeDropBehavior::Panic); + } + + /// Builds the configuration for the scope + pub fn build(self) -> Result { + // Get defaults + + // Generate a `CancellationToken`. If a parent scope and an associated + // cancellation token exists, we create a child token from it. + let cancellation_token = if let Some(parent) = &self.parent { + parent.scope.config.cancellation_token.child_token() + } else { + CancellationToken::new() + }; + + let mut drop_behavior = match &self.parent { + Some(parent) => parent.scope.config.drop_behavior, + None => ScopeDropBehavior::Panic, + }; + + // Apply overwrites + if let Some(behavior) = self.drop_behavior { + drop_behavior = behavior + }; + + Ok(ScopeConfig { + cancellation_token, + drop_behavior, + }) + } +} + +#[derive(Debug)] +pub enum ScopeConfigBuilderError { +} + +/// A builder with allows to build and enter a new task scope. +#[derive(Debug)] +pub struct Scope { + /// Configuration options for the scope + config: ScopeConfig, +} + +impl Scope { + /// Creates a new scope which is treated as a child scope of the given + /// [`ScopeHandle`]. The new scope will inherit all properties of the parent + /// scope. In addition tasks inside the new scope will get cancelled when + /// the parent scope gets cancelled. + pub fn with_parent(parent: ScopeHandle) -> Self { + Self::with_config(ScopeConfigBuilder::with_parent(parent).build().expect("Inherited config can not fail")) + } + + /// Creates a new scope which is detached from any parent scope. + /// Tasks spawned on this `scope` will not get cancelled if any parent scope + /// gets cancelled. Instead those tasks would only get cancelled if the + /// scope itself gets cancelled. + pub fn detached() -> Self { + Self::with_config(ScopeConfigBuilder::detached().build().expect("Default config can not fail")) + } + + /// Creates a `Scope` with the given configuration + pub fn with_config(config: ScopeConfig) -> Self { + Self { + config + } + } + + /// Creates a [`scope`] with custom options + /// + /// The method behaves like [`scope`], but the cancellation and `Drop` behavior + /// for the [`scope`] are configurable. See [`ScopeConfig`] for details. + pub async fn enter(self, scope_func: F) -> R + where + F: FnOnce(ScopeHandle) -> Fut, + Fut: Future + Send, + { + let scope_state = ScopeState::new(self.config); + let wait_fut = scope_state.wait_group.wait_future(); + + // This guard will be called be executed if the scope gets dropped while + // it is still executing. + let mut wait_for_tasks_guard = WaitForTasksToJoinGuard { + wait_group: &scope_state.wait_group, + enabled: true, + drop_behavior: scope_state.config.drop_behavior, + }; + + let scoped_result = { + // This guard will call `.cancel()` on the `CancellationToken` we + // just created. + let _cancel_guard = CancelTasksGuard { + scope: &scope_state.config.cancellation_token, + }; + + let handle = ScopeHandle { + scope: scope_state.clone(), + }; + + // Execute the scope handler, which gets passed a handle to the newly + // created scope + scope_func(handle).await + }; + + // Wait for all remaining tasks inside the scope to complete + wait_fut.await; + + // The tasks have completed. We do not need to wait for them to complete + // in the `Drop` guard. + wait_for_tasks_guard.disarm(); + + scoped_result + } +} + +/// Creates a task scope with default options. +/// +/// The `scope` allows to spawn child tasks so that the lifetime of child tasks +/// is constrained within the scope. +/// +/// A closure which accepts a [`ScopeHandle`] object and which returns a [`Future`] +/// needs to be passed to `scope`. The [`ScopeHandle`] can be used to spawn child +/// tasks. +/// +/// If the provided `Future` had been polled to completion, all child tasks which +/// have been spawned via the `ScopeHandle` will be cancelled. +/// +/// `scope` returns a [`Future`] which should be awaited. The `await` will only +/// complete once all child tasks that have been spawned via the provided +/// [`ScopeHandle`] have joined. Thereby the `scope` does not allow child tasks +/// to outlive their parent task, as long as the `scope` is awaited. +pub async fn scope(scope_func: F) -> R +where + F: FnOnce(ScopeHandle) -> Fut, + Fut: Future + Send, +{ + Scope::detached().enter(scope_func).await +} + +/// Error type which is returned when a task is cancelled +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)] +pub struct CancellationError {} diff --git a/tokio/tests/task_scope.rs b/tokio/tests/task_scope.rs new file mode 100644 index 00000000000..209047ae121 --- /dev/null +++ b/tokio/tests/task_scope.rs @@ -0,0 +1,341 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use futures::{select, FutureExt}; +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::{Duration, Instant}, +}; +use tokio::{ + task::{scope, Scope, ScopeConfigBuilder}, + time::delay_for, +}; + +#[tokio::test] +async fn unused_scope() { + let scope = scope(|_scope| async {}); + drop(scope); +} + +#[tokio::test] +async fn spawn_and_return_result() { + let result = scope(|scope| async move { + let handle = scope.spawn(async { 123u32 }); + handle.await + }) + .await; + assert_eq!(123u32, result.unwrap()); +} + +#[derive(Clone)] +struct AtomicFlag(Arc); + +impl AtomicFlag { + fn new() -> Self { + AtomicFlag(Arc::new(AtomicBool::new(false))) + } + + fn is_set(&self) -> bool { + self.0.load(Ordering::Acquire) + } + + fn set(&self) { + self.0.store(true, Ordering::Release); + } +} + +struct SetFlagOnDropGuard { + flag: AtomicFlag, +} + +impl Drop for SetFlagOnDropGuard { + fn drop(&mut self) { + self.flag.set(); + } +} + +#[tokio::test] +async fn cancel_and_wait_for_child_task() { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let result = scope(|scope| async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn_cancellable(async { + let _guard = SetFlagOnDropGuard { flag: flag_clone }; + loop { + tokio::task::yield_now().await; + } + }); + + handle.await + }) + .await; + assert_eq!(123u32, result.unwrap()); + + // Check that the second task was cancelled + assert_eq!(true, flag.is_set()); +} + +#[tokio::test] +async fn cancels_nested_scopes() { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let result = scope(|scope| async move { + let ct = scope.cancellation_token().clone(); + + let handle = scope.spawn(async move { + delay_for(Duration::from_millis(200)).await; + // Cancelling the parent scope should also cancel the task + // which is running insie a child scope + ct.cancel(); + 123u32 + }); + + scope + .child() + .enter(|child_scope| async move { + let _ = child_scope + .spawn_cancellable(async { + let _guard = SetFlagOnDropGuard { flag: flag_clone }; + loop { + tokio::task::yield_now().await; + } + }) + .await; + }) + .await; + + handle.await + }) + .await; + assert_eq!(123u32, result.unwrap()); + + // Check that the second task was cancelled + assert_eq!(true, flag.is_set()); +} + +#[test] +fn block_until_non_joined_tasks_complete() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let start_time = Instant::now(); + + let mut scope_config = ScopeConfigBuilder::detached(); + scope_config.block_to_completion(); + + let scope_fut = Scope::with_config(scope_config.build().unwrap()).enter(|scope| { + async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn(async move { + // Use block_in_place makes the task not cancellable + tokio::task::block_in_place(|| { + std::thread::sleep(Duration::from_millis(100)); + }); + flag_clone.set(); + }); + + handle.await + } + }); + + select! { + _ = scope_fut.fuse() => { + panic!("Scope should not complete"); + }, + _ = delay_for(Duration::from_millis(50)).fuse() => { + // Drop the scope here + }, + }; + + assert!(start_time.elapsed() >= Duration::from_millis(100)); + + // Check that the second task run to completion + assert_eq!(true, flag.is_set()); + }); +} + +#[should_panic] +#[test] +fn panic_if_active_scope_is_dropped() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let mut scope_config = ScopeConfigBuilder::detached(); + scope_config.panic_on_drop(); + + let scope_fut = Scope::with_config(scope_config.build().unwrap()).enter(|scope| { + async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn(async move { + // Use block_in_place makes the task not cancellable + tokio::task::block_in_place(|| { + std::thread::sleep(Duration::from_millis(100)); + }); + }); + + handle.await + } + }); + + select! { + _ = scope_fut.fuse() => { + panic!("Scope should not complete"); + }, + _ = delay_for(Duration::from_millis(50)).fuse() => { + // Drop the scope here + }, + }; + }); +} + +#[test] +fn child_tasks_can_continue_to_run_if_configured() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let mut scope_config = ScopeConfigBuilder::detached(); + scope_config.continue_tasks_on_drop(); + + let start_time = Instant::now(); + let scope_fut = Scope::with_config(scope_config.build().unwrap()).enter(|scope| { + async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn(async move { + // Use block_in_place makes the task not cancellable + tokio::task::block_in_place(|| { + std::thread::sleep(Duration::from_millis(100)); + }); + flag_clone.set(); + }); + + handle.await + } + }); + + select! { + _ = scope_fut.fuse() => { + panic!("Scope should not complete"); + }, + _ = delay_for(Duration::from_millis(50)).fuse() => { + // Drop the scope here + }, + }; + + let elapsed = start_time.elapsed(); + assert!(elapsed >= Duration::from_millis(50) && elapsed < Duration::from_millis(100)); + assert_eq!(false, flag.is_set()); + + // Wait until the leaked task run to completion + delay_for(Duration::from_millis(60)).await; + assert_eq!(true, flag.is_set()); + }); +} + +#[test] +fn clone_scope_handles_and_cancel_child() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let drop_flag = AtomicFlag::new(); + let drop_flag_clone = drop_flag.clone(); + let completion_flag = AtomicFlag::new(); + let completion_flag_clone = completion_flag.clone(); + + scope(|scope| { + async move { + let cloned_handle = scope.clone(); + + let join_handle = scope.spawn(async move { + delay_for(Duration::from_millis(20)).await; + // Spawn another task - which is not awaited but should get + // cancelled through the inherited CancellationToken + let _join_handle = cloned_handle.spawn_cancellable(async move { + let _guard = SetFlagOnDropGuard { + flag: drop_flag_clone, + }; + + delay_for(Duration::from_millis(50)).await; + // This should not get executed, since the inital task exits before + // and this task gets cancelled. + completion_flag_clone.set(); + }); + }); + + let _ = join_handle.await; + } + }) + .await; + + assert_eq!(true, drop_flag.is_set()); + assert_eq!(false, completion_flag.is_set()); + }); +} + +#[test] +fn clone_scope_handles_and_wait_for_child() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let drop_flag = AtomicFlag::new(); + let drop_flag_clone = drop_flag.clone(); + let completion_flag = AtomicFlag::new(); + let completion_flag_clone = completion_flag.clone(); + + let mut scope_config = ScopeConfigBuilder::detached(); + scope_config.continue_tasks_on_drop(); + + let start_time = Instant::now(); + Scope::with_config(scope_config.build().unwrap()) + .enter(|scope| { + async move { + let cloned_handle = scope.clone(); + + let join_handle = scope.spawn(async move { + delay_for(Duration::from_millis(20)).await; + // Spawn another task - which is not awaited + let _join_handle = cloned_handle.spawn(async move { + let _guard = SetFlagOnDropGuard { + flag: drop_flag_clone, + }; + + delay_for(Duration::from_millis(50)).await; + // This should get executed, since tasks are allowed to run + // to completion. + completion_flag_clone.set(); + }); + }); + + let _ = join_handle.await; + } + }) + .await; + + assert!(start_time.elapsed() >= Duration::from_millis(70)); + + assert_eq!(true, drop_flag.is_set()); + assert_eq!(true, completion_flag.is_set()); + }); +}