Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for !Send tasks #1216

Merged
merged 2 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/bevy_tasks/src/single_threaded_task_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ impl TaskPool {
});
FakeTask
}

pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> FakeTask
where
T: 'static,
{
self.spawn(future)
}
}

#[derive(Debug)]
Expand Down
161 changes: 114 additions & 47 deletions crates/bevy_tasks/src/task_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ pub struct TaskPool {
}

impl TaskPool {
thread_local! {
static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new();
}

/// Create a `TaskPool` with the default configuration.
pub fn new() -> Self {
TaskPoolBuilder::new().build()
Expand Down Expand Up @@ -162,58 +166,63 @@ impl TaskPool {
F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send,
T: Send + 'static,
{
// SAFETY: This function blocks until all futures complete, so this future must return
// before this function returns. However, rust has no way of knowing
// this so we must convert to 'static here to appease the compiler as it is unable to
// validate safety.
let executor: &async_executor::Executor = &*self.executor;
let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) };

let mut scope = Scope {
executor,
spawned: Vec::new(),
};

f(&mut scope);

if scope.spawned.is_empty() {
Vec::default()
} else if scope.spawned.len() == 1 {
vec![future::block_on(&mut scope.spawned[0])]
} else {
let fut = async move {
let mut results = Vec::with_capacity(scope.spawned.len());
for task in scope.spawned {
results.push(task.await);
}

results
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
// SAFETY: This function blocks until all futures complete, so this future must return
// before this function returns. However, rust has no way of knowing
// this so we must convert to 'static here to appease the compiler as it is unable to
// validate safety.
let executor: &async_executor::Executor = &*self.executor;
let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) };
let local_executor: &'scope async_executor::LocalExecutor =
unsafe { mem::transmute(local_executor) };
let mut scope = Scope {
executor,
local_executor,
spawned: Vec::new(),
};

// Pin the future on the stack.
pin!(fut);
f(&mut scope);

if scope.spawned.is_empty() {
Vec::default()
} else if scope.spawned.len() == 1 {
vec![future::block_on(&mut scope.spawned[0])]
} else {
let fut = async move {
let mut results = Vec::with_capacity(scope.spawned.len());
for task in scope.spawned {
results.push(task.await);
}

// SAFETY: This function blocks until all futures complete, so we do not read/write the
// data from futures outside of the 'scope lifetime. However, rust has no way of knowing
// this so we must convert to 'static here to appease the compiler as it is unable to
// validate safety.
let fut: Pin<&mut (dyn Future<Output = Vec<T>> + Send)> = fut;
let fut: Pin<&'static mut (dyn Future<Output = Vec<T>> + Send + 'static)> =
unsafe { mem::transmute(fut) };

// The thread that calls scope() will participate in driving tasks in the pool forward
// until the tasks that are spawned by this scope() call complete. (If the caller of scope()
// happens to be a thread in this thread pool, and we only have one thread in the pool, then
// simply calling future::block_on(spawned) would deadlock.)
let mut spawned = self.executor.spawn(fut);
loop {
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
break result;
}
results
};

self.executor.try_tick();
// Pin the futures on the stack.
pin!(fut);

// SAFETY: This function blocks until all futures complete, so we do not read/write the
// data from futures outside of the 'scope lifetime. However, rust has no way of knowing
// this so we must convert to 'static here to appease the compiler as it is unable to
// validate safety.
let fut: Pin<&mut (dyn Future<Output = Vec<T>>)> = fut;
let fut: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static)> =
unsafe { mem::transmute(fut) };

// The thread that calls scope() will participate in driving tasks in the pool forward
// until the tasks that are spawned by this scope() call complete. (If the caller of scope()
// happens to be a thread in this thread pool, and we only have one thread in the pool, then
// simply calling future::block_on(spawned) would deadlock.)
let mut spawned = local_executor.spawn(fut);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we're running all of the tasks on the local executor, regardless of the context they were spawned with. Maybe we should just have two scopes for simplicity / to avoid potential overhead when only one context is needed? Ex: pool.scope() and pool.local_scope()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah wait. We're awaiting (and i assume that means polling) using the local executor only. Which might be fine? I'm not sure! Separate scopes would probably still be safest.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. The tasks run on the executor where they were spawned and are only awaited on the local executor. Doing it that way makes the code significantly cleaner because otherwise you have to basically duplicate every line because the executors have different types with no unifying trait.

I see that the tests I've added don't actually prove that this is running the tasks on the right threads. Let me add a test that demonstrates that. If you're still uncomfortable with it I can move to awaiting two groups of futures since I agree that is more obvious.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm less worried about the parallelism (spawning the task with the proper executor is proof enough for me). I'm more worried about using the local executor to poll tasks that it didn't spawn. That seems a bit weird / might produce undefined behavior. I'm not sure how that's implemented (and even if the current implementation works, can we be sure that will always be the case?)

Copy link
Member Author

@alec-deason alec-deason Jan 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a valid concern. Let me go read more in async-executor/std::future and see what the guarantees actually are.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I am convinced this is a reasonable thing to do because of the way async-task is designed. When you spawn a future via their API they produce two structs, a Task and a Runnable. The Task, which is the only part we interact with directly, is used only to check for completion and retrieve the result, it doesn't participate in polling the underlying future at all. The Runnable, which is retained by the Executor/LocalExecutor is what actually polls the future when try_tick is called. So using the Task in a composite future that waits for results is exactly what it is intended for. Turning the crank on the executor to actually run the main futures is a completely separate process.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alrighty I'm decently convinced at this point. Lets get this merged!

loop {
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
break result;
};

self.executor.try_tick();
local_executor.try_tick();
}
}
}
})
}

