Skip to content

Commit

Permalink
Improving Shplonk implementation (#326)
Browse files Browse the repository at this point in the history
* feat: Avoid redundant operations with constant polynomial

* feat: Avoid redundant clones

* feat: Parallelized verifier

* chore: Requested changes
  • Loading branch information
storojs72 committed Feb 20, 2024
1 parent 07d9b1e commit 1f453c9
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 61 deletions.
190 changes: 130 additions & 60 deletions src/provider/shplonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ use crate::{CommitmentEngineTrait, NovaError};
use ff::{Field, PrimeFieldBits};
use group::{Curve, Group as group_Group};
use pairing::{Engine, MillerLoopResult, MultiMillerLoop};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
};
use rayon::prelude::*;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::marker::PhantomData;

use crate::provider::hyperkzg::EvaluationEngine as HyperKZG;
use crate::spartan::math::Math;
use group::prime::PrimeCurveAffine;
use itertools::Itertools;
use ref_cast::RefCast as _;
Expand Down Expand Up @@ -60,7 +64,7 @@ where
transcript.squeeze(b"a").unwrap()
}

fn compute_pi_polynomials(hat_P: &[E::Fr], point: &[E::Fr], eval: &E::Fr) -> Vec<Vec<E::Fr>> {
fn compute_pi_polynomials(hat_P: &[E::Fr], point: &[E::Fr]) -> Vec<Vec<E::Fr>> {
let mut polys: Vec<Vec<E::Fr>> = Vec::new();
polys.push(hat_P.to_vec());

Expand All @@ -78,26 +82,20 @@ where
polys.push(Pi);
}

// TODO avoid including last constant polynomial, known to verifier
polys.push(vec![*eval]);

assert_eq!(polys.len(), 1 + (hat_P.len() as f32).log2().ceil() as usize);
assert_eq!(polys.len(), hat_P.len().log_2());

polys
}

