Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve LocalPoolHandle #4680

Merged
merged 8 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 129 additions & 5 deletions tokio-util/src/task/spawn_pinned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,44 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::task::{spawn_local, JoinHandle, LocalSet};

/// A handle to a local pool, used for spawning `!Send` tasks.
/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
///
/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
/// execute on the same thread) inside the Future you supply to the various spawn methods
/// of `LocalPoolHandle`,
///
/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
///
/// # Examples
///
/// ```
/// use std::rc::Rc;
/// use tokio::{self, task };
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main(flavor = "current_thread")]
/// async fn main() {
/// let pool = LocalPoolHandle::new(5);
///
/// let output = pool.spawn_pinned(|| {
/// // `data` is !Send + !Sync
/// let data = Rc::new("local data");
/// let data_clone = data.clone();
///
/// async move {
/// task::spawn_local(async move {
/// println!("{}", data_clone);
/// });
///
/// data.to_string()
/// }
/// }).await.unwrap();
/// println!("output: {}", output);
/// }
/// ```
///
#[derive(Clone)]
pub struct LocalPoolHandle {
pool: Arc<LocalPool>,
Expand All @@ -33,6 +70,22 @@ impl LocalPoolHandle {
LocalPoolHandle { pool }
}

/// Returns the number of threads of the Pool.
#[inline]
pub fn num_threads(&self) -> usize {
self.pool.workers.len()
}

/// Returns the number of tasks scheduled on each worker. The indices of the
/// worker threads correspond to the indices of the returned `Vec`.
pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
self.pool
.workers
.iter()
.map(|worker| worker.task_count.load(Ordering::SeqCst))
.collect::<Vec<_>>()
}

/// Spawn a task onto a worker thread and pin it there so it can't be moved
/// off of the thread. Note that the future is not [`Send`], but the
/// [`FnOnce`] which creates it is.
Expand Down Expand Up @@ -69,7 +122,60 @@ impl LocalPoolHandle {
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool.spawn_pinned(create_task)
self.pool
.spawn_pinned(create_task, WorkerChoice::LeastBurdened)
}

/// Differs from `spawn_pinned` only in that you can choose a specific worker thread
/// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
/// number of tasks scheduled.
///
/// A worker thread is chosen by index. Indices are 0 based and the largest index
/// is given by `num_threads() - 1`
///
/// # Panics
///
/// This method panics if the index is out of bounds.
///
/// # Examples
///
/// This method can be used to spawn a task on all worker threads of the pool:
///
/// ```
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main]
/// async fn main() {
/// const NUM_WORKERS: usize = 3;
/// let pool = LocalPoolHandle::new(NUM_WORKERS);
/// let handles = (0..pool.num_threads())
/// .map(|worker_idx| {
/// pool.spawn_pinned_by_idx(
/// || {
/// async {
/// "test"
/// }
/// },
/// worker_idx,
/// )
/// })
/// .collect::<Vec<_>>();
///
/// for handle in handles {
/// handle.await.unwrap();
/// }
/// }
/// ```
///
pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool
.spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
}
}

Expand All @@ -79,22 +185,33 @@ impl Debug for LocalPoolHandle {
}
}

enum WorkerChoice {
LeastBurdened,
ByIdx(usize),
}

struct LocalPool {
workers: Vec<LocalWorkerHandle>,
}

impl LocalPool {
/// Spawn a `?Send` future onto a worker
fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
fn spawn_pinned<F, Fut>(
&self,
create_task: F,
worker_choice: WorkerChoice,
) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
let (sender, receiver) = oneshot::channel();

let (worker, job_guard) = self.find_and_incr_least_burdened_worker();
let (worker, job_guard) = match worker_choice {
WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
};
let worker_spawner = worker.spawner.clone();

// Spawn a future onto the worker's runtime so we can immediately return
Expand Down Expand Up @@ -206,6 +323,13 @@ impl LocalPool {
}
}
}

fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
let worker = &self.workers[idx];
worker.task_count.fetch_add(1, Ordering::SeqCst);

(worker, JobCountGuard(Arc::clone(&worker.task_count)))
}
}

/// Automatically decrements a worker's job count when a job finishes (when
Expand Down
43 changes: 43 additions & 0 deletions tokio-util/tests/spawn_pinned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::rc::Rc;
use std::sync::Arc;
use tokio::sync::Barrier;
use tokio_util::task;

/// Simple test of running a !Send future via spawn_pinned
Expand Down Expand Up @@ -191,3 +192,45 @@ async fn tasks_are_balanced() {
// be on separate workers/threads.
assert_ne!(thread_id1, thread_id2);
}

#[tokio::test]
async fn spawn_by_idx() {
let pool = task::LocalPoolHandle::new(3);
let barrier = Arc::new(Barrier::new(4));
let barrier1 = barrier.clone();
let barrier2 = barrier.clone();
let barrier3 = barrier.clone();

let handle1 = pool.spawn_pinned_by_idx(
|| async move {
barrier1.wait().await;
std::thread::current().id()
},
0,
);
let _ = pool.spawn_pinned_by_idx(
|| async move {
barrier2.wait().await;
std::thread::current().id()
},
0,
);
let handle2 = pool.spawn_pinned_by_idx(
|| async move {
barrier3.wait().await;
std::thread::current().id()
},
1,
);

let loads = pool.get_task_loads_for_each_worker();
barrier.wait().await;
assert_eq!(loads[0], 2);
assert_eq!(loads[1], 1);
assert_eq!(loads[2], 0);

let thread_id1 = handle1.await.unwrap();
let thread_id2 = handle2.await.unwrap();

assert_ne!(thread_id1, thread_id2);
}