From 1f453c9778568ee0e1d285a65fb30adfe6c07d3a Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Tue, 20 Feb 2024 16:16:53 +0200 Subject: [PATCH] Improving Shplonk implementation (#326) * feat: Avoid redundant operations with constant polynomial * feat: Avoid redundant clones * feat: Parallelized verifier * chore: Requested changes --- src/provider/shplonk.rs | 190 +++++++++++++++++++++++++++------------- src/spartan/mod.rs | 2 +- 2 files changed, 131 insertions(+), 61 deletions(-) diff --git a/src/provider/shplonk.rs b/src/provider/shplonk.rs index 73ad55a21..7ebb92265 100644 --- a/src/provider/shplonk.rs +++ b/src/provider/shplonk.rs @@ -12,11 +12,15 @@ use crate::{CommitmentEngineTrait, NovaError}; use ff::{Field, PrimeFieldBits}; use group::{Curve, Group as group_Group}; use pairing::{Engine, MillerLoopResult, MultiMillerLoop}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use rayon::prelude::*; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::marker::PhantomData; use crate::provider::hyperkzg::EvaluationEngine as HyperKZG; +use crate::spartan::math::Math; use group::prime::PrimeCurveAffine; use itertools::Itertools; use ref_cast::RefCast as _; @@ -60,7 +64,7 @@ where transcript.squeeze(b"a").unwrap() } - fn compute_pi_polynomials(hat_P: &[E::Fr], point: &[E::Fr], eval: &E::Fr) -> Vec> { + fn compute_pi_polynomials(hat_P: &[E::Fr], point: &[E::Fr]) -> Vec> { let mut polys: Vec> = Vec::new(); polys.push(hat_P.to_vec()); @@ -78,26 +82,20 @@ where polys.push(Pi); } - // TODO avoid including last constant polynomial, known to verifier - polys.push(vec![*eval]); - - assert_eq!(polys.len(), 1 + (hat_P.len() as f32).log2().ceil() as usize); + assert_eq!(polys.len(), hat_P.len().log_2()); polys } fn compute_commitments( ck: &UniversalKZGParam, - C: &Commitment, + _C: &Commitment, polys: &[Vec], ) -> Vec { - // TODO avoid computing commitment to constant polynomial - let mut comms: Vec = (1..polys.len()) + let comms: Vec = (1..polys.len()) .into_par_iter() .map(|i| >::commit(ck, &polys[i]).comm) .collect(); - // TODO avoid inserting commitment known to verifier - comms.insert(0, C.comm); let mut comms_affine: Vec = vec![E::G1Affine::identity(); comms.len()]; NE::GE::batch_normalize(&comms, &mut comms_affine); @@ -169,7 +167,7 @@ where C: &Commitment, hat_P: &[E::Fr], point: &[E::Fr], - eval: &E::Fr, + _eval: &E::Fr, ) -> Result, NovaError> { let x: Vec = point.to_vec(); let ell = x.len(); @@ -177,7 +175,7 @@ where assert_eq!(n, 1 << ell); // Phase 1 (similar to hyperkzg) - let polys = Self::compute_pi_polynomials(hat_P, point, eval); + let polys = Self::compute_pi_polynomials(hat_P, point); let comms = Self::compute_commitments(ck, C, &polys); // Phase 2 (similar to hyperkzg) @@ -226,9 +224,9 @@ where fn verify( vk: &KZGVerifierKey, transcript: &mut ::TE, - _C: &Commitment, + C: &Commitment, point: &[E::Fr], - _P_of_x: &E::Fr, + P_of_x: &E::Fr, pi: &EvaluationArgument, ) -> Result<(), NovaError> { let r = HyperKZG::::compute_challenge(&pi.comms, transcript); @@ -241,39 +239,22 @@ where return Err(NovaError::ProofVerifyError); } - // TODO: - // insert _P_of_x into every pi.evals_i[last] - // insert _C into pi.comms[0] - // compute commitment for eval and insert it into pi.comms[last] + let mut comms = pi.comms.to_vec(); + comms.insert(0, C.comm.to_affine()); let q = HyperKZG::::get_batch_challenge(&pi.evals, transcript); - //let q_powers = HyperKZG::::batch_challenge_powers(q, pi.comms.len()); - let R_x = UniPoly::new(pi.R_x.clone()); - let mut evals_at_r = vec![]; - let mut evals_at_minus_r = vec![]; - let mut evals_at_r_squared = vec![]; - for (i, evals_i) in pi.evals.iter().enumerate() { - if i == 0 { - evals_at_r = evals_i.clone(); - } - if i == 1 { - evals_at_minus_r = evals_i.clone(); - } - if i == 2 { - evals_at_r_squared = evals_i.clone(); - } - - let batched_eval = UniPoly::ref_cast(evals_i).evaluate(&q); - + let verification_failed = pi.evals.iter().zip_eq(u.iter()).any(|(evals_i, u_i)| { // here we check correlation between R polynomial and batched evals, e.g.: // 1) R(r) == eval at r // 2) R(-r) == eval at -r // 3) R(r^2) == eval at r^2 - if batched_eval != R_x.evaluate(&u[i]) { - return Err(NovaError::ProofVerifyError); - } + let batched_eval = UniPoly::ref_cast(evals_i).evaluate(&q); + batched_eval != R_x.evaluate(u_i) + }); + if verification_failed { + return Err(NovaError::ProofVerifyError); } // here we check that Pi polynomials were correctly constructed by the prover, using 'r' as a random point, e.g: @@ -282,23 +263,33 @@ where // P_i+1(r^2) == (1 - point_i) * P_i_even + point_i * P_i_odd -> should hold, according to Gemini transformation let mut point = point.to_vec(); point.reverse(); + + let r_mul_2 = E::Fr::from(2) * r; #[allow(clippy::disallowed_methods)] - for (index, ((eval_r, eval_minus_r), eval_r_squared)) in evals_at_r - .iter() - .zip_eq(evals_at_minus_r.iter()) - // TODO: Ask Adrian if we need evals_at_r_squared[0] for some additional checks - .zip(evals_at_r_squared[1..].iter()) + let verification_failed = pi.evals[0] + .par_iter() + .chain(&[*P_of_x]) + .zip_eq(pi.evals[1].par_iter().chain(&[*P_of_x])) + .zip(pi.evals[2][1..].par_iter().chain(&[*P_of_x])) .enumerate() - { - let even = (*eval_r + eval_minus_r) * (E::Fr::from(2).invert().unwrap()); - let odd = (*eval_r - eval_minus_r) * ((E::Fr::from(2) * r).invert().unwrap()); + .any(|(index, ((eval_r, eval_minus_r), eval_r_squared))| { + // some optimisation to avoid using expensive inversions: + // P_i+1(r^2) == (1 - point_i) * (P_i(r) + P_i(-r)) * 1/2 + point_i * (P_i(r) - P_i(-r)) * 1/2 * r + // is equivalent to: + // 2 * r * P_i+1(r^2) == r * (1 - point_i) * (P_i(r) + P_i(-r)) + point_i * (P_i(r) - P_i(-r)) + + let even = *eval_r + eval_minus_r; + let odd = *eval_r - eval_minus_r; + let right = r * ((E::Fr::ONE - point[index]) * even) + (point[index] * odd); + let left = *eval_r_squared * r_mul_2; + left != right + }); - if *eval_r_squared != ((E::Fr::ONE - point[index]) * even) + (point[index] * odd) { - return Err(NovaError::ProofVerifyError); - } + if verification_failed { + return Err(NovaError::ProofVerifyError); } - let C_P: E::G1 = pi.comms.par_iter().map(|comm| comm.to_curve()).rlc(&q); + let C_P: E::G1 = comms.par_iter().map(|comm| comm.to_curve()).rlc(&q); let C_Q = pi.C_Q; let C_H = pi.C_H; let r_squared = u[2]; @@ -352,10 +343,11 @@ mod tests { C: &Commitment, poly: &[Fr], point: &[Fr], - eval: &Fr, + _eval: &Fr, ) { - let polys = EvaluationEngine::::compute_pi_polynomials(poly, point, eval); - let comms = EvaluationEngine::::compute_commitments(ck, C, &polys); + let polys = EvaluationEngine::::compute_pi_polynomials(poly, point); + let mut comms = EvaluationEngine::::compute_commitments(ck, C, &polys); + comms.insert(0, C.comm.to_affine()); let q = Fr::from(8165763); let q_powers = HyperKZG::::batch_challenge_powers(q, polys.len()); @@ -404,8 +396,8 @@ mod tests { assert_eq!(C_K_expected, C_K.to_affine()); } - fn test_k_polynomial_correctness(poly: &[Fr], point: &[Fr], eval: &Fr) { - let polys = EvaluationEngine::::compute_pi_polynomials(poly, point, eval); + fn test_k_polynomial_correctness(poly: &[Fr], point: &[Fr], _eval: &Fr) { + let polys = EvaluationEngine::::compute_pi_polynomials(poly, point); let q = Fr::from(8165763); let batched_Pi: UniPoly = polys.clone().into_iter().map(UniPoly::new).rlc(&q); @@ -428,8 +420,8 @@ mod tests { assert_eq!(Fr::from(0), K_x.evaluate(&a)); } - fn test_d_polynomial_correctness(poly: &[Fr], point: &[Fr], eval: &Fr) { - let polys = EvaluationEngine::::compute_pi_polynomials(poly, point, eval); + fn test_d_polynomial_correctness(poly: &[Fr], point: &[Fr], _eval: &Fr) { + let polys = EvaluationEngine::::compute_pi_polynomials(poly, point); let q = Fr::from(8165763); let batched_Pi: UniPoly = polys.clone().into_iter().map(UniPoly::new).rlc(&q); @@ -471,8 +463,8 @@ mod tests { assert_eq!(Q_x, Q_x_recomputed); } - fn test_batching_property_on_evaluation(poly: &[Fr], point: &[Fr], eval: &Fr) { - let polys = EvaluationEngine::::compute_pi_polynomials(poly, point, eval); + fn test_batching_property_on_evaluation(poly: &[Fr], point: &[Fr], _eval: &Fr) { + let polys = EvaluationEngine::::compute_pi_polynomials(poly, point); let q = Fr::from(97652); let u = [Fr::from(10), Fr::from(20), Fr::from(50)]; @@ -648,4 +640,82 @@ mod tests { ) .is_err()); } + + #[test] + fn test_shplonk_pcs_negative_wrong_commitment() { + let n = 8; + // poly = [1, 2, 1, 4, 1, 2, 1, 4] + let poly = vec![ + Fr::ONE, + Fr::from(2), + Fr::from(1), + Fr::from(4), + Fr::ONE, + Fr::from(2), + Fr::from(1), + Fr::from(4), + ]; + // point = [4,3,8] + let point = vec![Fr::from(4), Fr::from(3), Fr::from(8)]; + // eval = 57 + let eval = Fr::from(57); + + // altered_poly = [85, 84, 83, 82, 81, 80, 79, 78] + let altered_poly = vec![ + Fr::from(85), + Fr::from(84), + Fr::from(83), + Fr::from(82), + Fr::from(81), + Fr::from(80), + Fr::from(79), + Fr::from(78), + ]; + + let ck: CommitmentKey = + as CommitmentEngineTrait>::setup(b"test", n); + + let C1: Commitment = KZGCommitmentEngine::commit(&ck, &poly); // correct commitment + let C2: Commitment = KZGCommitmentEngine::commit(&ck, &altered_poly); // wrong commitment + + test_negative_inner_commitment(&poly, &point, &eval, &ck, &C1, &C2); // here we check detection when proof and commitment do not correspond + test_negative_inner_commitment(&poly, &point, &eval, &ck, &C2, &C2); // here we check detection when proof was built with wrong commitment + } + + fn test_negative_inner_commitment( + poly: &[Fr], + point: &[Fr], + eval: &Fr, + ck: &CommitmentKey, + C_prover: &Commitment, + C_verifier: &Commitment, + ) { + let ck = Arc::new(ck.clone()); + let (pk, vk): (KZGProverKey, KZGVerifierKey) = + EvaluationEngine::::setup(ck.clone()); + + let mut prover_transcript = Keccak256Transcript::new(b"TestEval"); + let mut verifier_transcript = Keccak256Transcript::::new(b"TestEval"); + + let proof = EvaluationEngine::::prove( + &ck, + &pk, + &mut prover_transcript, + C_prover, + poly, + point, + eval, + ) + .unwrap(); + + assert!(EvaluationEngine::::verify( + &vk, + &mut verifier_transcript, + C_verifier, + point, + eval, + &proof + ) + .is_err()); + } } diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index 50cbd6a38..9b38adb24 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -10,7 +10,7 @@ pub mod batched; pub mod batched_ppsnark; #[macro_use] mod macros; -mod math; +pub(crate) mod math; pub mod polys; pub mod ppsnark; pub mod snark;