From 00db2aefb4734e1a751f7f3dac28373323618f19 Mon Sep 17 00:00:00 2001 From: Guillaume Endignoux Date: Fri, 24 Nov 2023 16:42:28 +0100 Subject: [PATCH] Make the custom thread pool generic over the input and output types. --- src/meek.rs | 12 +- src/parallelism/mod.rs | 79 +++++++- src/parallelism/thread_pool.rs | 178 ++++++++---------- src/vote_count.rs | 121 +++++++++++- .../bigint/ballot_sum_overflows.bigfixed9.err | 2 +- .../ballot_sum_overflows.fixed9.custom.err | 2 +- 6 files changed, 285 insertions(+), 109 deletions(-) diff --git a/src/meek.rs b/src/meek.rs index af525f0..8bca663 100644 --- a/src/meek.rs +++ b/src/meek.rs @@ -17,9 +17,9 @@ use crate::arithmetic::{Integer, IntegerRef, Rational, RationalRef}; use crate::cli::Parallel; -use crate::parallelism::{RangeStrategy, ThreadPool}; +use crate::parallelism::RangeStrategy; use crate::types::{Election, ElectionResult}; -use crate::vote_count::{self, VoteCount}; +use crate::vote_count::{self, VoteCount, VoteCountingThreadPool}; use log::{debug, info, log_enabled, warn, Level}; use std::fmt::{self, Debug, Display}; use std::io; @@ -111,7 +111,7 @@ where .expect("A positive number of threads must be spawned, but the available parallelism is zero threads") }); info!("Spawning {num_threads} threads"); - let thread_pool = ThreadPool::new( + let thread_pool = VoteCountingThreadPool::new( thread_scope, num_threads, if disable_work_stealing { @@ -200,7 +200,9 @@ pub struct State<'scope, 'e, I, R> { /// Pre-computed Pascal triangle. Set only if the "equalized counting" is /// enabled. pascal: Option<&'e [Vec]>, - thread_pool: Option>, + /// Thread pool where to schedule vote counting. Set only if the parallel + /// mode is set to "custom". + thread_pool: Option>, _phantom: PhantomData, } @@ -224,7 +226,7 @@ where parallel: Parallel, force_positive_surplus: bool, pascal: Option<&'a [Vec]>, - thread_pool: Option>, + thread_pool: Option>, ) -> State<'scope, 'a, I, R> { State { election, diff --git a/src/parallelism/mod.rs b/src/parallelism/mod.rs index c667da3..ad92908 100644 --- a/src/parallelism/mod.rs +++ b/src/parallelism/mod.rs @@ -17,4 +17,81 @@ mod range; mod thread_pool; -pub use thread_pool::{RangeStrategy, ThreadPool}; +pub use thread_pool::{RangeStrategy, ThreadAccumulator, ThreadPool}; + +#[cfg(test)] +mod test { + use super::*; + use std::num::NonZeroUsize; + + /// Example of accumulator that computes a sum of integers. + struct SumAccumulator; + + impl ThreadAccumulator for SumAccumulator { + type Accumulator<'a> = u64; + + fn init(&self) -> u64 { + 0 + } + + fn process_item(&self, accumulator: &mut u64, _index: usize, x: &u64) { + *accumulator += *x; + } + + fn finalize(&self, accumulator: u64) -> u64 { + accumulator + } + } + + macro_rules! parallelism_tests { + ( $mod:ident, $range_strategy:expr, $($case:ident,)+ ) => { + mod $mod { + use super::*; + + $( + #[test] + fn $case() { + $crate::parallelism::test::$case($range_strategy); + } + )+ + } + }; + } + + macro_rules! all_parallelism_tests { + ( $mod:ident, $range_strategy:expr ) => { + parallelism_tests!($mod, $range_strategy, test_sum_integers, test_sum_twice,); + }; + } + + all_parallelism_tests!(fixed, RangeStrategy::Fixed); + all_parallelism_tests!(work_stealing, RangeStrategy::WorkStealing); + + fn test_sum_integers(range_strategy: RangeStrategy) { + let input = (0..=10_000).collect::>(); + let num_threads = NonZeroUsize::try_from(4).unwrap(); + let sum = std::thread::scope(|scope| { + let thread_pool = ThreadPool::new(scope, num_threads, range_strategy, &input, || { + SumAccumulator + }); + thread_pool.process_inputs().reduce(|a, b| a + b).unwrap() + }); + assert_eq!(sum, 5_000 * 10_001); + } + + fn test_sum_twice(range_strategy: RangeStrategy) { + let input = (0..=10_000).collect::>(); + let num_threads = NonZeroUsize::try_from(4).unwrap(); + let (sum1, sum2) = std::thread::scope(|scope| { + let thread_pool = ThreadPool::new(scope, num_threads, range_strategy, &input, || { + SumAccumulator + }); + // The same input can be processed multiple times on the thread pool. + let sum1 = thread_pool.process_inputs().reduce(|a, b| a + b).unwrap(); + let sum2 = thread_pool.process_inputs().reduce(|a, b| a + b).unwrap(); + (sum1, sum2) + }); + assert_eq!(sum1, 5_000 * 10_001); + assert_eq!(sum2, 5_000 * 10_001); + } +} diff --git a/src/parallelism/thread_pool.rs b/src/parallelism/thread_pool.rs index 5dbb457..e7ee72c 100644 --- a/src/parallelism/thread_pool.rs +++ b/src/parallelism/thread_pool.rs @@ -14,12 +14,9 @@ //! A hand-rolled thread pool, customized for the vote counting problem. -use crate::arithmetic::{Integer, IntegerRef, Rational, RationalRef}; -use crate::parallelism::range::{ +use super::range::{ FixedRangeFactory, Range, RangeFactory, RangeOrchestrator, WorkStealingRangeFactory, }; -use crate::types::Election; -use crate::vote_count::{VoteAccumulator, VoteCount}; use log::{debug, error, warn}; // Platforms that support `libc::sched_setaffinity()`. #[cfg(any( @@ -34,9 +31,8 @@ use nix::{ }; use std::cell::Cell; use std::num::NonZeroUsize; -use std::ops::DerefMut; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError, RwLock}; +use std::sync::{Arc, Condvar, Mutex, MutexGuard, PoisonError}; use std::thread::{Scope, ScopedJoinHandle}; /// Status of the main thread. @@ -129,10 +125,11 @@ impl Status { } } -/// A thread pool tied to a scope, that can perform vote counting rounds. -pub struct ThreadPool<'scope, I, R> { - /// Handles to all the threads in the pool. - threads: Vec>, +/// A thread pool tied to a scope, that can process inputs into the given output +/// type. +pub struct ThreadPool<'scope, Output> { + /// Handles to all the worker threads in the pool. + threads: Vec>, /// Number of worker threads active in the current round. num_active_threads: Arc, /// Color of the current round. @@ -141,21 +138,18 @@ pub struct ThreadPool<'scope, I, R> { worker_status: Arc>, /// Status of the main thread. main_status: Arc>, - /// Storage for the keep factors, used as input of the current round by the - /// worker threads. - keep_factors: Arc>>, /// Orchestrator for the work ranges distributed to the threads. This is a /// dynamic object to avoid making the range type a parameter of /// everything. range_orchestrator: Box, } -/// Handle to a thread in the pool. -struct Thread<'scope, I, R> { +/// Handle to a worker thread in the pool. +struct WorkerThreadHandle<'scope, Output> { /// Thread handle object. handle: ScopedJoinHandle<'scope, ()>, /// Storage for this thread's computation output. - output: Arc>>>, + output: Arc>>, } /// Strategy to distribute ranges of work items among threads. @@ -166,48 +160,47 @@ pub enum RangeStrategy { WorkStealing, } -impl<'scope, I, R> ThreadPool<'scope, I, R> -where - I: Integer + Send + Sync + 'scope, - for<'a> &'a I: IntegerRef, - R: Rational + Send + Sync + 'scope, - for<'a> &'a R: RationalRef<&'a I, R>, -{ - /// Creates a new pool tied to the given scope, with the given number of - /// threads and references to the necessary election inputs. - pub fn new<'e>( - thread_scope: &'scope Scope<'scope, 'e>, +impl<'scope, Output: Send + 'scope> ThreadPool<'scope, Output> { + /// Creates a new pool tied to the given scope, spawning the given number of + /// threads and using the given input slice. + pub fn new<'env, Input: Sync, Accum: ThreadAccumulator + Send + 'scope>( + thread_scope: &'scope Scope<'scope, 'env>, num_threads: NonZeroUsize, range_strategy: RangeStrategy, - election: &'e Election, - pascal: Option<&'e [Vec]>, + input: &'env [Input], + new_accumulator: impl Fn() -> Accum, ) -> Self { let num_threads: usize = num_threads.into(); - let num_ballots = election.ballots.len(); + let input_len = input.len(); match range_strategy { RangeStrategy::Fixed => Self::new_with_factory( thread_scope, num_threads, - FixedRangeFactory::new(num_ballots, num_threads), - election, - pascal, + FixedRangeFactory::new(input_len, num_threads), + input, + new_accumulator, ), RangeStrategy::WorkStealing => Self::new_with_factory( thread_scope, num_threads, - WorkStealingRangeFactory::new(num_ballots, num_threads), - election, - pascal, + WorkStealingRangeFactory::new(input_len, num_threads), + input, + new_accumulator, ), } } - fn new_with_factory<'e, RnFactory: RangeFactory>( - thread_scope: &'scope Scope<'scope, 'e>, + fn new_with_factory< + 'env, + RnFactory: RangeFactory, + Input: Sync, + Accum: ThreadAccumulator + Send + 'scope, + >( + thread_scope: &'scope Scope<'scope, 'env>, num_threads: usize, range_factory: RnFactory, - election: &'e Election, - pascal: Option<&'e [Vec]>, + input: &'env [Input], + new_accumulator: impl Fn() -> Accum, ) -> Self where RnFactory::Rn: 'scope + Send, @@ -217,7 +210,6 @@ where let num_active_threads = Arc::new(AtomicUsize::new(0)); let worker_status = Arc::new(Status::new(WorkerStatus::Round(color))); let main_status = Arc::new(Status::new(MainStatus::Waiting)); - let keep_factors = Arc::new(RwLock::new(Vec::new())); #[cfg(not(any( target_os = "android", @@ -234,13 +226,12 @@ where num_active_threads: num_active_threads.clone(), worker_status: worker_status.clone(), main_status: main_status.clone(), - election, - pascal, - keep_factors: keep_factors.clone(), range: range_factory.range(id), + input, output: output.clone(), + accumulator: new_accumulator(), }; - Thread { + WorkerThreadHandle { handle: thread_scope.spawn(move || { #[cfg(any( target_os = "android", @@ -266,25 +257,19 @@ where .collect(); debug!("[main thread] Spawned threads"); - ThreadPool { + Self { threads, num_active_threads, round: Cell::new(color), worker_status, main_status, - keep_factors, range_orchestrator: Box::new(range_factory.orchestrator()), } } - /// Accumulates votes from the election ballots based on the given keep - /// factors. - pub fn accumulate_votes(&self, keep_factors: &[R]) -> VoteAccumulator { - { - let mut keep_factors_guard = self.keep_factors.write().unwrap(); - keep_factors_guard.clear(); - keep_factors_guard.extend_from_slice(keep_factors); - } + /// Performs a computation round, processing the input slice in parallel and + /// returning an iterator over the threads' outputs. + pub fn process_inputs(&self) -> impl Iterator + '_ { self.range_orchestrator.reset_ranges(); let num_threads = self.threads.len(); @@ -316,13 +301,11 @@ where self.threads .iter() - .map(|t| -> VoteAccumulator { t.output.lock().unwrap().take().unwrap() }) - .reduce(|a, b| a.reduce(b)) - .unwrap() + .map(move |t| t.output.lock().unwrap().take().unwrap()) } } -impl Drop for ThreadPool<'_, I, R> { +impl Drop for ThreadPool<'_, Output> { /// Joins all the threads in the pool. fn drop(&mut self) { debug!("[main thread] Notifying threads to finish..."); @@ -343,8 +326,30 @@ impl Drop for ThreadPool<'_, I, R> { } } +/// Trait representing a function to map and reduce inputs into an output. +pub trait ThreadAccumulator { + /// Type to accumulate inputs into. + type Accumulator<'a> + where + Self: 'a; + + /// Creates a new accumulator to process inputs. + fn init(&self) -> Self::Accumulator<'_>; + + /// Accumulates the given input item. + fn process_item<'a>( + &'a self, + accumulator: &mut Self::Accumulator<'a>, + index: usize, + item: &Input, + ); + + /// Converts the given accumulator into an output. + fn finalize<'a>(&'a self, accumulator: Self::Accumulator<'a>) -> Output; +} + /// Context object owned by a worker thread. -struct ThreadContext<'e, I, R, Rn: Range> { +struct ThreadContext<'env, Rn: Range, Input, Output, Accum: ThreadAccumulator> { /// Thread index. id: usize, /// Number of worker threads active in the current round. @@ -353,24 +358,18 @@ struct ThreadContext<'e, I, R, Rn: Range> { worker_status: Arc>, /// Status of the main thread. main_status: Arc>, - /// Election input. - election: &'e Election, - /// Pre-computed Pascal triangle. - pascal: Option<&'e [Vec]>, - /// Keep factors used in the current round. - keep_factors: Arc>>, - /// Range of ballots that this worker thread needs to count. + /// Range of items that this worker thread needs to process. range: Rn, - /// Storage for the votes accumulated by this thread. - output: Arc>>>, + /// Reference to the inputs to process. + input: &'env [Input], + /// Output that this thread writes to. + output: Arc>>, + /// Function to map and reduce inputs into the output. + accumulator: Accum, } -impl ThreadContext<'_, I, R, Rn> -where - I: Integer, - for<'a> &'a I: IntegerRef, - R: Rational, - for<'a> &'a R: RationalRef<&'a I, R>, +impl> + ThreadContext<'_, Rn, Input, Output, Accum> { /// Main function run by this thread. fn run(&self) { @@ -408,7 +407,14 @@ where id: self.id, main_status: &self.main_status, }; - self.count_votes(); + { + let mut accumulator = self.accumulator.init(); + for i in self.range.iter() { + self.accumulator + .process_item(&mut accumulator, i, &self.input[i]); + } + *self.output.lock().unwrap() = Some(self.accumulator.finalize(accumulator)); + } std::mem::forget(panic_notifier); let thread_count = self.num_active_threads.fetch_sub(1, Ordering::SeqCst); @@ -444,26 +450,6 @@ where } } } - - /// Computes a vote counting round. - fn count_votes(&self) { - let mut guard = self.output.lock().unwrap(); - let vote_accumulator: &mut VoteAccumulator = guard - .deref_mut() - .insert(VoteAccumulator::new(self.election.num_candidates)); - let keep_factors = self.keep_factors.read().unwrap(); - - for i in self.range.iter() { - let ballot = &self.election.ballots[i]; - VoteCount::::process_ballot( - vote_accumulator, - &keep_factors, - self.pascal, - i, - ballot, - ); - } - } } /// Object whose destructor notifies the main thread that a panic happened. diff --git a/src/vote_count.rs b/src/vote_count.rs index 82ddd09..d56f576 100644 --- a/src/vote_count.rs +++ b/src/vote_count.rs @@ -17,13 +17,16 @@ use crate::arithmetic::{Integer, IntegerRef, Rational, RationalRef}; use crate::cli::Parallel; -use crate::parallelism::ThreadPool; +use crate::parallelism::{RangeStrategy, ThreadAccumulator, ThreadPool}; use crate::types::{Ballot, Election}; use log::Level::{Trace, Warn}; use log::{debug, log_enabled, trace, warn}; use rayon::prelude::*; use std::io; use std::marker::PhantomData; +use std::num::NonZeroUsize; +use std::sync::{Arc, RwLock, RwLockReadGuard}; +use std::thread::Scope; /// Result of a vote count. #[cfg_attr(test, derive(Debug, PartialEq))] @@ -96,6 +99,116 @@ where } } +/// A thread pool tied to a scope, that can perform vote counting rounds. +pub struct VoteCountingThreadPool<'scope, I, R> { + /// Inner thread pool. + pool: ThreadPool<'scope, VoteAccumulator>, + /// Storage for the keep factors, used as input of the current round by the + /// worker threads. + keep_factors: Arc>>, +} + +impl<'scope, I, R> VoteCountingThreadPool<'scope, I, R> +where + I: Integer + Send + Sync + 'scope, + for<'a> &'a I: IntegerRef, + R: Rational + Send + Sync + 'scope, + for<'a> &'a R: RationalRef<&'a I, R>, +{ + /// Creates a new pool tied to the given scope, with the given number of + /// threads and references to the necessary election inputs. + pub fn new<'env>( + thread_scope: &'scope Scope<'scope, 'env>, + num_threads: NonZeroUsize, + range_strategy: RangeStrategy, + election: &'env Election, + pascal: Option<&'env [Vec]>, + ) -> Self { + let keep_factors = Arc::new(RwLock::new(Vec::new())); + Self { + pool: ThreadPool::new( + thread_scope, + num_threads, + range_strategy, + &election.ballots, + || ThreadVoteCounter { + num_candidates: election.num_candidates, + pascal, + keep_factors: keep_factors.clone(), + }, + ), + keep_factors, + } + } + + /// Accumulates votes from the election ballots based on the given keep + /// factors. + pub fn accumulate_votes(&self, keep_factors: &[R]) -> VoteAccumulator { + { + let mut keep_factors_guard = self.keep_factors.write().unwrap(); + keep_factors_guard.clear(); + keep_factors_guard.extend_from_slice(keep_factors); + } + + self.pool + .process_inputs() + .reduce(|a, b| a.reduce(b)) + .unwrap() + } +} + +/// Helper state to accumulate votes in a worker thread. +struct ThreadVoteCounter<'env, I, R> { + /// Number of candidates in the election. + num_candidates: usize, + /// Pre-computed Pascal triangle. + pascal: Option<&'env [Vec]>, + /// Keep factors used in the current round. + keep_factors: Arc>>, +} + +impl ThreadAccumulator> for ThreadVoteCounter<'_, I, R> +where + I: Integer, + for<'a> &'a I: IntegerRef, + R: Rational, + for<'a> &'a R: RationalRef<&'a I, R>, +{ + type Accumulator<'a> + = (VoteAccumulator, RwLockReadGuard<'a, Vec>) + where + Self: 'a, + I: 'a, + R: 'a; + + fn init(&self) -> Self::Accumulator<'_> { + ( + VoteAccumulator::new(self.num_candidates), + self.keep_factors.read().unwrap(), + ) + } + + fn process_item<'a>( + &'a self, + accumulator: &mut Self::Accumulator<'a>, + index: usize, + ballot: &Ballot, + ) { + let (vote_accumulator, keep_factors) = accumulator; + VoteCount::::process_ballot( + vote_accumulator, + keep_factors, + self.pascal, + index, + ballot, + ); + } + + fn finalize<'a>(&'a self, accumulator: Self::Accumulator<'a>) -> VoteAccumulator { + accumulator.0 + } +} + impl VoteCount where I: Integer + Send + Sync, @@ -108,7 +221,7 @@ where election: &Election, keep_factors: &[R], parallel: Parallel, - thread_pool: Option<&ThreadPool<'_, I, R>>, + thread_pool: Option<&VoteCountingThreadPool<'_, I, R>>, pascal: Option<&[Vec]>, ) -> Self { let vote_accumulator = match parallel { @@ -690,7 +803,6 @@ where mod test { use super::*; use crate::arithmetic::{ApproxRational, BigFixedDecimal9, FixedDecimal9, Integer64}; - use crate::parallelism::RangeStrategy; use crate::types::Candidate; use ::test::Bencher; use num::rational::Ratio; @@ -698,7 +810,6 @@ mod test { use std::borrow::Borrow; use std::fmt::{Debug, Display}; use std::hint::black_box; - use std::num::NonZeroUsize; macro_rules! numeric_tests { ( $typei:ty, $typer:ty, ) => {}; @@ -1107,7 +1218,7 @@ mod test { for num_threads in 1..=10 { std::thread::scope(|thread_scope| { - let thread_pool = ThreadPool::new( + let thread_pool = VoteCountingThreadPool::new( thread_scope, NonZeroUsize::new(num_threads).unwrap(), RangeStrategy::WorkStealing, diff --git a/testdata/meek/bigint/ballot_sum_overflows.bigfixed9.err b/testdata/meek/bigint/ballot_sum_overflows.bigfixed9.err index b2cd26b..1bd850e 100644 --- a/testdata/meek/bigint/ballot_sum_overflows.bigfixed9.err +++ b/testdata/meek/bigint/ballot_sum_overflows.bigfixed9.err @@ -1,3 +1,3 @@ -thread 'main' panicked at /home/runner/work/stv-rs/stv-rs/src/meek.rs:574:25: +thread 'main' panicked at /home/runner/work/stv-rs/stv-rs/src/meek.rs:576:25: assertion failed: self.to_elect > 0 note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace diff --git a/testdata/meek/bigint/ballot_sum_overflows.fixed9.custom.err b/testdata/meek/bigint/ballot_sum_overflows.fixed9.custom.err index f8487ff..c01be30 100644 --- a/testdata/meek/bigint/ballot_sum_overflows.fixed9.custom.err +++ b/testdata/meek/bigint/ballot_sum_overflows.fixed9.custom.err @@ -3,6 +3,6 @@ called `Option::unwrap()` on a `None` value note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace [ERROR stv_rs::parallelism::thread_pool] [thread 0] Detected panic in this thread, notifying the main thread [ERROR stv_rs::parallelism::thread_pool] [main thread] A worker thread panicked! -thread 'main' panicked at /home/runner/work/stv-rs/stv-rs/src/parallelism/thread_pool.rs:310:13: +thread 'main' panicked at /home/runner/work/stv-rs/stv-rs/src/parallelism/thread_pool.rs:295:13: A worker thread panicked! [ERROR stv_rs::parallelism::thread_pool] [main thread] Thread 0 joined with result: Err(Any { .. })