Skip to content

Commit

Permalink
Small code organization improvements (#206)
Browse files Browse the repository at this point in the history
* refactor: Deleted a redundant `ScalarMul` helper trait

* refactor: Refactor `to_transcript_bytes`

* refactor: refactor R1CS Shape checking in Spartan checks

- Introduced a new function `check_regular_shape` in `r1cs.rs` to enforce regularity conditions necessary for Spartan-class SNARKs.

* refactor: Refactor sumcheck.rs prove_quad_* for readability

- Extracted the calculation of evaluation points to its new function `compute_eval_points`, enhancing code reusability within `prove_quad` and `prove_quad_batch` functions.
  • Loading branch information
huitseeker authored Jul 21, 2023
1 parent 87499b3 commit a62bccf
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 54 deletions.
10 changes: 10 additions & 0 deletions src/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ impl<G: Group> R1CSShape<G> {
})
}

// Checks regularity conditions on the R1CSShape, required in Spartan-class SNARKs
// Panics if num_cons, num_vars, or num_io are not powers of two, or if num_io > num_vars
#[inline]
pub(crate) fn check_regular_shape(&self) {
assert_eq!(self.num_cons.next_power_of_two(), self.num_cons);
assert_eq!(self.num_vars.next_power_of_two(), self.num_vars);
assert_eq!(self.num_io.next_power_of_two(), self.num_io);
assert!(self.num_io < self.num_vars);
}

