Skip to content

Commit

Permalink
Make task::Builder::spawn* methods fallible (#4823)
Browse files Browse the repository at this point in the history
  • Loading branch information
ipetkov authored Jul 12, 2022
1 parent de686b5 commit 3b6c74a
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 38 deletions.
2 changes: 1 addition & 1 deletion tokio/src/runtime/blocking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! compilation.

mod pool;
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task};
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, SpawnError, Spawner, Task};

cfg_fs! {
pub(crate) use pool::spawn_mandatory_blocking;
Expand Down
26 changes: 23 additions & 3 deletions tokio/src/runtime/blocking/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::runtime::{Builder, Callback, ToHandle};

use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::io;
use std::time::Duration;

pub(crate) struct BlockingPool {
Expand Down Expand Up @@ -82,6 +83,25 @@ pub(crate) enum Mandatory {
NonMandatory,
}

pub(crate) enum SpawnError {
/// Pool is shutting down and the task was not scheduled
ShuttingDown,
/// There are no worker threads available to take the task
/// and the OS failed to spawn a new one
NoThreads(io::Error),
}

impl From<SpawnError> for io::Error {
fn from(e: SpawnError) -> Self {
match e {
SpawnError::ShuttingDown => {
io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
}
SpawnError::NoThreads(e) => e,
}
}
}

impl Task {
pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task {
Task { task, mandatory }
Expand Down Expand Up @@ -221,7 +241,7 @@ impl fmt::Debug for BlockingPool {
// ===== impl Spawner =====

impl Spawner {
pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), ()> {
pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), SpawnError> {
let mut shared = self.inner.shared.lock();

if shared.shutdown {
Expand All @@ -231,7 +251,7 @@ impl Spawner {
task.task.shutdown();

// no need to even push this task; it would never get picked up
return Err(());
return Err(SpawnError::ShuttingDown);
}

shared.queue.push_back(task);
Expand Down Expand Up @@ -262,7 +282,7 @@ impl Spawner {
Err(e) => {
// The OS refused to spawn the thread and there is no thread
// to pick up the task that has just been pushed to the queue.
panic!("OS can't spawn worker thread: {}", e)
return Err(SpawnError::NoThreads(e));
}
}
}
Expand Down
19 changes: 13 additions & 6 deletions tokio/src/runtime/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,22 @@ impl HandleInner {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (join_handle, _was_spawned) = if cfg!(debug_assertions)
let (join_handle, spawn_result) = if cfg!(debug_assertions)
&& std::mem::size_of::<F>() > 2048
{
self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None, rt)
} else {
self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None, rt)
};

join_handle
match spawn_result {
Ok(()) => join_handle,
// Compat: do not panic here, return the join_handle even though it will never resolve
Err(blocking::SpawnError::ShuttingDown) => join_handle,
Err(blocking::SpawnError::NoThreads(e)) => {
panic!("OS can't spawn worker thread: {}", e)
}
}
}

cfg_fs! {
Expand All @@ -363,7 +370,7 @@ impl HandleInner {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
self.spawn_blocking_inner(
Box::new(func),
blocking::Mandatory::Mandatory,
Expand All @@ -379,7 +386,7 @@ impl HandleInner {
)
};

if was_spawned {
if spawn_result.is_ok() {
Some(join_handle)
} else {
None
Expand All @@ -394,7 +401,7 @@ impl HandleInner {
is_mandatory: blocking::Mandatory,
name: Option<&str>,
rt: &dyn ToHandle,
) -> (JoinHandle<R>, bool)
) -> (JoinHandle<R>, Result<(), blocking::SpawnError>)
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
Expand Down Expand Up @@ -424,7 +431,7 @@ impl HandleInner {
let spawned = self
.blocking_spawner
.spawn(blocking::Task::new(task, is_mandatory), rt);
(handle, spawned.is_ok())
(handle, spawned)
}
}

Expand Down
41 changes: 27 additions & 14 deletions tokio/src/task/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
runtime::{context, Handle},
task::{JoinHandle, LocalSet},
};
use std::future::Future;
use std::{future::Future, io};

/// Factory which is used to configure the properties of a new task.
///
Expand Down Expand Up @@ -48,7 +48,7 @@ use std::future::Future;
/// .spawn(async move {
/// // Process each socket concurrently.
/// process(socket).await
/// });
/// })?;
/// }
/// }
/// ```
Expand Down Expand Up @@ -83,12 +83,12 @@ impl<'a> Builder<'a> {
/// See [`task::spawn`](crate::task::spawn) for
/// more details.
#[track_caller]
pub fn spawn<Fut>(self, future: Fut) -> JoinHandle<Fut::Output>
pub fn spawn<Fut>(self, future: Fut) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
super::spawn::spawn_inner(future, self.name)
Ok(super::spawn::spawn_inner(future, self.name))
}

/// Spawn a task with this builder's settings on the provided [runtime
Expand All @@ -99,12 +99,16 @@ impl<'a> Builder<'a> {
/// [runtime handle]: crate::runtime::Handle
/// [`Handle::spawn`]: crate::runtime::Handle::spawn
#[track_caller]
pub fn spawn_on<Fut>(&mut self, future: Fut, handle: &Handle) -> JoinHandle<Fut::Output>
pub fn spawn_on<Fut>(
&mut self,
future: Fut,
handle: &Handle,
) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
handle.spawn_named(future, self.name)
Ok(handle.spawn_named(future, self.name))
}

/// Spawns `!Send` a task on the current [`LocalSet`] with this builder's
Expand All @@ -122,12 +126,12 @@ impl<'a> Builder<'a> {
/// [`task::spawn_local`]: crate::task::spawn_local
/// [`LocalSet`]: crate::task::LocalSet
#[track_caller]
pub fn spawn_local<Fut>(self, future: Fut) -> JoinHandle<Fut::Output>
pub fn spawn_local<Fut>(self, future: Fut) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + 'static,
Fut::Output: 'static,
{
super::local::spawn_local_inner(future, self.name)
Ok(super::local::spawn_local_inner(future, self.name))
}

/// Spawns `!Send` a task on the provided [`LocalSet`] with this builder's
Expand All @@ -138,12 +142,16 @@ impl<'a> Builder<'a> {
/// [`LocalSet::spawn_local`]: crate::task::LocalSet::spawn_local
/// [`LocalSet`]: crate::task::LocalSet
#[track_caller]
pub fn spawn_local_on<Fut>(self, future: Fut, local_set: &LocalSet) -> JoinHandle<Fut::Output>
pub fn spawn_local_on<Fut>(
self,
future: Fut,
local_set: &LocalSet,
) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + 'static,
Fut::Output: 'static,
{
local_set.spawn_named(future, self.name)
Ok(local_set.spawn_named(future, self.name))
}

/// Spawns blocking code on the blocking threadpool.
Expand All @@ -155,7 +163,10 @@ impl<'a> Builder<'a> {
/// See [`task::spawn_blocking`](crate::task::spawn_blocking)
/// for more details.
#[track_caller]
pub fn spawn_blocking<Function, Output>(self, function: Function) -> JoinHandle<Output>
pub fn spawn_blocking<Function, Output>(
self,
function: Function,
) -> io::Result<JoinHandle<Output>>
where
Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static,
Expand All @@ -174,18 +185,20 @@ impl<'a> Builder<'a> {
self,
function: Function,
handle: &Handle,
) -> JoinHandle<Output>
) -> io::Result<JoinHandle<Output>>
where
Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static,
{
use crate::runtime::Mandatory;
let (join_handle, _was_spawned) = handle.as_inner().spawn_blocking_inner(
let (join_handle, spawn_result) = handle.as_inner().spawn_blocking_inner(
function,
Mandatory::NonMandatory,
self.name,
handle,
);
join_handle

spawn_result?;
Ok(join_handle)
}
}
25 changes: 14 additions & 11 deletions tokio/src/task/join_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,15 @@ impl<T: 'static> JoinSet<T> {
/// use tokio::task::JoinSet;
///
/// #[tokio::main]
/// async fn main() {
/// async fn main() -> std::io::Result<()> {
/// let mut set = JoinSet::new();
///
/// // Use the builder to configure a task's name before spawning it.
/// set.build_task()
/// .name("my_task")
/// .spawn(async { /* ... */ });
/// .spawn(async { /* ... */ })?;
///
/// Ok(())
/// }
/// ```
#[cfg(all(tokio_unstable, feature = "tracing"))]
Expand Down Expand Up @@ -377,13 +379,13 @@ impl<'a, T: 'static> Builder<'a, T> {
///
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn<F>(self, future: F) -> AbortHandle
pub fn spawn<F>(self, future: F) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.joinset.insert(self.builder.spawn(future))
Ok(self.joinset.insert(self.builder.spawn(future)?))
}

/// Spawn the provided task on the provided [runtime handle] with this
Expand All @@ -397,13 +399,13 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`AbortHandle`]: crate::task::AbortHandle
/// [runtime handle]: crate::runtime::Handle
#[track_caller]
pub fn spawn_on<F>(mut self, future: F, handle: &Handle) -> AbortHandle
pub fn spawn_on<F>(mut self, future: F, handle: &Handle) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.joinset.insert(self.builder.spawn_on(future, handle))
Ok(self.joinset.insert(self.builder.spawn_on(future, handle)?))
}

/// Spawn the provided task on the current [`LocalSet`] with this builder's
Expand All @@ -420,12 +422,12 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`LocalSet`]: crate::task::LocalSet
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn_local<F>(self, future: F) -> AbortHandle
pub fn spawn_local<F>(self, future: F) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: 'static,
{
self.joinset.insert(self.builder.spawn_local(future))
Ok(self.joinset.insert(self.builder.spawn_local(future)?))
}

/// Spawn the provided task on the provided [`LocalSet`] with this builder's
Expand All @@ -438,13 +440,14 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`LocalSet`]: crate::task::LocalSet
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> AbortHandle
pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: 'static,
{
self.joinset
.insert(self.builder.spawn_local_on(future, local_set))
Ok(self
.joinset
.insert(self.builder.spawn_local_on(future, local_set)?))
}
}

Expand Down
20 changes: 17 additions & 3 deletions tokio/tests/task_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod tests {
let result = Builder::new()
.name("name")
.spawn(async { "task executed" })
.unwrap()
.await;

assert_eq!(result.unwrap(), "task executed");
Expand All @@ -21,6 +22,7 @@ mod tests {
let result = Builder::new()
.name("name")
.spawn_blocking(|| "task executed")
.unwrap()
.await;

assert_eq!(result.unwrap(), "task executed");
Expand All @@ -34,6 +36,7 @@ mod tests {
Builder::new()
.name("name")
.spawn_local(async move { unsend_data })
.unwrap()
.await
})
.await;
Expand All @@ -43,14 +46,20 @@ mod tests {

#[test]
async fn spawn_without_name() {
let result = Builder::new().spawn(async { "task executed" }).await;
let result = Builder::new()
.spawn(async { "task executed" })
.unwrap()
.await;

assert_eq!(result.unwrap(), "task executed");
}

#[test]
async fn spawn_blocking_without_name() {
let result = Builder::new().spawn_blocking(|| "task executed").await;
let result = Builder::new()
.spawn_blocking(|| "task executed")
.unwrap()
.await;

assert_eq!(result.unwrap(), "task executed");
}
Expand All @@ -59,7 +68,12 @@ mod tests {
async fn spawn_local_without_name() {
let unsend_data = Rc::new("task executed");
let result = LocalSet::new()
.run_until(async move { Builder::new().spawn_local(async move { unsend_data }).await })
.run_until(async move {
Builder::new()
.spawn_local(async move { unsend_data })
.unwrap()
.await
})
.await;

assert_eq!(*result.unwrap(), "task executed");
Expand Down

0 comments on commit 3b6c74a

Please sign in to comment.