fn compute_commitments(
ck: &UniversalKZGParam<E>,
C: &Commitment<NE>,
_C: &Commitment<NE>,
polys: &[Vec<E::Fr>],
) -> Vec<E::G1Affine> {
// TODO avoid computing commitment to constant polynomial
let mut comms: Vec<NE::GE> = (1..polys.len())
let comms: Vec<NE::GE> = (1..polys.len())
.into_par_iter()
.map(|i| <NE::CE as CommitmentEngineTrait<NE>>::commit(ck, &polys[i]).comm)
.collect();
// TODO avoid inserting commitment known to verifier
comms.insert(0, C.comm);

let mut comms_affine: Vec<E::G1Affine> = vec![E::G1Affine::identity(); comms.len()];
NE::GE::batch_normalize(&comms, &mut comms_affine);
Expand Down Expand Up @@ -169,15 +167,15 @@ where
C: &Commitment<NE>,
hat_P: &[E::Fr],
point: &[E::Fr],
eval: &E::Fr,
_eval: &E::Fr,
) -> Result<EvaluationArgument<E>, NovaError> {
let x: Vec<E::Fr> = point.to_vec();
let ell = x.len();
let n = hat_P.len();
assert_eq!(n, 1 << ell);

// Phase 1 (similar to hyperkzg)
let polys = Self::compute_pi_polynomials(hat_P, point, eval);
let polys = Self::compute_pi_polynomials(hat_P, point);
let comms = Self::compute_commitments(ck, C, &polys);

// Phase 2 (similar to hyperkzg)
Expand Down Expand Up @@ -226,9 +224,9 @@ where
fn verify(
vk: &KZGVerifierKey<E>,
transcript: &mut <NE as NovaEngine>::TE,
_C: &Commitment<NE>,
C: &Commitment<NE>,
point: &[E::Fr],
_P_of_x: &E::Fr,
P_of_x: &E::Fr,
pi: &EvaluationArgument<E>,
) -> Result<(), NovaError> {
let r = HyperKZG::<E, NE>::compute_challenge(&pi.comms, transcript);
Expand All @@ -241,39 +239,22 @@ where
return Err(NovaError::ProofVerifyError);
}

// TODO:
// insert _P_of_x into every pi.evals_i[last]
// insert _C into pi.comms[0]
// compute commitment for eval and insert it into pi.comms[last]
let mut comms = pi.comms.to_vec();
comms.insert(0, C.comm.to_affine());

let q = HyperKZG::<E, NE>::get_batch_challenge(&pi.evals, transcript);
//let q_powers = HyperKZG::<E, NE>::batch_challenge_powers(q, pi.comms.len());

let R_x = UniPoly::new(pi.R_x.clone());

let mut evals_at_r = vec![];
let mut evals_at_minus_r = vec![];
let mut evals_at_r_squared = vec![];
for (i, evals_i) in pi.evals.iter().enumerate() {
if i == 0 {
evals_at_r = evals_i.clone();
}
if i == 1 {
evals_at_minus_r = evals_i.clone();
}
if i == 2 {
evals_at_r_squared = evals_i.clone();
}

let batched_eval = UniPoly::ref_cast(evals_i).evaluate(&q);

let verification_failed = pi.evals.iter().zip_eq(u.iter()).any(|(evals_i, u_i)| {
// here we check correlation between R polynomial and batched evals, e.g.:
// 1) R(r) == eval at r
// 2) R(-r) == eval at -r
// 3) R(r^2) == eval at r^2
if batched_eval != R_x.evaluate(&u[i]) {
return Err(NovaError::ProofVerifyError);
}
let batched_eval = UniPoly::ref_cast(evals_i).evaluate(&q);
batched_eval != R_x.evaluate(u_i)
});
if verification_failed {
return Err(NovaError::ProofVerifyError);
}

// here we check that Pi polynomials were correctly constructed by the prover, using 'r' as a random point, e.g:
Expand All @@ -282,23 +263,33 @@ where
// P_i+1(r^2) == (1 - point_i) * P_i_even + point_i * P_i_odd -> should hold, according to Gemini transformation
let mut point = point.to_vec();
point.reverse();

let r_mul_2 = E::Fr::from(2) * r;
#[allow(clippy::disallowed_methods)]
for (index, ((eval_r, eval_minus_r), eval_r_squared)) in evals_at_r
.iter()
.zip_eq(evals_at_minus_r.iter())
// TODO: Ask Adrian if we need evals_at_r_squared[0] for some additional checks
.zip(evals_at_r_squared[1..].iter())
let verification_failed = pi.evals[0]
.par_iter()
.chain(&[*P_of_x])
.zip_eq(pi.evals[1].par_iter().chain(&[*P_of_x]))
.zip(pi.evals[2][1..].par_iter().chain(&[*P_of_x]))
.enumerate()
{
let even = (*eval_r + eval_minus_r) * (E::Fr::from(2).invert().unwrap());
let odd = (*eval_r - eval_minus_r) * ((E::Fr::from(2) * r).invert().unwrap());
.any(|(index, ((eval_r, eval_minus_r), eval_r_squared))| {
// some optimisation to avoid using expensive inversions:
// P_i+1(r^2) == (1 - point_i) * (P_i(r) + P_i(-r)) * 1/2 + point_i * (P_i(r) - P_i(-r)) * 1/2 * r
// is equivalent to:
// 2 * r * P_i+1(r^2) == r * (1 - point_i) * (P_i(r) + P_i(-r)) + point_i * (P_i(r) - P_i(-r))

let even = *eval_r + eval_minus_r;
let odd = *eval_r - eval_minus_r;
let right = r * ((E::Fr::ONE - point[index]) * even) + (point[index] * odd);
let left = *eval_r_squared * r_mul_2;
left != right
});

if *eval_r_squared != ((E::Fr::ONE - point[index]) * even) + (point[index] * odd) {
return Err(NovaError::ProofVerifyError);
}
if verification_failed {
return Err(NovaError::ProofVerifyError);
}

let C_P: E::G1 = pi.comms.par_iter().map(|comm| comm.to_curve()).rlc(&q);
let C_P: E::G1 = comms.par_iter().map(|comm| comm.to_curve()).rlc(&q);
let C_Q = pi.C_Q;
let C_H = pi.C_H;
let r_squared = u[2];
Expand Down Expand Up @@ -352,10 +343,11 @@ mod tests {
C: &Commitment<NE>,
poly: &[Fr],
point: &[Fr],
eval: &Fr,
_eval: &Fr,
) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point, eval);
let comms = EvaluationEngine::<E, NE>::compute_commitments(ck, C, &polys);
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point);
let mut comms = EvaluationEngine::<E, NE>::compute_commitments(ck, C, &polys);
comms.insert(0, C.comm.to_affine());

let q = Fr::from(8165763);
let q_powers = HyperKZG::<E, NE>::batch_challenge_powers(q, polys.len());
Expand Down Expand Up @@ -404,8 +396,8 @@ mod tests {
assert_eq!(C_K_expected, C_K.to_affine());
}

fn test_k_polynomial_correctness(poly: &[Fr], point: &[Fr], eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point, eval);
fn test_k_polynomial_correctness(poly: &[Fr], point: &[Fr], _eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point);
let q = Fr::from(8165763);
let batched_Pi: UniPoly<Fr> = polys.clone().into_iter().map(UniPoly::new).rlc(&q);

Expand All @@ -428,8 +420,8 @@ mod tests {
assert_eq!(Fr::from(0), K_x.evaluate(&a));
}

fn test_d_polynomial_correctness(poly: &[Fr], point: &[Fr], eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point, eval);
fn test_d_polynomial_correctness(poly: &[Fr], point: &[Fr], _eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point);
let q = Fr::from(8165763);
let batched_Pi: UniPoly<Fr> = polys.clone().into_iter().map(UniPoly::new).rlc(&q);

Expand Down Expand Up @@ -471,8 +463,8 @@ mod tests {
assert_eq!(Q_x, Q_x_recomputed);
}

fn test_batching_property_on_evaluation(poly: &[Fr], point: &[Fr], eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point, eval);
fn test_batching_property_on_evaluation(poly: &[Fr], point: &[Fr], _eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point);

let q = Fr::from(97652);
let u = [Fr::from(10), Fr::from(20), Fr::from(50)];
Expand Down Expand Up @@ -648,4 +640,82 @@ mod tests {
)
.is_err());
}