pub fn multiply_vec(
&self,
z: &[G::Scalar],
Expand Down
5 changes: 1 addition & 4 deletions src/spartan/ppsnark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,10 +967,7 @@ impl<G: Group, EE: EvaluationEngineTrait<G, CE = G::CE>> RelaxedR1CSSNARKTrait<G
let mut w_u_vec = Vec::new();

// sanity check that R1CSShape has certain size characteristics
assert_eq!(pk.S.num_cons.next_power_of_two(), pk.S.num_cons);
assert_eq!(pk.S.num_vars.next_power_of_two(), pk.S.num_vars);
assert_eq!(pk.S.num_io.next_power_of_two(), pk.S.num_io);
assert!(pk.S.num_io < pk.S.num_vars);
pk.S.check_regular_shape();

// append the verifier key (which includes commitment to R1CS matrices) and the RelaxedR1CSInstance to the transcript
transcript.absorb(b"vk", &pk.vk_digest);
Expand Down
5 changes: 1 addition & 4 deletions src/spartan/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ impl<G: Group, EE: EvaluationEngineTrait<G, CE = G::CE>> RelaxedR1CSSNARKTrait<G
let mut transcript = G::TE::new(b"RelaxedR1CSSNARK");

// sanity check that R1CSShape has certain size characteristics
assert_eq!(pk.S.num_cons.next_power_of_two(), pk.S.num_cons);
assert_eq!(pk.S.num_vars.next_power_of_two(), pk.S.num_vars);
assert_eq!(pk.S.num_io.next_power_of_two(), pk.S.num_io);
assert!(pk.S.num_io < pk.S.num_vars);
pk.S.check_regular_shape();

// append the digest of vk (which includes R1CS matrices) and the RelaxedR1CSInstance to the transcript
transcript.absorb(b"vk", &pk.vk_digest);
Expand Down
65 changes: 31 additions & 34 deletions src/spartan/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ impl<G: Group> SumcheckProof<G> {
Ok((e, r))
}

#[inline]
fn compute_eval_points<F>(
poly_A: &MultilinearPolynomial<G::Scalar>,
poly_B: &MultilinearPolynomial<G::Scalar>,
comb_func: &F,
) -> (G::Scalar, G::Scalar)
where
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{
let len = poly_A.len() / 2;
(0..len)
.into_par_iter()
.map(|i| {
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func(&poly_A[i], &poly_B[i]);

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let eval_point_2 = comb_func(&poly_A_bound_point, &poly_B_bound_point);
(eval_point_0, eval_point_2)
})
.reduce(
|| (G::Scalar::ZERO, G::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
)
}

pub fn prove_quad<F>(
claim: &G::Scalar,
num_rounds: usize,
Expand All @@ -77,25 +105,7 @@ impl<G: Group> SumcheckProof<G> {
let mut claim_per_round = *claim;
for _ in 0..num_rounds {
let poly = {
let len = poly_A.len() / 2;

// Make an iterator returning the contributions to the evaluations
let (eval_point_0, eval_point_2) = (0..len)
.into_par_iter()
.map(|i| {
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func(&poly_A[i], &poly_B[i]);

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let eval_point_2 = comb_func(&poly_A_bound_point, &poly_B_bound_point);
(eval_point_0, eval_point_2)
})
.reduce(
|| (G::Scalar::ZERO, G::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
);
let (eval_point_0, eval_point_2) = Self::compute_eval_points(poly_A, poly_B, &comb_func);

let evals = vec![eval_point_0, claim_per_round - eval_point_0, eval_point_2];
UniPoly::from_evals(&evals)
Expand Down Expand Up @@ -136,7 +146,7 @@ impl<G: Group> SumcheckProof<G> {
transcript: &mut G::TE,
) -> Result<(Self, Vec<G::Scalar>, (Vec<G::Scalar>, Vec<G::Scalar>)), NovaError>
where
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar,
F: Fn(&G::Scalar, &G::Scalar) -> G::Scalar + Sync,
{
let mut e = *claim;
let mut r: Vec<G::Scalar> = Vec::new();
Expand All @@ -146,20 +156,7 @@ impl<G: Group> SumcheckProof<G> {
let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new();

for (poly_A, poly_B) in poly_A_vec.iter().zip(poly_B_vec.iter()) {
let mut eval_point_0 = G::Scalar::ZERO;
let mut eval_point_2 = G::Scalar::ZERO;

let len = poly_A.len() / 2;
for i in 0..len {
// eval 0: bound_func is A(low)
eval_point_0 += comb_func(&poly_A[i], &poly_B[i]);

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
eval_point_2 += comb_func(&poly_A_bound_point, &poly_B_bound_point);
}

let (eval_point_0, eval_point_2) = Self::compute_eval_points(poly_A, poly_B, &comb_func);
evals.push((eval_point_0, eval_point_2));
}

Expand Down
10 changes: 3 additions & 7 deletions src/traits/commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use crate::{
};
use core::{
fmt::Debug,
ops::{Add, AddAssign, Mul, MulAssign},
ops::{Add, AddAssign},
};
use serde::{Deserialize, Serialize};

use super::ScalarMul;

/// Defines basic operations on commitments
pub trait CommitmentOps<Rhs = Self, Output = Self>:
Add<Rhs, Output = Output> + AddAssign<Rhs>
Expand All @@ -31,12 +33,6 @@ impl<T, Rhs, Output> CommitmentOpsOwned<Rhs, Output> for T where
{
}

/// A helper trait for types implementing a multiplication of a commitment with a scalar
pub trait ScalarMul<Rhs, Output = Self>: Mul<Rhs, Output = Output> + MulAssign<Rhs> {}

impl<T, Rhs, Output> ScalarMul<Rhs, Output> for T where T: Mul<Rhs, Output = Output> + MulAssign<Rhs>
{}

/// This trait defines the behavior of the commitment
pub trait CommitmentTrait<G: Group>:
Clone
Expand Down
8 changes: 3 additions & 5 deletions src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,9 @@ pub trait PrimeFieldExt: PrimeField {

impl<G: Group, T: TranscriptReprTrait<G>> TranscriptReprTrait<G> for &[T] {
fn to_transcript_bytes(&self) -> Vec<u8> {
(0..self.len())
.map(|i| self[i].to_transcript_bytes())
.collect::<Vec<_>>()
.into_iter()
.flatten()
self
.iter()
.flat_map(|t| t.to_transcript_bytes())
.collect::<Vec<u8>>()
}
}
Expand Down

0 comments on commit a62bccf

Please sign in to comment.