Skip to content

Commit

Permalink
Avoid memory copy in z_mat
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunming Jiang committed Dec 18, 2024
1 parent 6edc0b0 commit 93e30d1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 36 deletions.
2 changes: 1 addition & 1 deletion spartan_parallel/src/custom_dense_mlpoly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn rev_bits(q: usize, max_num_proofs: usize) -> usize {
}

impl<S: SpartanExtensionField> DensePolynomialPqx<S> {
// Assume z_mat is of form (p, q_rev, x), construct DensePoly
// Assume z_mat is of form (p, q_rev, x_rev), construct DensePoly
pub fn new(
z_mat: Vec<Vec<Vec<Vec<S>>>>,
num_proofs: Vec<usize>,
Expand Down
18 changes: 3 additions & 15 deletions spartan_parallel/src/r1csinstance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSInstance<S> {

Az[p] = (0..num_proofs[p])
.into_par_iter()
.map(|q_rev| {
// Reverse the bits of q
let q_step = max_num_proofs / num_proofs[p];
let q = rev_bits(q_rev * q_step, max_num_proofs);

.map(|q| {
vec![self.A_list[p_inst].multiply_vec_disjoint_rounds(
max_num_cons,
num_cons[p_inst].clone(),
Expand All @@ -267,11 +263,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSInstance<S> {
.collect();
Bz[p] = (0..num_proofs[p])
.into_par_iter()
.map(|q_rev| {
// Reverse the bits of q
let q_step = max_num_proofs / num_proofs[p];
let q = rev_bits(q_rev * q_step, max_num_proofs);

.map(|q| {
vec![self.B_list[p_inst].multiply_vec_disjoint_rounds(
max_num_cons,
num_cons[p_inst].clone(),
Expand All @@ -283,11 +275,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSInstance<S> {
.collect();
Cz[p] = (0..num_proofs[p])
.into_par_iter()
.map(|q_rev| {
// Reverse the bits of q
let q_step = max_num_proofs / num_proofs[p];
let q = rev_bits(q_rev * q_step, max_num_proofs);

.map(|q| {
vec![self.C_list[p_inst].multiply_vec_disjoint_rounds(
max_num_cons,
num_cons[p_inst].clone(),
Expand Down
43 changes: 25 additions & 18 deletions spartan_parallel/src/r1csproof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::r1csinstance::R1CSInstance;
use super::sumcheck::SumcheckInstanceProof;
use super::timer::Timer;
use super::transcript::ProofTranscript;
use crate::custom_dense_mlpoly::rev_bits;
use crate::scalar::SpartanExtensionField;
use crate::{ProverWitnessSecInfo, VerifierWitnessSecInfo};
use merlin::Transcript;
Expand Down Expand Up @@ -178,22 +179,28 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {

// append input to variables to create a single vector z
let timer_tmp = Timer::new("prove_z_mat_gen");
let mut z_mat: Vec<Vec<Vec<Vec<S>>>> = Vec::new();
for p in 0..num_instances {
z_mat.push(Vec::new());
for q in 0..num_proofs[p] {
z_mat[p].push(vec![vec![S::field_zero(); num_inputs[p]]; num_witness_secs]);
for w in 0..witness_secs.len() {
let ws = witness_secs[w];
let p_w = if ws.w_mat.len() == 1 { 0 } else { p };
let q_w = if ws.w_mat[p_w].len() == 1 { 0 } else { q };
// Only append the first num_inputs_entries of w_mat[p][q]
for i in 0..min(ws.num_inputs[p_w], num_inputs[p]) {
z_mat[p][q][w][i] = ws.w_mat[p_w][q_w][i];
let z_mat_rev = {
let mut z_mat: Vec<Vec<Vec<Vec<S>>>> = Vec::new();
for p in 0..num_instances {
z_mat.push(vec![vec![vec![S::field_zero(); num_inputs[p]]; num_witness_secs]; num_proofs[p]]);
let q_step = max_num_proofs / num_proofs[p];
for q in 0..num_proofs[p] {
let q_rev = rev_bits(q, max_num_proofs) / q_step;
for w in 0..witness_secs.len() {
let ws = witness_secs[w];
let p_w = if ws.w_mat.len() == 1 { 0 } else { p };
let q_w = if ws.w_mat[p_w].len() == 1 { 0 } else { q };
let y_step = max_num_inputs / num_inputs[p];
// Only append the first num_inputs_entries of w_mat[p][q]
for i in 0..min(ws.num_inputs[p_w], num_inputs[p]) {
let y_rev = rev_bits(i, max_num_inputs) / y_step;
z_mat[p][q_rev][w][y_rev] = ws.w_mat[p_w][q_w][i];
}
}
}
}
}
z_mat
};
timer_tmp.stop();

// derive the verifier's challenge \tau
Expand Down Expand Up @@ -221,7 +228,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
max_num_inputs,
num_cons,
block_num_cons.clone(),
&z_mat,
&z_mat_rev,
);
timer_tmp.stop();

Expand Down Expand Up @@ -252,7 +259,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
timer_tmp.stop();
timer_sc_proof_phase1.stop();

let (tau_claim, Az_claim, Bz_claim, Cz_claim) = (
let (_tau_claim, Az_claim, Bz_claim, Cz_claim) = (
&(poly_tau_p[0] * poly_tau_q[0] * poly_tau_x[0]),
&poly_Az.index(0, 0, 0, 0),
&poly_Bz.index(0, 0, 0, 0),
Expand Down Expand Up @@ -320,8 +327,8 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {

let timer_tmp = Timer::new("prove_z_gen");
// Construct a p * q * len(z) matrix Z and bound it to r_q
let mut Z_poly = DensePolynomialPqx::new_rev(
&z_mat,
let mut Z_poly = DensePolynomialPqx::new(
z_mat_rev,
num_proofs.clone(),
max_num_proofs,
num_inputs.clone(),
Expand Down Expand Up @@ -586,7 +593,7 @@ impl<S: SpartanExtensionField + Send + Sync> R1CSProof<S> {
S::append_field_to_transcript(b"Cz_claim", transcript, Cz_claim);

// debug_zk
// assert_eq!(taus_bound_rx * (Az_claim * Bz_claim - Cz_claim), claim_post_phase_1);
assert_eq!(taus_bound_rx * (Az_claim * Bz_claim - Cz_claim), claim_post_phase_1);

// derive three public challenges and then derive a joint claim
let r_A: S = transcript.challenge_scalar(b"challenge_Az");
Expand Down
9 changes: 7 additions & 2 deletions spartan_parallel/src/sparse_mlpoly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ impl<S: SpartanExtensionField> SparseMatPolynomial<S> {
max_num_rows: usize,
num_rows: usize,
max_num_cols: usize,
_num_cols: usize,
num_cols: usize,
z: &Vec<Vec<S>>
) -> Vec<S> {
let step_r = max_num_rows / num_rows;
Expand All @@ -417,7 +417,12 @@ impl<S: SpartanExtensionField> SparseMatPolynomial<S> {
let row = self.M[i].row;
let col = self.M[i].col;
let val = self.M[i].val.clone();
(row, val * z[col / max_num_cols][col % max_num_cols])
let w = col / max_num_cols;
let y = col % max_num_cols;
// Z expresses y in reverse bits order, so have to find the correct y
let y_step = max_num_cols / num_cols;
let y_rev = rev_bits(y, max_num_cols) / y_step;
(row, val * z[w][y_rev])
})
.fold(vec![S::field_zero(); num_rows], |mut Mz, (r, v)| {
// Reverse the bits of r. r_rev is a multiple of step_r
Expand Down

0 comments on commit 93e30d1

Please sign in to comment.