/// Spawns a static future onto the thread pool. The returned Task is a future. It can also be
Expand All @@ -225,6 +234,13 @@ impl TaskPool {
{
Task::new(self.executor.spawn(future))
}

pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
where
T: 'static,
{
Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
}
}

impl Default for TaskPool {
Expand All @@ -236,6 +252,7 @@ impl Default for TaskPool {
#[derive(Debug)]
pub struct Scope<'scope, T> {
executor: &'scope async_executor::Executor<'scope>,
local_executor: &'scope async_executor::LocalExecutor<'scope>,
spawned: Vec<async_executor::Task<T>>,
}

Expand All @@ -244,6 +261,11 @@ impl<'scope, T: Send + 'scope> Scope<'scope, T> {
let task = self.executor.spawn(f);
self.spawned.push(task);
}

pub fn spawn_local<Fut: Future<Output = T> + 'scope>(&mut self, f: Fut) {
let task = self.local_executor.spawn(f);
self.spawned.push(task);
}
}

#[cfg(test)]
Expand Down Expand Up @@ -281,4 +303,49 @@ mod tests {
assert_eq!(outputs.len(), 100);
assert_eq!(count.load(Ordering::Relaxed), 100);
}

#[test]
pub fn test_mixed_spawn_local_and_spawn() {
let pool = TaskPool::new();

let foo = Box::new(42);
let foo = &*foo;

let local_count = Arc::new(AtomicI32::new(0));
let non_local_count = Arc::new(AtomicI32::new(0));

let outputs = pool.scope(|scope| {
for i in 0..100 {
if i % 2 == 0 {
let count_clone = non_local_count.clone();
scope.spawn(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
count_clone.fetch_add(1, Ordering::Relaxed);
*foo
}
});
} else {
let count_clone = local_count.clone();
scope.spawn_local(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
count_clone.fetch_add(1, Ordering::Relaxed);
*foo
}
});
}
}
});

for output in &outputs {
assert_eq!(*output, 42);
}

assert_eq!(outputs.len(), 100);
assert_eq!(local_count.load(Ordering::Relaxed), 50);
assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
}
}