diff --git a/crates/bevy_tasks/src/single_threaded_task_pool.rs b/crates/bevy_tasks/src/single_threaded_task_pool.rs index 6b515b1326d13..89137cb141226 100644 --- a/crates/bevy_tasks/src/single_threaded_task_pool.rs +++ b/crates/bevy_tasks/src/single_threaded_task_pool.rs @@ -102,6 +102,13 @@ impl TaskPool { }); FakeTask } + + pub fn spawn_local(&self, future: impl Future + 'static) -> FakeTask + where + T: 'static, + { + self.spawn(future) + } } #[derive(Debug)] diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index febd8351c9a29..39e049eed2372 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -95,6 +95,10 @@ pub struct TaskPool { } impl TaskPool { + thread_local! { + static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new(); + } + /// Create a `TaskPool` with the default configuration. pub fn new() -> Self { TaskPoolBuilder::new().build() @@ -162,58 +166,63 @@ impl TaskPool { F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send, T: Send + 'static, { - // SAFETY: This function blocks until all futures complete, so this future must return - // before this function returns. However, rust has no way of knowing - // this so we must convert to 'static here to appease the compiler as it is unable to - // validate safety. - let executor: &async_executor::Executor = &*self.executor; - let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) }; - - let mut scope = Scope { - executor, - spawned: Vec::new(), - }; - - f(&mut scope); - - if scope.spawned.is_empty() { - Vec::default() - } else if scope.spawned.len() == 1 { - vec![future::block_on(&mut scope.spawned[0])] - } else { - let fut = async move { - let mut results = Vec::with_capacity(scope.spawned.len()); - for task in scope.spawned { - results.push(task.await); - } - - results + TaskPool::LOCAL_EXECUTOR.with(|local_executor| { + // SAFETY: This function blocks until all futures complete, so this future must return + // before this function returns. However, rust has no way of knowing + // this so we must convert to 'static here to appease the compiler as it is unable to + // validate safety. + let executor: &async_executor::Executor = &*self.executor; + let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) }; + let local_executor: &'scope async_executor::LocalExecutor = + unsafe { mem::transmute(local_executor) }; + let mut scope = Scope { + executor, + local_executor, + spawned: Vec::new(), }; - // Pin the future on the stack. - pin!(fut); + f(&mut scope); + + if scope.spawned.is_empty() { + Vec::default() + } else if scope.spawned.len() == 1 { + vec![future::block_on(&mut scope.spawned[0])] + } else { + let fut = async move { + let mut results = Vec::with_capacity(scope.spawned.len()); + for task in scope.spawned { + results.push(task.await); + } - // SAFETY: This function blocks until all futures complete, so we do not read/write the - // data from futures outside of the 'scope lifetime. However, rust has no way of knowing - // this so we must convert to 'static here to appease the compiler as it is unable to - // validate safety. - let fut: Pin<&mut (dyn Future> + Send)> = fut; - let fut: Pin<&'static mut (dyn Future> + Send + 'static)> = - unsafe { mem::transmute(fut) }; - - // The thread that calls scope() will participate in driving tasks in the pool forward - // until the tasks that are spawned by this scope() call complete. (If the caller of scope() - // happens to be a thread in this thread pool, and we only have one thread in the pool, then - // simply calling future::block_on(spawned) would deadlock.) - let mut spawned = self.executor.spawn(fut); - loop { - if let Some(result) = future::block_on(future::poll_once(&mut spawned)) { - break result; - } + results + }; - self.executor.try_tick(); + // Pin the futures on the stack. + pin!(fut); + + // SAFETY: This function blocks until all futures complete, so we do not read/write the + // data from futures outside of the 'scope lifetime. However, rust has no way of knowing + // this so we must convert to 'static here to appease the compiler as it is unable to + // validate safety. + let fut: Pin<&mut (dyn Future>)> = fut; + let fut: Pin<&'static mut (dyn Future> + 'static)> = + unsafe { mem::transmute(fut) }; + + // The thread that calls scope() will participate in driving tasks in the pool forward + // until the tasks that are spawned by this scope() call complete. (If the caller of scope() + // happens to be a thread in this thread pool, and we only have one thread in the pool, then + // simply calling future::block_on(spawned) would deadlock.) + let mut spawned = local_executor.spawn(fut); + loop { + if let Some(result) = future::block_on(future::poll_once(&mut spawned)) { + break result; + }; + + self.executor.try_tick(); + local_executor.try_tick(); + } } - } + }) } /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be @@ -225,6 +234,13 @@ impl TaskPool { { Task::new(self.executor.spawn(future)) } + + pub fn spawn_local(&self, future: impl Future + 'static) -> Task + where + T: 'static, + { + Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future))) + } } impl Default for TaskPool { @@ -236,6 +252,7 @@ impl Default for TaskPool { #[derive(Debug)] pub struct Scope<'scope, T> { executor: &'scope async_executor::Executor<'scope>, + local_executor: &'scope async_executor::LocalExecutor<'scope>, spawned: Vec>, } @@ -244,12 +261,20 @@ impl<'scope, T: Send + 'scope> Scope<'scope, T> { let task = self.executor.spawn(f); self.spawned.push(task); } + + pub fn spawn_local + 'scope>(&mut self, f: Fut) { + let task = self.local_executor.spawn(f); + self.spawned.push(task); + } } #[cfg(test)] mod tests { use super::*; - use std::sync::atomic::{AtomicI32, Ordering}; + use std::sync::{ + atomic::{AtomicBool, AtomicI32, Ordering}, + Barrier, + }; #[test] pub fn test_spawn() { @@ -281,4 +306,85 @@ mod tests { assert_eq!(outputs.len(), 100); assert_eq!(count.load(Ordering::Relaxed), 100); } + + #[test] + pub fn test_mixed_spawn_local_and_spawn() { + let pool = TaskPool::new(); + + let foo = Box::new(42); + let foo = &*foo; + + let local_count = Arc::new(AtomicI32::new(0)); + let non_local_count = Arc::new(AtomicI32::new(0)); + + let outputs = pool.scope(|scope| { + for i in 0..100 { + if i % 2 == 0 { + let count_clone = non_local_count.clone(); + scope.spawn(async move { + if *foo != 42 { + panic!("not 42!?!?") + } else { + count_clone.fetch_add(1, Ordering::Relaxed); + *foo + } + }); + } else { + let count_clone = local_count.clone(); + scope.spawn_local(async move { + if *foo != 42 { + panic!("not 42!?!?") + } else { + count_clone.fetch_add(1, Ordering::Relaxed); + *foo + } + }); + } + } + }); + + for output in &outputs { + assert_eq!(*output, 42); + } + + assert_eq!(outputs.len(), 100); + assert_eq!(local_count.load(Ordering::Relaxed), 50); + assert_eq!(non_local_count.load(Ordering::Relaxed), 50); + } + + #[test] + pub fn test_thread_locality() { + let pool = Arc::new(TaskPool::new()); + let count = Arc::new(AtomicI32::new(0)); + let barrier = Arc::new(Barrier::new(101)); + let thread_check_failed = Arc::new(AtomicBool::new(false)); + + for _ in 0..100 { + let inner_barrier = barrier.clone(); + let count_clone = count.clone(); + let inner_pool = pool.clone(); + let inner_thread_check_failed = thread_check_failed.clone(); + std::thread::spawn(move || { + inner_pool.scope(|scope| { + let inner_count_clone = count_clone.clone(); + scope.spawn(async move { + inner_count_clone.fetch_add(1, Ordering::Release); + }); + let spawner = std::thread::current().id(); + let inner_count_clone = count_clone.clone(); + scope.spawn_local(async move { + inner_count_clone.fetch_add(1, Ordering::Release); + if std::thread::current().id() != spawner { + // NOTE: This check is using an atomic rather than simply panicing the thread to avoid deadlocking the barrier on failure + inner_thread_check_failed.store(true, Ordering::Release); + } + }); + }); + inner_barrier.wait(); + }); + } + barrier.wait(); + assert!(!thread_check_failed.load(Ordering::Acquire)); + assert_eq!(count.load(Ordering::Acquire), 200); + } }