From d6f562a81ad319b6da88bda73d37967ab46cd5f7 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Sun, 25 Dec 2022 23:23:27 -0500 Subject: [PATCH] feat: parallelize (cpu) shplonk prover --- halo2_proofs/src/poly/kzg/multiopen/shplonk.rs | 8 ++++---- .../src/poly/kzg/multiopen/shplonk/prover.rs | 15 +++++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/halo2_proofs/src/poly/kzg/multiopen/shplonk.rs b/halo2_proofs/src/poly/kzg/multiopen/shplonk.rs index 125936229e..459c4a7e1b 100644 --- a/halo2_proofs/src/poly/kzg/multiopen/shplonk.rs +++ b/halo2_proofs/src/poly/kzg/multiopen/shplonk.rs @@ -80,7 +80,7 @@ where assert_eq!(*point, query.get_point()); } // All points appear in queries - let super_point_set: Vec = rotation_point_map.values().cloned().collect(); + let super_point_set: Vec = rotation_point_map.values().copied().collect(); // Collect rotation sets for each commitment // Example elements in the vector: @@ -90,7 +90,7 @@ where // (C_3, {r_2, r_3, r_4}), // ... let mut commitment_rotation_set_map: Vec<(Q::Commitment, Vec)> = vec![]; - for query in queries.clone() { + for query in queries.iter() { let rotation = query.get_point(); if let Some(pos) = commitment_rotation_set_map .iter() @@ -114,8 +114,7 @@ where let mut rotation_set_commitment_map = Vec::<(Vec<_>, Vec)>::new(); for (commitment, rotation_set) in commitment_rotation_set_map.iter() { if let Some(pos) = rotation_set_commitment_map.iter().position(|(set, _)| { - BTreeSet::::from_iter(set.iter().cloned()) - == BTreeSet::::from_iter(rotation_set.iter().cloned()) + BTreeSet::<&F>::from_iter(set.iter()) == BTreeSet::<&F>::from_iter(rotation_set.iter()) }) { let (_, commitments) = &mut rotation_set_commitment_map[pos]; if !commitments.contains(commitment) { @@ -126,6 +125,7 @@ where } } + // TODO: parallelize let rotation_sets = rotation_set_commitment_map .into_iter() .map(|(rotations, commitments)| { diff --git a/halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs b/halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs index 2585d9ab69..46626a4ab4 100644 --- a/halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs +++ b/halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs @@ -17,6 +17,8 @@ use ff::Field; use group::Curve; use halo2curves::pairing::Engine; use rand_core::RngCore; +use rayon::iter::ParallelIterator; +use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator}; use std::fmt::Debug; use std::io::{self, Write}; use std::marker::PhantomData; @@ -171,11 +173,11 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> ); let rotation_sets: Vec> = rotation_sets - .iter() + .par_iter() .map(|rotation_set| { let commitments: Vec> = rotation_set .commitments - .iter() + .par_iter() .map(|commitment_data| commitment_data.extend(rotation_set.points.clone())) .collect(); rotation_set.extend(commitments) @@ -184,9 +186,13 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> let v: ChallengeV<_> = transcript.squeeze_challenge_scalar(); - let quotient_polynomials = rotation_sets.iter().map(quotient_contribution); + let quotient_polynomials = rotation_sets + .par_iter() + .map(quotient_contribution) + .collect::>(); let h_x: Polynomial = quotient_polynomials + .into_iter() .zip(powers(*v)) .map(|(poly, power_of_v)| poly * power_of_v) .reduce(|acc, poly| acc + &poly) @@ -235,7 +241,7 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> Vec>, Vec, ) = rotation_sets - .into_iter() + .into_par_iter() .map(linearisation_contribution) .unzip(); @@ -249,6 +255,7 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> let l_x = l_x - &(h_x * zt_eval); // sanity check + #[cfg(debug_assertions)] { let must_be_zero = eval_polynomial(&l_x.values[..], *u); assert_eq!(must_be_zero, E::Scalar::zero());