Skip to content

Commit

Permalink
remove IdxError and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
b-naber committed May 17, 2022
1 parent 2d8f30c commit 71c63c0
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 43 deletions.
2 changes: 1 addition & 1 deletion tokio-util/src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#[cfg(tokio_unstable)]
mod join_map;
mod spawn_pinned;
pub use spawn_pinned::{LocalPoolHandle, WorkerIdxError};
pub use spawn_pinned::LocalPoolHandle;

#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))]
Expand Down
42 changes: 4 additions & 38 deletions tokio-util/src/task/spawn_pinned.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use futures_util::future::{AbortHandle, Abortable};
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::future::Future;
Expand All @@ -10,27 +9,6 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::task::{spawn_local, JoinHandle, LocalSet};

/// Error Type for out-of-bounds indexing error in [`LocalPoolHandle::spawn_pinned_by_idx`].
///
/// [`LocalPoolHandle::spawn_pinned_by_idx`]: LocalPoolHandle::spawn_pinned_by_idx
#[derive(Debug)]
pub struct WorkerIdxError {
idx: usize,
num_workers: usize,
}

impl fmt::Display for WorkerIdxError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Index {} out of bounds, only {} workers in pool",
self.idx, self.num_workers
)
}
}

impl Error for WorkerIdxError {}

/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
///
/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
Expand Down Expand Up @@ -156,27 +134,15 @@ impl LocalPoolHandle {
/// is given by `num_threads() - 1`
///
/// Returns a `WorkerIdxError` if the provided index is out of bounds.
pub fn spawn_pinned_by_idx<F, Fut>(
&self,
create_task: F,
idx: usize,
) -> Result<JoinHandle<Fut::Output>, WorkerIdxError>
pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
if idx >= self.pool.workers.len() {
return Err(WorkerIdxError {
idx,
num_workers: self.pool.workers.len(),
});
}

Ok(self
.pool
.spawn_pinned(create_task, WorkerChoice::ByIdx(idx)))
self.pool
.spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
}

/// Spawn a task on every worker thread in the pool and pin it so that it
Expand Down Expand Up @@ -212,7 +178,7 @@ impl LocalPoolHandle {
Fut::Output: Send + 'static,
{
(0..self.pool.workers.len())
.map(|idx| self.spawn_pinned_by_idx(create_task.clone(), idx).unwrap())
.map(|idx| self.spawn_pinned_by_idx(create_task.clone(), idx))
.collect::<Vec<_>>()
}
}
Expand Down
12 changes: 8 additions & 4 deletions tokio-util/tests/spawn_pinned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async fn tasks_are_balanced() {
#[tokio::test]
async fn spawn_by_idx() {
let pool = task::LocalPoolHandle::new(3);
let barrier = Arc::new(Barrier::new(3));
let barrier = Arc::new(Barrier::new(4));
let barrier1 = barrier.clone();
let barrier2 = barrier.clone();
let barrier3 = barrier.clone();
Expand Down Expand Up @@ -229,8 +229,8 @@ async fn spawn_by_idx() {
assert_eq!(loads[1], 1);
assert_eq!(loads[2], 0);

let thread_id1 = handle1.unwrap().await.unwrap();
let thread_id2 = handle2.unwrap().await.unwrap();
let thread_id1 = handle1.await.unwrap();
let thread_id2 = handle2.await.unwrap();

assert_ne!(thread_id1, thread_id2);
}
Expand All @@ -242,7 +242,7 @@ async fn spawn_on_all_workers() {
let barrier = Arc::new(Barrier::new(NUM_WORKERS + 1));
let barrier_clone = barrier.clone();

let _ = pool.spawn_pinned_on_all_workers(|| async move {
let handles = pool.spawn_pinned_on_all_workers(|| async move {
barrier_clone.wait().await;

"test"
Expand All @@ -253,4 +253,8 @@ async fn spawn_on_all_workers() {
assert_eq!(loads[0], 1);
assert_eq!(loads[1], 1);
assert_eq!(loads[2], 1);

let _ = handles
.into_iter()
.map(|handle| async { handle.await.unwrap() });
}

0 comments on commit 71c63c0

Please sign in to comment.