Skip to content

Commit

Permalink
Alternative folding scheme for bound_poly_vars_rq
Browse files Browse the repository at this point in the history
  • Loading branch information
darth-cy committed Dec 27, 2024
1 parent 142134d commit 32a2d7c
Showing 1 changed file with 68 additions and 26 deletions.
94 changes: 68 additions & 26 deletions spartan_parallel/src/custom_dense_mlpoly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const MODE_P: usize = 1;
const MODE_Q: usize = 2;
const MODE_W: usize = 3;
const MODE_X: usize = 4;
const NUM_MULTI_THREAD_CORES: usize = 8;

// Customized Dense ML Polynomials for Data-Parallelism
// These Dense ML Polys are aimed for space-efficiency by removing the 0s for invalid (p, q, w, x) quadruple
Expand Down Expand Up @@ -209,47 +210,88 @@ impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
pub fn bound_poly_vars_rq(&mut self,
r_q: &[S],
) {
let ONE = S::field_one();
let num_instances = min(self.num_instances, self.Z.len());

fn recur<S: SpartanExtensionField>(idx: usize, lvl: usize, w: usize, x: usize, env: &(
&Vec<Vec<Vec<S>>>, // self.Z[p]
&[S], // r_q
&usize, // start_idx
)) -> S {
if lvl > 0 {
(S::field_one() - env.1[lvl]) * recur(2 * idx, lvl - 1, w, x, env) + env.1[lvl] * recur(2 * idx + 1, lvl - 1, w, x, env)
} else {
env.0[env.2 + idx][w][x]
}
}

self.Z = (0..num_instances)
.into_par_iter()
// .into_par_iter()
.map(|p| {
let num_proofs = self.num_proofs[p];
let dist_size = num_proofs / min(num_proofs, NUM_MULTI_THREAD_CORES);
let num_threads = num_proofs / dist_size;

// To perform rigorous 2-fold parallelism, both num_proofs and # cores must be powers of 2
// # cores must fully divide num_proofs for even distribution
assert!(num_proofs & (num_proofs - 1) == 0);
assert!(num_threads & (num_threads - 1) == 0);

// debug_parallelism
println!("num_proofs: {:?}, num_threads: {:?}", num_proofs, num_threads);

// Determine the aggregation levels that will be performed in parallel
// The last rounds of aggregation will be done in single core
let levels = num_proofs.trailing_zeros() as usize;
let sub_levels = dist_size.trailing_zeros() as usize;
let final_levels = num_threads.trailing_zeros() as usize;
let left_over_q_len = r_q.len() - levels;

// debug_parallelism
println!("levels: {:?}, sub_levels: {:?}, final_levels: {:?}, left_over_q_len: {:?}", levels, sub_levels, final_levels, left_over_q_len);

let num_witness_secs = min(self.num_witness_secs, self.Z[p][0].len());
let num_inputs = self.num_inputs[p];

let wit = (0..num_witness_secs).into_par_iter().map(|w| {
(0..num_inputs).into_par_iter().map(|x| {
let mut np = num_proofs;
let mut x_fold = (0..num_proofs).map(|q| self.Z[p][q][w][x]).collect::<Vec<S>>();
for r in r_q {
if np == 1 {
x_fold[0] *= ONE - *r;
} else {
np /= 2;
for q in 0..np {
x_fold[q] = x_fold[2 * q] + *r * (x_fold[2 * q + 1] - x_fold[2 * q]);
}
}
}

x_fold
}).collect::<Vec<Vec<S>>>()
}).collect::<Vec<Vec<Vec<S>>>>();
// debug_parallelism
println!("num_witness_secs: {:?}, num_inputs: {:?}", num_witness_secs, num_inputs);

(0..num_proofs)
let mut sub_mats = if sub_levels > 0 {
std::iter::successors(Some(0usize), move |&x| Some(x + dist_size))
.take(NUM_MULTI_THREAD_CORES)
.collect::<Vec<usize>>()
.into_par_iter()
.map(|q| {
(0..wit.len()).map(|w| {
(0..wit[w].len()).map(|x| {
wit[w][x][q]
.map(|start_idx| {
(0..num_witness_secs).map(|w| {
(0..num_inputs).map(|x| {
recur(0, sub_levels, w, x, &(&self.Z[p], r_q, &start_idx))
}).collect::<Vec<S>>()
}).collect::<Vec<Vec<S>>>()
}).collect::<Vec<Vec<Vec<S>>>>()
} else {
self.Z[p].clone()
};

if final_levels > 0 {
sub_mats[0] = (0..num_witness_secs).map(|w| {
(0..num_inputs).map(|x| {
recur(0, final_levels, w, x, &(&sub_mats, r_q, &0))
}).collect::<Vec<S>>()
}).collect::<Vec<Vec<S>>>()
}

if left_over_q_len > 0 {
let c = r_q[(r_q.len() - left_over_q_len)..r_q.len()].iter().fold(S::field_one(), |acc, n| acc * (S::field_one() - *n));
for w in 0..sub_mats[0].len() {
for x in 0..sub_mats[0][0].len() {
sub_mats[0][w][x] *= c;
}
}
}

sub_mats
}).collect::<Vec<Vec<Vec<Vec<S>>>>>();

self.max_num_proofs /= 2usize.pow(r_q.len() as u32);
self.max_num_proofs /= 2usize.pow(r_q.len() as u32);
}

// Bound the entire "w" section to r_w in reverse
Expand Down

0 comments on commit 32a2d7c

Please sign in to comment.