Skip to content

Commit

Permalink
Merge #492
Browse files Browse the repository at this point in the history
492: Add ThreadPool::broadcast r=cuviper a=cuviper

A broadcast runs the closure on every thread in the pool, then collects
the results.  It's scheduled somewhat like a very soft interrupt -- it
won't preempt a thread's local work, but will run before it goes to
steal from any other threads.

This can be used when you want to precisely split your work per-thread,
or to set or retrieve some thread-local data in the pool, e.g. #483.

Co-authored-by: Josh Stone <cuviper@gmail.com>
  • Loading branch information
bors[bot] and cuviper authored Nov 16, 2022
2 parents 274499a + 9ef85cd commit 911d6d0
Show file tree
Hide file tree
Showing 12 changed files with 778 additions and 86 deletions.
148 changes: 148 additions & 0 deletions rayon-core/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use crate::job::{ArcJob, StackJob};
use crate::registry::{Registry, WorkerThread};
use crate::scope::ScopeLatch;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;

mod test;

/// Executes `op` within every thread in the current threadpool. If this is
/// called from a non-Rayon thread, it will execute in the global threadpool.
/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
/// within that threadpool. When the call has completed on each thread, returns
/// a vector containing all of their return values.
///
/// For more information, see the [`ThreadPool::broadcast()`][m] method.
///
/// [m]: struct.ThreadPool.html#method.broadcast
pub fn broadcast<OP, R>(op: OP) -> Vec<R>
where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
// We assert that current registry has not terminated.
unsafe { broadcast_in(op, &Registry::current()) }
}

/// Spawns an asynchronous task on every thread in this thread-pool. This task
/// will run in the implicit, global scope, which means that it may outlast the
/// current stack frame -- therefore, it cannot capture any references onto the
/// stack (you will likely need a `move` closure).
///
/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
///
/// [m]: struct.ThreadPool.html#method.spawn_broadcast
pub fn spawn_broadcast<OP>(op: OP)
where
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
// We assert that current registry has not terminated.
unsafe { spawn_broadcast_in(op, &Registry::current()) }
}

/// Provides context to a closure called by `broadcast`.
pub struct BroadcastContext<'a> {
worker: &'a WorkerThread,

/// Make sure to prevent auto-traits like `Send` and `Sync`.
_marker: PhantomData<&'a mut dyn Fn()>,
}

impl<'a> BroadcastContext<'a> {
pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
let worker_thread = WorkerThread::current();
assert!(!worker_thread.is_null());
f(BroadcastContext {
worker: unsafe { &*worker_thread },
_marker: PhantomData,
})
}

/// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
#[inline]
pub fn index(&self) -> usize {
self.worker.index()
}

/// The number of threads receiving the broadcast in the thread pool.
///
/// # Future compatibility note
///
/// Future versions of Rayon might vary the number of threads over time, but
/// this method will always return the number of threads which are actually
/// receiving your particular `broadcast` call.
#[inline]
pub fn num_threads(&self) -> usize {
self.worker.registry().num_threads()
}
}

impl<'a> fmt::Debug for BroadcastContext<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("BroadcastContext")
.field("index", &self.index())
.field("num_threads", &self.num_threads())
.field("pool_id", &self.worker.registry().id())
.finish()
}
}

/// Execute `op` on every thread in the pool. It will be executed on each
/// thread when they have nothing else to do locally, before they try to
/// steal work from other threads. This function will not return until all
/// threads have completed the `op`.
///
/// Unsafe because `registry` must not yet have terminated.
pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
let f = move |injected: bool| {
debug_assert!(injected);
BroadcastContext::with(&op)
};

let n_threads = registry.num_threads();
let current_thread = WorkerThread::current().as_ref();
let latch = ScopeLatch::with_count(n_threads, current_thread);
let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, &latch)).collect();
let job_refs = jobs.iter().map(|job| job.as_job_ref());

registry.inject_broadcast(job_refs);

// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
latch.wait(current_thread);
jobs.into_iter().map(|job| job.into_result()).collect()
}

/// Execute `op` on every thread in the pool. It will be executed on each
/// thread when they have nothing else to do locally, before they try to
/// steal work from other threads. This function returns immediately after
/// injecting the jobs.
///
/// Unsafe because `registry` must not yet have terminated.
pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
where
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
let job = ArcJob::new({
let registry = Arc::clone(registry);
move || {
registry.catch_unwind(|| BroadcastContext::with(&op));
registry.terminate(); // (*) permit registry to terminate now
}
});

let n_threads = registry.num_threads();
let job_refs = (0..n_threads).map(|_| {
// Ensure that registry cannot terminate until this job has executed
// on each thread. This ref is decremented at the (*) above.
registry.increment_terminate_count();

ArcJob::as_static_job_ref(&job)
});

registry.inject_broadcast(job_refs);
}
202 changes: 202 additions & 0 deletions rayon-core/src/broadcast/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#![cfg(test)]

use crate::ThreadPoolBuilder;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::{thread, time};

#[test]
fn broadcast_global() {
let v = crate::broadcast(|ctx| ctx.index());
assert!(v.into_iter().eq(0..crate::current_num_threads()));
}

