-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
492: Add ThreadPool::broadcast r=cuviper a=cuviper A broadcast runs the closure on every thread in the pool, then collects the results. It's scheduled somewhat like a very soft interrupt -- it won't preempt a thread's local work, but will run before it goes to steal from any other threads. This can be used when you want to precisely split your work per-thread, or to set or retrieve some thread-local data in the pool, e.g. #483. Co-authored-by: Josh Stone <cuviper@gmail.com>
- Loading branch information
Showing
12 changed files
with
778 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
use crate::job::{ArcJob, StackJob}; | ||
use crate::registry::{Registry, WorkerThread}; | ||
use crate::scope::ScopeLatch; | ||
use std::fmt; | ||
use std::marker::PhantomData; | ||
use std::sync::Arc; | ||
|
||
mod test; | ||
|
||
/// Executes `op` within every thread in the current threadpool. If this is | ||
/// called from a non-Rayon thread, it will execute in the global threadpool. | ||
/// Any attempts to use `join`, `scope`, or parallel iterators will then operate | ||
/// within that threadpool. When the call has completed on each thread, returns | ||
/// a vector containing all of their return values. | ||
/// | ||
/// For more information, see the [`ThreadPool::broadcast()`][m] method. | ||
/// | ||
/// [m]: struct.ThreadPool.html#method.broadcast | ||
pub fn broadcast<OP, R>(op: OP) -> Vec<R> | ||
where | ||
OP: Fn(BroadcastContext<'_>) -> R + Sync, | ||
R: Send, | ||
{ | ||
// We assert that current registry has not terminated. | ||
unsafe { broadcast_in(op, &Registry::current()) } | ||
} | ||
|
||
/// Spawns an asynchronous task on every thread in this thread-pool. This task | ||
/// will run in the implicit, global scope, which means that it may outlast the | ||
/// current stack frame -- therefore, it cannot capture any references onto the | ||
/// stack (you will likely need a `move` closure). | ||
/// | ||
/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method. | ||
/// | ||
/// [m]: struct.ThreadPool.html#method.spawn_broadcast | ||
pub fn spawn_broadcast<OP>(op: OP) | ||
where | ||
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, | ||
{ | ||
// We assert that current registry has not terminated. | ||
unsafe { spawn_broadcast_in(op, &Registry::current()) } | ||
} | ||
|
||
/// Provides context to a closure called by `broadcast`. | ||
pub struct BroadcastContext<'a> { | ||
worker: &'a WorkerThread, | ||
|
||
/// Make sure to prevent auto-traits like `Send` and `Sync`. | ||
_marker: PhantomData<&'a mut dyn Fn()>, | ||
} | ||
|
||
impl<'a> BroadcastContext<'a> { | ||
pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R { | ||
let worker_thread = WorkerThread::current(); | ||
assert!(!worker_thread.is_null()); | ||
f(BroadcastContext { | ||
worker: unsafe { &*worker_thread }, | ||
_marker: PhantomData, | ||
}) | ||
} | ||
|
||
/// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`). | ||
#[inline] | ||
pub fn index(&self) -> usize { | ||
self.worker.index() | ||
} | ||
|
||
/// The number of threads receiving the broadcast in the thread pool. | ||
/// | ||
/// # Future compatibility note | ||
/// | ||
/// Future versions of Rayon might vary the number of threads over time, but | ||
/// this method will always return the number of threads which are actually | ||
/// receiving your particular `broadcast` call. | ||
#[inline] | ||
pub fn num_threads(&self) -> usize { | ||
self.worker.registry().num_threads() | ||
} | ||
} | ||
|
||
impl<'a> fmt::Debug for BroadcastContext<'a> { | ||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
fmt.debug_struct("BroadcastContext") | ||
.field("index", &self.index()) | ||
.field("num_threads", &self.num_threads()) | ||
.field("pool_id", &self.worker.registry().id()) | ||
.finish() | ||
} | ||
} | ||
|
||
/// Execute `op` on every thread in the pool. It will be executed on each | ||
/// thread when they have nothing else to do locally, before they try to | ||
/// steal work from other threads. This function will not return until all | ||
/// threads have completed the `op`. | ||
/// | ||
/// Unsafe because `registry` must not yet have terminated. | ||
pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R> | ||
where | ||
OP: Fn(BroadcastContext<'_>) -> R + Sync, | ||
R: Send, | ||
{ | ||
let f = move |injected: bool| { | ||
debug_assert!(injected); | ||
BroadcastContext::with(&op) | ||
}; | ||
|
||
let n_threads = registry.num_threads(); | ||
let current_thread = WorkerThread::current().as_ref(); | ||
let latch = ScopeLatch::with_count(n_threads, current_thread); | ||
let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, &latch)).collect(); | ||
let job_refs = jobs.iter().map(|job| job.as_job_ref()); | ||
|
||
registry.inject_broadcast(job_refs); | ||
|
||
// Wait for all jobs to complete, then collect the results, maybe propagating a panic. | ||
latch.wait(current_thread); | ||
jobs.into_iter().map(|job| job.into_result()).collect() | ||
} | ||
|
||
/// Execute `op` on every thread in the pool. It will be executed on each | ||
/// thread when they have nothing else to do locally, before they try to | ||
/// steal work from other threads. This function returns immediately after | ||
/// injecting the jobs. | ||
/// | ||
/// Unsafe because `registry` must not yet have terminated. | ||
pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>) | ||
where | ||
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, | ||
{ | ||
let job = ArcJob::new({ | ||
let registry = Arc::clone(registry); | ||
move || { | ||
registry.catch_unwind(|| BroadcastContext::with(&op)); | ||
registry.terminate(); // (*) permit registry to terminate now | ||
} | ||
}); | ||
|
||
let n_threads = registry.num_threads(); | ||
let job_refs = (0..n_threads).map(|_| { | ||
// Ensure that registry cannot terminate until this job has executed | ||
// on each thread. This ref is decremented at the (*) above. | ||
registry.increment_terminate_count(); | ||
|
||
ArcJob::as_static_job_ref(&job) | ||
}); | ||
|
||
registry.inject_broadcast(job_refs); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
#![cfg(test)] | ||
|
||
use crate::ThreadPoolBuilder; | ||
use std::sync::atomic::{AtomicUsize, Ordering}; | ||
use std::sync::Arc; | ||
use std::{thread, time}; | ||
|
||
#[test] | ||
fn broadcast_global() { | ||
let v = crate::broadcast(|ctx| ctx.index()); | ||
assert!(v.into_iter().eq(0..crate::current_num_threads())); | ||
} | ||
|
||
#[test] | ||
fn spawn_broadcast_global() { | ||
let (tx, rx) = crossbeam_channel::unbounded(); | ||
crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()); | ||
|
||
let mut v: Vec<_> = rx.into_iter().collect(); | ||
v.sort_unstable(); | ||
assert!(v.into_iter().eq(0..crate::current_num_threads())); | ||
} | ||
|
||
#[test] | ||
fn broadcast_pool() { | ||
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
let v = pool.broadcast(|ctx| ctx.index()); | ||
assert!(v.into_iter().eq(0..7)); | ||
} | ||
|
||
#[test] | ||
fn spawn_broadcast_pool() { | ||
let (tx, rx) = crossbeam_channel::unbounded(); | ||
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
pool.spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()); | ||
|
||
let mut v: Vec<_> = rx.into_iter().collect(); | ||
v.sort_unstable(); | ||
assert!(v.into_iter().eq(0..7)); | ||
} | ||
|
||
#[test] | ||
fn broadcast_self() { | ||
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
let v = pool.install(|| crate::broadcast(|ctx| ctx.index())); | ||
assert!(v.into_iter().eq(0..7)); | ||
} | ||
|
||
#[test] | ||
fn spawn_broadcast_self() { | ||
let (tx, rx) = crossbeam_channel::unbounded(); | ||
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
pool.spawn(|| crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap())); | ||
|
||
let mut v: Vec<_> = rx.into_iter().collect(); | ||
v.sort_unstable(); | ||
assert!(v.into_iter().eq(0..7)); | ||
} | ||
|
||
#[test] | ||
fn broadcast_mutual() { | ||
let count = AtomicUsize::new(0); | ||
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap(); | ||
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
pool1.install(|| { | ||
pool2.broadcast(|_| { | ||
pool1.broadcast(|_| { | ||
count.fetch_add(1, Ordering::Relaxed); | ||
}) | ||
}) | ||
}); | ||
assert_eq!(count.into_inner(), 3 * 7); | ||
} | ||
|
||
#[test] | ||
fn spawn_broadcast_mutual() { | ||
let (tx, rx) = crossbeam_channel::unbounded(); | ||
let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap()); | ||
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
pool1.spawn({ | ||
let pool1 = Arc::clone(&pool1); | ||
move || { | ||
pool2.spawn_broadcast(move |_| { | ||
let tx = tx.clone(); | ||
pool1.spawn_broadcast(move |_| tx.send(()).unwrap()) | ||
}) | ||
} | ||
}); | ||
assert_eq!(rx.into_iter().count(), 3 * 7); | ||
} | ||
|
||
#[test] | ||
fn broadcast_mutual_sleepy() { | ||
let count = AtomicUsize::new(0); | ||
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap(); | ||
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
pool1.install(|| { | ||
thread::sleep(time::Duration::from_secs(1)); | ||
pool2.broadcast(|_| { | ||
thread::sleep(time::Duration::from_secs(1)); | ||
pool1.broadcast(|_| { | ||
thread::sleep(time::Duration::from_millis(100)); | ||
count.fetch_add(1, Ordering::Relaxed); | ||
}) | ||
}) | ||
}); | ||
assert_eq!(count.into_inner(), 3 * 7); | ||
} | ||
|
||
#[test] | ||
fn spawn_broadcast_mutual_sleepy() { | ||
let (tx, rx) = crossbeam_channel::unbounded(); | ||
let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap()); | ||
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
pool1.spawn({ | ||
let pool1 = Arc::clone(&pool1); | ||
move || { | ||
thread::sleep(time::Duration::from_secs(1)); | ||
pool2.spawn_broadcast(move |_| { | ||
let tx = tx.clone(); | ||
thread::sleep(time::Duration::from_secs(1)); | ||
pool1.spawn_broadcast(move |_| { | ||
thread::sleep(time::Duration::from_millis(100)); | ||
tx.send(()).unwrap(); | ||
}) | ||
}) | ||
} | ||
}); | ||
assert_eq!(rx.into_iter().count(), 3 * 7); | ||
} | ||
|
||
#[test] | ||
fn broadcast_panic_one() { | ||
let count = AtomicUsize::new(0); | ||
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
let result = crate::unwind::halt_unwinding(|| { | ||
pool.broadcast(|ctx| { | ||
count.fetch_add(1, Ordering::Relaxed); | ||
if ctx.index() == 3 { | ||
panic!("Hello, world!"); | ||
} | ||
}) | ||
}); | ||
assert_eq!(count.into_inner(), 7); | ||
assert!(result.is_err(), "broadcast panic should propagate!"); | ||
} | ||
|
||
#[test] | ||
fn spawn_broadcast_panic_one() { | ||
let (tx, rx) = crossbeam_channel::unbounded(); | ||
let (panic_tx, panic_rx) = crossbeam_channel::unbounded(); | ||
let pool = ThreadPoolBuilder::new() | ||
.num_threads(7) | ||
.panic_handler(move |e| panic_tx.send(e).unwrap()) | ||
.build() | ||
.unwrap(); | ||
pool.spawn_broadcast(move |ctx| { | ||
tx.send(()).unwrap(); | ||
if ctx.index() == 3 { | ||
panic!("Hello, world!"); | ||
} | ||
}); | ||
drop(pool); // including panic_tx | ||
assert_eq!(rx.into_iter().count(), 7); | ||
assert_eq!(panic_rx.into_iter().count(), 1); | ||
} | ||
|
||
#[test] | ||
fn broadcast_panic_many() { | ||
let count = AtomicUsize::new(0); | ||
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); | ||
let result = crate::unwind::halt_unwinding(|| { | ||
pool.broadcast(|ctx| { | ||
count.fetch_add(1, Ordering::Relaxed); | ||
if ctx.index() % 2 == 0 { | ||
panic!("Hello, world!"); | ||
} | ||
}) | ||
}); | ||
assert_eq!(count.into_inner(), 7); | ||
assert!(result.is_err(), "broadcast panic should propagate!"); | ||
} | ||
|
||
#[test] | ||
fn spawn_broadcast_panic_many() { | ||
let (tx, rx) = crossbeam_channel::unbounded(); | ||
let (panic_tx, panic_rx) = crossbeam_channel::unbounded(); | ||
let pool = ThreadPoolBuilder::new() | ||
.num_threads(7) | ||
.panic_handler(move |e| panic_tx.send(e).unwrap()) | ||
.build() | ||
.unwrap(); | ||
pool.spawn_broadcast(move |ctx| { | ||
tx.send(()).unwrap(); | ||
if ctx.index() % 2 == 0 { | ||
panic!("Hello, world!"); | ||
} | ||
}); | ||
drop(pool); // including panic_tx | ||
assert_eq!(rx.into_iter().count(), 7); | ||
assert_eq!(panic_rx.into_iter().count(), 4); | ||
} |
Oops, something went wrong.