From 37092fbb6e01bf341b83f2cb44dbd7e12e3ec408 Mon Sep 17 00:00:00 2001 From: Matthias Einwag Date: Wed, 19 Feb 2020 23:58:39 -0800 Subject: [PATCH] WIP: task::scope using implicit scopes This change adds task::scope as a mechanism for supporting structured concurrency as described in #1879. This version of the scope implementation makes use of implicit scopes, which are propgated within the task system through task local storage. Ever task spawned via `scope::spawn` or `scope::spawn_cancellable` is automatically attached to it's current scope without having to explicitly attach to it. This provides stronger guarantees, since child tasks in this model will never be able to outlive the parent - there is no `ScopeHandle` available to spawn a task on a certain scope after this is finished. One drawback of this approach is however that since no `ScopeHandle` is available, we also can't tie the lifetime of tasks and their `JoinHandle`s to this scope. This makes it less likely that we could borrowing data from the parent task using this approach. One benefit however is that there seems to be an interesting migration path from tokios current task system to this scoped approach: - Using `tokio::spawn` could in the future be equivalent to spawning a task on the runtimes implicit top level scope. The task would not be force-cancellable, in the same fashion as tasks spawned via `scope::spawn` are not cancellable. - Shutting down the runtime could be equivalent to leaving a scope: The remaining running tasks get a graceful cancellation signal and the scope would wait for those tasks to finish. - However since the Runtime would never have to force-cancel a task (people would opt into this behavior using `scope::spawn_cancellable`) the `JoinError` could be removed from the "normal" spawn API. It is still available for cancellable spawns. --- tokio/Cargo.toml | 4 + tokio/src/macros/cfg.rs | 10 + tokio/src/sync/mod.rs | 5 + tokio/src/sync/wait_group.rs | 338 +++++++++++++++++++++++ tokio/src/task/mod.rs | 9 + tokio/src/task/scope.rs | 522 +++++++++++++++++++++++++++++++++++ tokio/tests/task_scope.rs | 202 ++++++++++++++ 7 files changed, 1090 insertions(+) create mode 100644 tokio/src/sync/wait_group.rs create mode 100644 tokio/src/task/scope.rs create mode 100644 tokio/tests/task_scope.rs diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index e596375d630..835a29b8c7a 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -41,10 +41,12 @@ full = [ "rt-core", "rt-util", "rt-threaded", + "scope", "signal", "stream", "sync", "time", + "futures" ] blocking = ["rt-core"] @@ -73,6 +75,7 @@ rt-threaded = [ "num_cpus", "rt-core", ] +scope = [] signal = [ "io-driver", "lazy_static", @@ -99,6 +102,7 @@ pin-project-lite = "0.1.1" # Everything else is optional... fnv = { version = "1.0.6", optional = true } futures-core = { version = "0.3.0", optional = true } +futures = { version = "0.3.0", optional = true } lazy_static = { version = "1.0.2", optional = true } memchr = { version = "2.2", optional = true } mio = { version = "0.6.20", optional = true } diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index 288f58d2f40..2cc2c4f7f9f 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -204,6 +204,16 @@ macro_rules! cfg_process { } } +macro_rules! cfg_scope { + ($($item:item)*) => { + $( + #[cfg(feature = "scope")] + #[cfg_attr(docsrs, doc(cfg(feature = "scope")))] + $item + )* + } +} + macro_rules! cfg_signal { ($($item:item)*) => { $( diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 3d96106d2df..9509a585415 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -460,6 +460,11 @@ cfg_sync! { mod task; pub(crate) use task::AtomicWaker; + cfg_unstable! { + mod wait_group; + pub(crate) use wait_group::{SharedWaitGroup}; + } + pub mod watch; } diff --git a/tokio/src/sync/wait_group.rs b/tokio/src/sync/wait_group.rs new file mode 100644 index 00000000000..8ed5f46a529 --- /dev/null +++ b/tokio/src/sync/wait_group.rs @@ -0,0 +1,338 @@ +//! An asynchronously awaitable WaitGroup which allows to wait for running tasks +//! to complete. + +use crate::{ + loom::sync::{Arc, Mutex}, + util::linked_list::{self, LinkedList}, +}; +use std::{ + cell::UnsafeCell, + future::Future, + marker::PhantomPinned, + pin::Pin, + ptr::NonNull, + task::{Context, Poll, Waker}, +}; + +/// A synchronization primitive which allows to wait until all tracked tasks +/// have finished. +/// +/// Tasks can wait for tracked tasks to finish by obtaining a Future via `wait`. +/// This Future will get fulfilled when no tasks are running anymore. +pub(crate) struct WaitGroup { + inner: Mutex, +} + +// The Group can be sent to other threads as long as it's not borrowed +unsafe impl Send for WaitGroup {} +// The Group is thread-safe as long as the utilized Mutex is thread-safe +unsafe impl Sync for WaitGroup {} + +impl core::fmt::Debug for WaitGroup { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitGroup").finish() + } +} + +impl WaitGroup { + /// Creates a new WaitGroup + pub(crate) fn new() -> WaitGroup { + WaitGroup { + inner: Mutex::new(GroupState::new(0)), + } + } + + /// Adds a pending task to the WaitGroup + pub(crate) fn add(&self) { + self.inner.lock().unwrap().add() + } + + /// Removes a task that has finished from the WaitGroup + pub(crate) fn remove(&self) { + self.inner.lock().unwrap().remove() + } + + /// Returns a future that gets fulfilled when all tracked tasks complete + pub(crate) fn wait(&self) -> WaitGroupFuture<'_> { + WaitGroupFuture { + group: Some(self), + waiter: UnsafeCell::new(Waiter::new()), + } + } + + unsafe fn try_wait(&self, waiter: &mut UnsafeCell, cx: &mut Context<'_>) -> Poll<()> { + let mut guard = self.inner.lock().unwrap(); + // Safety: The wait node is only accessed inside the Mutex + let waiter = &mut *waiter.get(); + guard.try_wait(waiter, cx) + } + + fn remove_waiter(&self, waiter: &mut UnsafeCell) { + let mut guard = self.inner.lock().unwrap(); + // Safety: The wait node is only accessed inside the Mutex + let waiter = unsafe { &mut *waiter.get() }; + guard.remove_waiter(waiter) + } +} + +/// A Future that is resolved once the corresponding WaitGroup has reached +/// 0 active tasks. +#[must_use = "futures do nothing unless polled"] +pub(crate) struct WaitGroupFuture<'a> { + /// The WaitGroup that is associated with this WaitGroupFuture + group: Option<&'a WaitGroup>, + /// Node for waiting at the group + waiter: UnsafeCell, +} + +// Safety: Futures can be sent between threads, since the underlying +// group is thread-safe (Sync), which allows to poll/register/unregister from +// a different thread. +unsafe impl<'a> Send for WaitGroupFuture<'a> {} + +impl<'a> core::fmt::Debug for WaitGroupFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitGroupFuture").finish() + } +} + +impl Future for WaitGroupFuture<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + // It might be possible to use Pin::map_unchecked here instead of the two unsafe APIs. + // However this didn't seem to work for some borrow checker reasons + + // Safety: The next operations are safe, because Pin promises us that + // the address of the wait queue entry inside WaitGroupFuture is stable, + // and we don't move any fields inside the future until it gets dropped. + let mut_self: &mut WaitGroupFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; + + let group = mut_self + .group + .expect("polled WaitGroupFuture after completion"); + + let poll_res = unsafe { group.try_wait(&mut mut_self.waiter, cx) }; + + if let Poll::Ready(()) = poll_res { + mut_self.group = None; + } + + poll_res + } +} + +impl<'a> Drop for WaitGroupFuture<'a> { + fn drop(&mut self) { + // If this WaitGroupFuture has been polled and it was added to the + // wait queue at the group, it must be removed before dropping. + // Otherwise the group would access invalid memory. + if let Some(ev) = self.group { + ev.remove_waiter(&mut self.waiter); + } + } +} + +/// A cloneable [`WaitGroup`] +/// +/// When tasks are added to this [`WaitGroup`] a [`WaitGroupReleaser`] will be +/// returned, which will automatically decrement the count of active tasks in +/// the [`SharedWaitGroup`] when dropped. +#[derive(Clone)] +pub(crate) struct SharedWaitGroup { + inner: Arc, +} + +impl SharedWaitGroup { + /// Creates a new [`SharedWaitGroup`] + pub(crate) fn new() -> Self { + Self { + inner: Arc::new(WaitGroup::new()), + } + } + + /// Registers a task at the [`SharedWaitGroup`] + /// + /// The method returns a [`WaitGroupReleaser`] which is intended to be dropped + /// once the task completes. + #[must_use] + pub(crate) fn add(&self) -> WaitGroupReleaser { + self.inner.add(); + WaitGroupReleaser { + inner: self.inner.clone(), + } + } + + /// Returns a [`Future`] which will complete once all tasks which have been + /// previously added have dropped their [`WaitGroupReleaser`] and are thereby + /// deemed as finished. + pub(crate) fn wait_future(&self) -> WaitGroupFuture<'_> { + self.inner.wait() + } +} + +/// A handle which tracks an active task which is monitored by the [`SharedWaitGroup`]. +/// When this object is dropped, the task will be automatically be marked as +/// completed inside the [`SharedWaitGroup`]. +pub(crate) struct WaitGroupReleaser { + inner: Arc, +} + +impl Drop for WaitGroupReleaser { + fn drop(&mut self) { + self.inner.remove(); + } +} + +/// Tracks how the future had interacted with the group +#[derive(PartialEq)] +enum PollState { + /// The task has never interacted with the group. + New, + /// The task was added to the wait queue at the group. + Waiting, + /// The task has been polled to completion. + Done, +} + +/// A `Waiter` allows a task to wait o the `WaitGroup`. A `Waiter` is a node +/// in a linked list which is managed through the `WaitGroup`. +/// Access to this struct is synchronized through the mutex in the WaitGroup. +struct Waiter { + /// Intrusive linked-list pointers + pointers: linked_list::Pointers, + /// The task handle of the waiting task + waker: Option, + /// Current polling state + state: PollState, + /// Should not be `Unpin`. + _p: PhantomPinned, +} + +impl Waiter { + /// Creates a new Waiter + fn new() -> Waiter { + Waiter { + pointers: linked_list::Pointers::new(), + waker: None, + state: PollState::New, + _p: PhantomPinned, + } + } +} + +/// # Safety +/// +/// `Waiter` is forced to be !Unpin. +unsafe impl linked_list::Link for Waiter { + type Handle = NonNull; + type Target = Waiter; + + fn as_raw(handle: &NonNull) -> NonNull { + *handle + } + + unsafe fn from_raw(ptr: NonNull) -> NonNull { + ptr + } + + unsafe fn pointers(mut target: NonNull) -> NonNull> { + NonNull::from(&mut target.as_mut().pointers) + } +} + +/// Internal state of the `WaitGroup` +struct GroupState { + count: usize, + waiters: LinkedList, +} + +impl GroupState { + fn new(count: usize) -> GroupState { + GroupState { + count, + waiters: LinkedList::new(), + } + } + + fn add(&mut self) { + self.count += 1; + } + + fn remove(&mut self) { + if self.count == 0 { + return; + } + self.count -= 1; + if self.count != 0 { + return; + } + + // Wakeup all waiters + while let Some(mut waiter) = self.waiters.pop_back() { + // Safety: waiters lock is held + let waiter = unsafe { waiter.as_mut() }; + if let Some(handle) = (*waiter).waker.take() { + handle.wake(); + } + (*waiter).state = PollState::Done; + } + } + + /// Checks how many tasks are running. If none are running, this returns + /// `Poll::Ready` immediately. + /// If tasks are running, the WaitGroupFuture gets added to the wait + /// queue at the group, and will be signalled once the tasks completed. + /// This function is only safe as long as the `waiter`s address is guaranteed + /// to be stable until it gets removed from the queue. + unsafe fn try_wait(&mut self, waiter: &mut Waiter, cx: &mut Context<'_>) -> Poll<()> { + match waiter.state { + PollState::New => { + if self.count == 0 { + // The group is already signaled + waiter.state = PollState::Done; + Poll::Ready(()) + } else { + // Added the task to the wait queue + waiter.waker = Some(cx.waker().clone()); + waiter.state = PollState::Waiting; + self.waiters.push_front(waiter.into()); + Poll::Pending + } + } + PollState::Waiting => { + // The WaitGroupFuture is already in the queue. + // The group can't have reached 0 tasks, since this would change the + // waitstate inside the mutex. However the caller might have + // passed a different `Waker`. In this case we need to update it. + if waiter + .waker + .as_ref() + .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker())) + { + waiter.waker = Some(cx.waker().clone()); + } + + Poll::Pending + } + PollState::Done => { + // We have been woken up by the group. + // This does not guarantee that the group still has 0 running tasks. + Poll::Ready(()) + } + } + } + + fn remove_waiter(&mut self, waiter: &mut Waiter) { + // WaitGroupFuture only needs to get removed if it has been added to + // the wait queue of the WaitGroup. This has happened in the PollState::Waiting case. + if let PollState::Waiting = waiter.state { + if unsafe { self.waiters.remove(waiter.into()).is_none() } { + // Panic if the address isn't found. This can only happen if the contract was + // violated, e.g. the Waiter got moved after the initial poll. + panic!("Future could not be removed from wait queue"); + } + waiter.state = PollState::Done; + } + } +} diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index 5c89393a5e2..3dfa2561e74 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -240,3 +240,12 @@ cfg_rt_util! { mod task_local; pub use task_local::LocalKey; } + +cfg_scope! { + // scope requires `CancellationToken` + cfg_sync! { + cfg_unstable! { + pub mod scope; + } + } +} diff --git a/tokio/src/task/scope.rs b/tokio/src/task/scope.rs new file mode 100644 index 00000000000..c5219f0f18e --- /dev/null +++ b/tokio/src/task/scope.rs @@ -0,0 +1,522 @@ +//! Tools for structuring concurrent tasks +//! +//! Tokio tasks can run completely independent of each other. However it is +//! often useful to group tasks which try to fulfill a common goal. +//! These groups of tasks should share the same lifetime. If the task group is +//! no longer needed all tasks should stop. If one task errors, the other tasks +//! might no longer be needed and should also be cancelled. +//! +//! The utilities inside this module allow to group tasks by following the +//! concept of structured concurrency. + +use crate::{ + sync::{CancellationToken, SharedWaitGroup}, + task::{JoinError, JoinHandle}, +}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +/// Creates and enters a task scope +/// +/// The `scope` allows to spawn child tasks so that the lifetime of child tasks +/// is constrained within the scope. +/// +/// If the provided `Future` had been polled to completion, all child tasks which +/// have been spawned via [`scope::spawn`] and [`scope::spawn_cancellable`] are +/// guaranteed to have run to completion. +/// +/// `enter` returns a [`Future`] which must be awaited. The `await` will only +/// complete once all child tasks that have been spawned via the provided +/// [`ScopeHandle`] have joined. Thereby the `scope` does not allow child tasks +/// to outlive their parent task, as long as the future returned from +/// `scope::enter` is awaited. +/// +/// The `Future` returned from `enter` will evaluate the value which is returned +/// from the async function inside the scope. +/// +/// Since scopes need to run to completion they should not be be started on +/// tasks which have been spawned using the [`spawn_cancellable`] function, since +/// this function will force-cancel running tasks on it. +/// +/// Instead new scopes should always be created from gracefully cancellable tasks +/// which have been stared using the [`scope::spawn`] method. +/// +/// # Examples +/// +/// ```no_run +/// use tokio::task::scope; +/// +/// #[tokio::main] +/// async fn scope_with_graceful_cancellation() { +/// let result = scope::enter(async move { +/// // This is the main task which will finish after 20ms +/// let handle = scope::spawn(async { +/// tokio::time::delay_for(std::time::Duration::from_millis(20)).await; +/// println!("Cancelling"); +/// scope::current_cancellation_token().cancel(); +/// 123u32 +/// }); +/// +/// // Spawn a long running task which is not intended to run to completion +/// let _ = scope::spawn(async { +/// let ct = scope::current_cancellation_token(); +/// tokio::select! { +/// _ = ct.cancelled() => { +/// // This branch will be taken once the scope is left +/// println!("task was cancelled"); +/// }, +/// _ = tokio::time::delay_for(std::time::Duration::from_secs(3600)) => { +/// panic!("This task should not run to completion"); +/// }, +/// } +/// }).await; +/// +/// // Wait for the main task. After this finishes the scope will end. +/// // Thereby the remaining task will get cancelled, and awaited before +/// // `scope::enter` returns. +/// handle.await +/// }) +/// .await; +/// +/// assert_eq!(123, result.unwrap()); +/// } +/// ``` +pub async fn enter(scope_fut: Fut) -> R +where + Fut: Future + Send, +{ + let child_scope = match CURRENT_SCOPE.try_with(|scope_handle| scope_handle.clone()) { + Ok(scope) => scope.child(), + Err(_) => Scope::detached(), + }; + + child_scope.enter(scope_fut).await +} + +/// Spawns a task on the current scope which will run to completion. +/// +/// If the parent scope is cancelled the task will be informed via through a +/// [`CancellationToken`] whose cancellation state can be queried using +/// [`scope::current_cancellation_token`]. If cancellation is requested, the +/// task should return as early as possible. +pub fn spawn(task: T) -> ScopedJoinHandle +where + T: Future + Send + 'static, + R: Send + 'static, + T: 'static, +{ + let current_scope_handle = CURRENT_SCOPE + .try_with(|scope_handle| scope_handle.clone()) + .unwrap(); + + current_scope_handle.spawn(task) +} + +/// Spawns a task on the current scope which will automatically get force-cancelled +/// if the parent if the `scope` gets cancelled. That spawned task therefore is +/// not guaranteed to run to completion. +/// +/// Spawning a task using [`scope::spawn_cancellable`] is equivalent to spawning +/// it with [`scope::spawn`] and aborting execution when the tasks +/// [`CancellationToken`] was signalled: +/// +/// ```no_run +/// # use std::future::Future; +/// use tokio::task::scope; +/// +/// fn spawn_cancellable(task: T) -> scope::ScopedJoinHandle> +/// where +/// T: Future + Send + 'static, +/// R: Send + 'static, +/// T: 'static, +/// { +/// scope::spawn(async { +/// let ct = scope::current_cancellation_token(); +/// tokio::select! { +/// _ = ct.cancelled() => { +/// Err(scope::CancellationError{}) +/// }, +/// result = task => { +/// Ok(result) +/// }, +/// } +/// }) +/// } +/// ``` +/// +/// On tasks spawned via `spawn_cancellable` no new task scopes should be created +/// via `scope::enter`, since they are not guaranteed to run to completion. If the +/// a task gets force cancelled while a scope is active inside the task a +/// runtime panic will be emitted. +pub fn spawn_cancellable(task: T) -> ScopedJoinHandle> +where + T: Future + Send + 'static, + R: Send + 'static, + T: 'static, +{ + let current_scope_handle = CURRENT_SCOPE + .try_with(|scope_handle| scope_handle.clone()) + .unwrap(); + + current_scope_handle.spawn_cancellable(task) +} + +/// Returns the [`CancellationToken`] which is associated with the currently +/// running task and `scope`. +/// If the current `scope` gets cancelled the `CancellationToken` will be signalled +pub fn current_cancellation_token() -> CancellationToken { + // TODO: We could also return an Option here, but that is + // somewhat inconvenient to use. + // Or we wrap Option also in a struct which can product + // a `.cancelled()` `Future`, which would never resolve in case the token + // is `None`. + CURRENT_SCOPE + .try_with(|scope_handle| scope_handle.cancellation_token().clone()) + .unwrap_or_else(|_| CancellationToken::new()) +} + +/// Holds the current task-local [`ScopeHandle`] +static CURRENT_SCOPE: crate::task::LocalKey = { + std::thread_local! { + static __KEY: std::cell::RefCell> = std::cell::RefCell::new(None); + } + + crate::task::LocalKey { inner: __KEY } +}; + +/// Error type which is returned when a force-cancellable task was cancelled and +/// did not run to completion. +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq)] +pub struct CancellationError {} + +pin_project! { + /// Allows to wait for a child task to join + pub struct ScopedJoinHandle { + #[pin] + handle: JoinHandle, + } +} + +impl Future for ScopedJoinHandle { + // TODO: Assuming the runtime semantics are adapted to suite structured + /// concurrency better, the `JoinError` might not be necessary here. + /// - For gracefully cancelled tasks the runtime would need to wait until the + /// tasks finished. In this case the task would never be aborted and + /// JoinError is not necessary. + /// - For forcefully cancelled tasks which have been spawned using + /// `spawn_cancellable` the join error in form of JoinError/CancellationError + /// is still necessary - but not a nested version. + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().handle.poll(cx) + } +} + +#[derive(Clone)] +struct ScopeState { + wait_group: SharedWaitGroup, + config: ScopeConfig, +} + +impl ScopeState { + fn new(config: ScopeConfig) -> Self { + Self { + config, + wait_group: SharedWaitGroup::new(), + } + } +} + +struct CancelTasksGuard<'a> { + scope: &'a CancellationToken, +} + +impl<'a> Drop for CancelTasksGuard<'a> { + fn drop(&mut self) { + self.scope.cancel(); + } +} + +struct WaitForTasksToJoinGuard<'a> { + _wait_group: &'a SharedWaitGroup, + drop_behavior: ScopeDropBehavior, + enabled: bool, +} + +impl<'a> WaitForTasksToJoinGuard<'a> { + fn disarm(&mut self) { + self.enabled = false; + } +} + +impl<'a> Drop for WaitForTasksToJoinGuard<'a> { + fn drop(&mut self) { + if !self.enabled { + return; + } + + match self.drop_behavior { + ScopeDropBehavior::Panic => { + panic!("Scope was dropped before child tasks run to completion"); + } + } + } +} + +/// A handle to the scope, which allows to spawn child tasks +#[derive(Clone)] +struct ScopeHandle { + scope: ScopeState, +} + +impl core::fmt::Debug for ScopeHandle { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("ScopeHandle").finish() + } +} + +impl ScopeHandle { + /// Returns a reference to the `CancellationToken` which signals whether the + /// scope had been cancelled. + fn cancellation_token(&self) -> &CancellationToken { + &self.scope.config.cancellation_token + } + + /// Creates a child scope. + /// The child scope inherit all properties of this scope + fn child(&self) -> Scope { + Scope::with_parent(self.clone()) + } + + /// Spawns a task on the scope + fn spawn(&self, task: T) -> ScopedJoinHandle + where + T: Future + Send + 'static, + R: Send + 'static, + { + let spawn_handle = + crate::runtime::context::spawn_handle().expect("Spawn handle must be available"); + + // Add a wait handle + // This must happen BEFORE we spawn the child task - otherwise this is + // would be racy. + let releaser = self.scope.wait_group.add(); + + let self_clone = self.clone(); + let child_task: JoinHandle = spawn_handle.spawn(async move { + // Drop this at the end of the task to signal we are done and unblock + // the WaitGroup + let _wait_group_releaser = releaser; + + // Set the thread local scope handle so that the child task inherits + // the properties from the parent and execute it. + // TODO: In case the properties would already be required for spawning + // (e.g. in order to bind a task to a certain runtime thread) this + // place would already be too late. This properties would rather + // need to be passed to `runtime::context::spawn_handle()`. + CURRENT_SCOPE.scope(self_clone, task).await + }); + + ScopedJoinHandle { handle: child_task } + } + + /// Spawns a task on the scope, which will get automatically cancelled if + /// the `CancellationToken` which is associated with the current `ScopeHandle` + /// gets cancelled. + fn spawn_cancellable<'inner, T, R>( + &'inner self, + task: T, + ) -> ScopedJoinHandle> + where + T: Future + Send + 'static, + R: Send + 'static, + T: 'inner, + { + let cancel_token = self.cancellation_token().clone(); + + self.spawn(async move { + crate::pin!(task); + use futures::FutureExt; + + // TODO: This should use `tokio::select!`. But using this macro from + // inside the tokio crate just produces wonderful error messages. + futures::select! { + _ = cancel_token.cancelled().fuse() => { + // The child task was cancelled + Err(CancellationError{}) + }, + res = task.fuse() => { + Ok(res) + } + } + }) + } +} + +/// Defines how a scope will behave if the `Future` it returns get dropped +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ScopeDropBehavior { + /// When a scope is dropped while tasks are outstanding, the current thread + /// will panic. Since this will not wait for child tasks to complete, the + /// child tasks can outlive the parent in this case. + Panic, +} + +/// Advanced configuration options for `scope` +#[derive(Debug, Clone)] +struct ScopeConfig { + drop_behavior: ScopeDropBehavior, + cancellation_token: CancellationToken, +} + +/// Allows to configure a new `Scope` +#[derive(Debug)] +struct ScopeConfigBuilder { + /// The parent scope if available + parent: Option, + /// Drop behavior overwrite + drop_behavior: Option, +} + +impl ScopeConfigBuilder { + /// Creates a new scope which is treated as a child scope of the given + /// [`ScopeHandle`]. The new scope will inherit all properties of the parent + /// scope. In addition tasks inside the new scope will get cancelled when + /// the parent scope gets cancelled. + fn with_parent(parent: ScopeHandle) -> Self { + Self { + parent: Some(parent), + drop_behavior: None, + } + } + + /// Creates a new scope which is detached from any parent scope. + /// Tasks spawned on this `scope` will not get cancelled if any parent scope + /// gets cancelled. Instead those tasks would only get cancelled if the + /// scope itself gets cancelled. + fn detached() -> Self { + Self { + parent: None, + drop_behavior: Some(ScopeDropBehavior::Panic), + } + } + + /// Builds the configuration for the scope + fn build(self) -> Result { + // Get defaults + + // Generate a `CancellationToken`. If a parent scope and an associated + // cancellation token exists, we create a child token from it. + let cancellation_token = if let Some(parent) = &self.parent { + parent.scope.config.cancellation_token.child_token() + } else { + CancellationToken::new() + }; + + let mut drop_behavior = match &self.parent { + Some(parent) => parent.scope.config.drop_behavior, + None => ScopeDropBehavior::Panic, + }; + + // Apply overwrites + if let Some(behavior) = self.drop_behavior { + drop_behavior = behavior + }; + + Ok(ScopeConfig { + cancellation_token, + drop_behavior, + }) + } +} + +#[derive(Debug)] +enum ScopeConfigBuilderError {} + +/// A builder with allows to build and enter a new task scope. +#[derive(Debug)] +struct Scope { + /// Configuration options for the scope + config: ScopeConfig, +} + +impl Scope { + /// Creates a new scope which is treated as a child scope of the given + /// [`ScopeHandle`]. The new scope will inherit all properties of the parent + /// scope. In addition tasks inside the new scope will get cancelled when + /// the parent scope gets cancelled. + fn with_parent(parent: ScopeHandle) -> Self { + Self::with_config( + ScopeConfigBuilder::with_parent(parent) + .build() + .expect("Inherited config can not fail"), + ) + } + + /// Creates a new scope which is detached from any parent scope. + /// Tasks spawned on this `scope` will not get cancelled if any parent scope + /// gets cancelled. Instead those tasks would only get cancelled if the + /// scope itself gets cancelled. + fn detached() -> Self { + Self::with_config( + ScopeConfigBuilder::detached() + .build() + .expect("Default config can not fail"), + ) + } + + /// Creates a `Scope` with the given configuration + fn with_config(config: ScopeConfig) -> Self { + Self { config } + } + + /// Creates a [`scope`] with custom options + /// + /// The method behaves like [`scope`], but the cancellation and `Drop` behavior + /// for the [`scope`] are configurable. See [`ScopeConfig`] for details. + async fn enter(self, scope_fut: Fut) -> R + where + Fut: Future + Send, + { + let scope_state = ScopeState::new(self.config); + let wait_fut = scope_state.wait_group.wait_future(); + + // This guard will be called be executed if the scope gets dropped while + // it is still executing. + let mut wait_for_tasks_guard = WaitForTasksToJoinGuard { + _wait_group: &scope_state.wait_group, + enabled: true, + drop_behavior: scope_state.config.drop_behavior, + }; + + let scoped_result = { + // This guard will call `.cancel()` on the `CancellationToken` we + // just created. + let _cancel_guard = CancelTasksGuard { + scope: &scope_state.config.cancellation_token, + }; + + let handle = ScopeHandle { + scope: scope_state.clone(), + }; + + // Execute the scope handler, which gets passed a handle to the newly + // created scope + CURRENT_SCOPE.scope(handle, scope_fut).await + }; + + // Wait for all remaining tasks inside the scope to complete + wait_fut.await; + + // The tasks have completed. We do not need to wait for them to complete + // in the `Drop` guard. + wait_for_tasks_guard.disarm(); + + scoped_result + } +} diff --git a/tokio/tests/task_scope.rs b/tokio/tests/task_scope.rs new file mode 100644 index 00000000000..3a1e22a8e7f --- /dev/null +++ b/tokio/tests/task_scope.rs @@ -0,0 +1,202 @@ +#![warn(rust_2018_idioms)] +#![cfg(tokio_unstable)] +#![cfg(feature = "full")] + +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::{Duration, Instant}, +}; +use tokio::{select, task::scope, time::delay_for}; + +#[derive(Clone)] +struct AtomicFlag(Arc); + +impl AtomicFlag { + fn new() -> Self { + AtomicFlag(Arc::new(AtomicBool::new(false))) + } + + fn is_set(&self) -> bool { + self.0.load(Ordering::Acquire) + } + + fn set(&self) { + self.0.store(true, Ordering::Release); + } +} + +struct SetFlagOnDropGuard { + flag: AtomicFlag, +} + +impl Drop for SetFlagOnDropGuard { + fn drop(&mut self) { + self.flag.set(); + } +} + +#[tokio::test] +async fn unused_scope() { + let scope = scope::enter(async {}); + drop(scope); +} + +#[tokio::test] +async fn spawn_and_return_result() { + let result = scope::enter(async move { + let handle = scope::spawn(async { + tokio::time::delay_for(std::time::Duration::from_millis(500)).await; + 123u32 + }); + handle.await + }) + .await; + assert_eq!(123u32, result.unwrap()); +} + +#[tokio::test] +async fn cancel_and_wait_for_child_task() { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let result = scope::enter(async move { + let handle = scope::spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope::spawn_cancellable(async { + let _guard = SetFlagOnDropGuard { flag: flag_clone }; + loop { + tokio::task::yield_now().await; + } + }); + + handle.await + }) + .await; + assert_eq!(123u32, result.unwrap()); + + // Check that the second task was cancelled + assert_eq!(true, flag.is_set()); +} + +#[tokio::test] +async fn graceful_cancellation() { + let result = scope::enter(async move { + scope::spawn(async { + delay_for(Duration::from_millis(20)).await; + scope::current_cancellation_token().cancel(); + 123u32 + }); + + scope::spawn(async { + let ct = scope::current_cancellation_token(); + select! { + _ = ct.cancelled() => { + 1 + }, + _ = delay_for(Duration::from_millis(5000)) => { + 2 + }, + } + }) + .await + }) + .await; + assert_eq!(1, result.unwrap()); +} + +#[tokio::test] +async fn cancels_nested_scopes() { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let result = scope::enter(async move { + let ct = scope::current_cancellation_token(); + + let handle = scope::spawn(async move { + delay_for(Duration::from_millis(200)).await; + // Cancelling the parent scope should also cancel the task + // which is running insie a child scope + ct.cancel(); + 123u32 + }); + + scope::enter(async move { + dbg!("Start of scope"); + let _ = scope::spawn_cancellable(async { + let _guard = SetFlagOnDropGuard { flag: flag_clone }; + loop { + tokio::task::yield_now().await; + } + }) + .await; + }) + .await; + + handle.await + }) + .await; + assert_eq!(123u32, result.unwrap()); + + // Check that the second task was cancelled + assert_eq!(true, flag.is_set()); +} + +#[tokio::test] +async fn wait_until_non_joined_tasks_complete() { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + let start_time = Instant::now(); + + let _ = scope::enter(async move { + let handle = scope::spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope::spawn(async move { + tokio::time::delay_for(Duration::from_millis(100)).await; + flag_clone.set(); + }); + + handle.await + }) + .await; + + assert!(start_time.elapsed() >= Duration::from_millis(100)); + + // Check that the second task run to completion + assert_eq!(true, flag.is_set()); +} + +#[should_panic] +#[tokio::test] +async fn panic_if_active_scope_is_dropped() { + let scope_fut = scope::enter(async move { + let handle = scope::spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + // Spawn a long running task which prevents the task from finishing + scope::spawn(async move { + tokio::time::delay_for(Duration::from_millis(1000)).await; + }); + + handle.await + }); + + select! { + _ = scope_fut => { + panic!("Scope should not complete"); + }, + _ = delay_for(Duration::from_millis(50)) => { + // Drop the scope here + }, + }; +}