#[test]
fn spawn_broadcast_global() {
let (tx, rx) = crossbeam_channel::unbounded();
crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap());

let mut v: Vec<_> = rx.into_iter().collect();
v.sort_unstable();
assert!(v.into_iter().eq(0..crate::current_num_threads()));
}

#[test]
fn broadcast_pool() {
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let v = pool.broadcast(|ctx| ctx.index());
assert!(v.into_iter().eq(0..7));
}

#[test]
fn spawn_broadcast_pool() {
let (tx, rx) = crossbeam_channel::unbounded();
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool.spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap());

let mut v: Vec<_> = rx.into_iter().collect();
v.sort_unstable();
assert!(v.into_iter().eq(0..7));
}

#[test]
fn broadcast_self() {
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let v = pool.install(|| crate::broadcast(|ctx| ctx.index()));
assert!(v.into_iter().eq(0..7));
}

#[test]
fn spawn_broadcast_self() {
let (tx, rx) = crossbeam_channel::unbounded();
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool.spawn(|| crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()));

let mut v: Vec<_> = rx.into_iter().collect();
v.sort_unstable();
assert!(v.into_iter().eq(0..7));
}

#[test]
fn broadcast_mutual() {
let count = AtomicUsize::new(0);
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.install(|| {
pool2.broadcast(|_| {
pool1.broadcast(|_| {
count.fetch_add(1, Ordering::Relaxed);
})
})
});
assert_eq!(count.into_inner(), 3 * 7);
}

#[test]
fn spawn_broadcast_mutual() {
let (tx, rx) = crossbeam_channel::unbounded();
let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap());
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.spawn({
let pool1 = Arc::clone(&pool1);
move || {
pool2.spawn_broadcast(move |_| {
let tx = tx.clone();
pool1.spawn_broadcast(move |_| tx.send(()).unwrap())
})
}
});
assert_eq!(rx.into_iter().count(), 3 * 7);
}

#[test]
fn broadcast_mutual_sleepy() {
let count = AtomicUsize::new(0);
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.install(|| {
thread::sleep(time::Duration::from_secs(1));
pool2.broadcast(|_| {
thread::sleep(time::Duration::from_secs(1));
pool1.broadcast(|_| {
thread::sleep(time::Duration::from_millis(100));
count.fetch_add(1, Ordering::Relaxed);
})
})
});
assert_eq!(count.into_inner(), 3 * 7);
}

#[test]
fn spawn_broadcast_mutual_sleepy() {
let (tx, rx) = crossbeam_channel::unbounded();
let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap());
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.spawn({
let pool1 = Arc::clone(&pool1);
move || {
thread::sleep(time::Duration::from_secs(1));
pool2.spawn_broadcast(move |_| {
let tx = tx.clone();
thread::sleep(time::Duration::from_secs(1));
pool1.spawn_broadcast(move |_| {
thread::sleep(time::Duration::from_millis(100));
tx.send(()).unwrap();
})
})
}
});
assert_eq!(rx.into_iter().count(), 3 * 7);
}

#[test]
fn broadcast_panic_one() {
let count = AtomicUsize::new(0);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let result = crate::unwind::halt_unwinding(|| {
pool.broadcast(|ctx| {
count.fetch_add(1, Ordering::Relaxed);
if ctx.index() == 3 {
panic!("Hello, world!");
}
})
});
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}

#[test]
fn spawn_broadcast_panic_one() {
let (tx, rx) = crossbeam_channel::unbounded();
let (panic_tx, panic_rx) = crossbeam_channel::unbounded();
let pool = ThreadPoolBuilder::new()
.num_threads(7)
.panic_handler(move |e| panic_tx.send(e).unwrap())
.build()
.unwrap();
pool.spawn_broadcast(move |ctx| {
tx.send(()).unwrap();
if ctx.index() == 3 {
panic!("Hello, world!");
}
});
drop(pool); // including panic_tx
assert_eq!(rx.into_iter().count(), 7);
assert_eq!(panic_rx.into_iter().count(), 1);
}

#[test]
fn broadcast_panic_many() {
let count = AtomicUsize::new(0);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let result = crate::unwind::halt_unwinding(|| {
pool.broadcast(|ctx| {
count.fetch_add(1, Ordering::Relaxed);
if ctx.index() % 2 == 0 {
panic!("Hello, world!");
}
})
});
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}

#[test]
fn spawn_broadcast_panic_many() {
let (tx, rx) = crossbeam_channel::unbounded();
let (panic_tx, panic_rx) = crossbeam_channel::unbounded();
let pool = ThreadPoolBuilder::new()
.num_threads(7)
.panic_handler(move |e| panic_tx.send(e).unwrap())
.build()
.unwrap();
pool.spawn_broadcast(move |ctx| {
tx.send(()).unwrap();
if ctx.index() % 2 == 0 {
panic!("Hello, world!");
}
});
drop(pool); // including panic_tx
assert_eq!(rx.into_iter().count(), 7);
assert_eq!(panic_rx.into_iter().count(), 4);
}
Loading

0 comments on commit 911d6d0

Please sign in to comment.