#[test]
fn test_shplonk_pcs_negative_wrong_commitment() {
let n = 8;
// poly = [1, 2, 1, 4, 1, 2, 1, 4]
let poly = vec![
Fr::ONE,
Fr::from(2),
Fr::from(1),
Fr::from(4),
Fr::ONE,
Fr::from(2),
Fr::from(1),
Fr::from(4),
];
// point = [4,3,8]
let point = vec![Fr::from(4), Fr::from(3), Fr::from(8)];
// eval = 57
let eval = Fr::from(57);

// altered_poly = [85, 84, 83, 82, 81, 80, 79, 78]
let altered_poly = vec![
Fr::from(85),
Fr::from(84),
Fr::from(83),
Fr::from(82),
Fr::from(81),
Fr::from(80),
Fr::from(79),
Fr::from(78),
];

let ck: CommitmentKey<NE> =
<KZGCommitmentEngine<E> as CommitmentEngineTrait<NE>>::setup(b"test", n);

let C1: Commitment<NE> = KZGCommitmentEngine::commit(&ck, &poly); // correct commitment
let C2: Commitment<NE> = KZGCommitmentEngine::commit(&ck, &altered_poly); // wrong commitment

test_negative_inner_commitment(&poly, &point, &eval, &ck, &C1, &C2); // here we check detection when proof and commitment do not correspond
test_negative_inner_commitment(&poly, &point, &eval, &ck, &C2, &C2); // here we check detection when proof was built with wrong commitment
}

fn test_negative_inner_commitment(
poly: &[Fr],
point: &[Fr],
eval: &Fr,
ck: &CommitmentKey<NE>,
C_prover: &Commitment<NE>,
C_verifier: &Commitment<NE>,
) {
let ck = Arc::new(ck.clone());
let (pk, vk): (KZGProverKey<E>, KZGVerifierKey<E>) =
EvaluationEngine::<E, NE>::setup(ck.clone());

let mut prover_transcript = Keccak256Transcript::new(b"TestEval");
let mut verifier_transcript = Keccak256Transcript::<NE>::new(b"TestEval");

let proof = EvaluationEngine::<E, NE>::prove(
&ck,
&pk,
&mut prover_transcript,
C_prover,
poly,
point,
eval,
)
.unwrap();

assert!(EvaluationEngine::<E, NE>::verify(
&vk,
&mut verifier_transcript,
C_verifier,
point,
eval,
&proof
)
.is_err());
}
}
2 changes: 1 addition & 1 deletion src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod batched;
pub mod batched_ppsnark;
#[macro_use]
mod macros;
mod math;
pub(crate) mod math;
pub mod polys;
pub mod ppsnark;
pub mod snark;
Expand Down

1 comment on commit 1f453c9

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Arecibo GPU benchmarks.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/arecibo/actions/runs/7974914682

Benchmark Results

RecursiveSNARK-NIVC-2

ref=07d9b1e ref=1f453c9
Prove-NumCons-6540 44.35 ms (✅ 1.00x) 44.27 ms (✅ 1.00x faster)
Verify-NumCons-6540 33.92 ms (✅ 1.00x) 33.87 ms (✅ 1.00x faster)
Prove-NumCons-1028888 317.71 ms (✅ 1.00x) 336.14 ms (✅ 1.06x slower)
Verify-NumCons-1028888 248.57 ms (✅ 1.00x) 267.20 ms (✅ 1.07x slower)

CompressedSNARK-NIVC-Commitments-2

ref=07d9b1e ref=1f453c9
Prove-NumCons-6540 10.42 s (✅ 1.00x) 10.43 s (✅ 1.00x slower)
Verify-NumCons-6540 50.80 ms (✅ 1.00x) 51.69 ms (✅ 1.02x slower)
Prove-NumCons-1028888 53.33 s (✅ 1.00x) 51.70 s (✅ 1.03x faster)
Verify-NumCons-1028888 50.98 ms (✅ 1.00x) 51.18 ms (✅ 1.00x slower)

Made with criterion-table

Please sign in to comment.