Skip to content

Commit

Permalink
Make the custom thread pool generic over the input and output types.
Browse files Browse the repository at this point in the history
  • Loading branch information
gendx committed Dec 5, 2024
1 parent e48722d commit 00db2ae
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 109 deletions.
12 changes: 7 additions & 5 deletions src/meek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
use crate::arithmetic::{Integer, IntegerRef, Rational, RationalRef};
use crate::cli::Parallel;
use crate::parallelism::{RangeStrategy, ThreadPool};
use crate::parallelism::RangeStrategy;
use crate::types::{Election, ElectionResult};
use crate::vote_count::{self, VoteCount};
use crate::vote_count::{self, VoteCount, VoteCountingThreadPool};
use log::{debug, info, log_enabled, warn, Level};
use std::fmt::{self, Debug, Display};
use std::io;
Expand Down Expand Up @@ -111,7 +111,7 @@ where
.expect("A positive number of threads must be spawned, but the available parallelism is zero threads")
});
info!("Spawning {num_threads} threads");
let thread_pool = ThreadPool::new(
let thread_pool = VoteCountingThreadPool::new(
thread_scope,
num_threads,
if disable_work_stealing {
Expand Down Expand Up @@ -200,7 +200,9 @@ pub struct State<'scope, 'e, I, R> {
/// Pre-computed Pascal triangle. Set only if the "equalized counting" is
/// enabled.
pascal: Option<&'e [Vec<I>]>,
thread_pool: Option<ThreadPool<'scope, I, R>>,
/// Thread pool where to schedule vote counting. Set only if the parallel
/// mode is set to "custom".
thread_pool: Option<VoteCountingThreadPool<'scope, I, R>>,
_phantom: PhantomData<I>,
}

Expand All @@ -224,7 +226,7 @@ where
parallel: Parallel,
force_positive_surplus: bool,
pascal: Option<&'a [Vec<I>]>,
thread_pool: Option<ThreadPool<'scope, I, R>>,
thread_pool: Option<VoteCountingThreadPool<'scope, I, R>>,
) -> State<'scope, 'a, I, R> {
State {
election,
Expand Down
79 changes: 78 additions & 1 deletion src/parallelism/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,81 @@
mod range;
mod thread_pool;

pub use thread_pool::{RangeStrategy, ThreadPool};
pub use thread_pool::{RangeStrategy, ThreadAccumulator, ThreadPool};

#[cfg(test)]
mod test {
use super::*;
use std::num::NonZeroUsize;

/// Example of accumulator that computes a sum of integers.
struct SumAccumulator;

impl ThreadAccumulator<u64, u64> for SumAccumulator {
type Accumulator<'a> = u64;

fn init(&self) -> u64 {
0
}

fn process_item(&self, accumulator: &mut u64, _index: usize, x: &u64) {
*accumulator += *x;
}

fn finalize(&self, accumulator: u64) -> u64 {
accumulator
}
}

macro_rules! parallelism_tests {
( $mod:ident, $range_strategy:expr, $($case:ident,)+ ) => {
mod $mod {
use super::*;

$(
#[test]
fn $case() {
$crate::parallelism::test::$case($range_strategy);
}
)+
}
};
}

macro_rules! all_parallelism_tests {
( $mod:ident, $range_strategy:expr ) => {
parallelism_tests!($mod, $range_strategy, test_sum_integers, test_sum_twice,);
};
}

all_parallelism_tests!(fixed, RangeStrategy::Fixed);
all_parallelism_tests!(work_stealing, RangeStrategy::WorkStealing);

fn test_sum_integers(range_strategy: RangeStrategy) {
let input = (0..=10_000).collect::<Vec<u64>>();
let num_threads = NonZeroUsize::try_from(4).unwrap();
let sum = std::thread::scope(|scope| {
let thread_pool = ThreadPool::new(scope, num_threads, range_strategy, &input, || {
SumAccumulator
});
thread_pool.process_inputs().reduce(|a, b| a + b).unwrap()
});
assert_eq!(sum, 5_000 * 10_001);
}

fn test_sum_twice(range_strategy: RangeStrategy) {
let input = (0..=10_000).collect::<Vec<u64>>();
let num_threads = NonZeroUsize::try_from(4).unwrap();
let (sum1, sum2) = std::thread::scope(|scope| {
let thread_pool = ThreadPool::new(scope, num_threads, range_strategy, &input, || {
SumAccumulator
});
// The same input can be processed multiple times on the thread pool.
let sum1 = thread_pool.process_inputs().reduce(|a, b| a + b).unwrap();
let sum2 = thread_pool.process_inputs().reduce(|a, b| a + b).unwrap();
(sum1, sum2)
});
assert_eq!(sum1, 5_000 * 10_001);
assert_eq!(sum2, 5_000 * 10_001);
}
}
Loading

0 comments on commit 00db2ae

Please sign in to comment.