From 96c1783f92c4173230a27f7d3054a1de7d919a3c Mon Sep 17 00:00:00 2001 From: enpsi Date: Tue, 5 Nov 2024 14:14:56 +0100 Subject: [PATCH 01/14] feat: SIMD support for GKR2 prover --- gkr/src/prover/gkr_square.rs | 56 +++++++++++-- gkr/src/prover/linear_gkr.rs | 3 +- sumcheck/src/prover_helper/simd_gate.rs | 81 +++++++++++++++++++ .../src/prover_helper/sumcheck_gkr_square.rs | 72 ++++++++++++++--- sumcheck/src/sumcheck.rs | 53 ++++++------ 5 files changed, 227 insertions(+), 38 deletions(-) diff --git a/gkr/src/prover/gkr_square.rs b/gkr/src/prover/gkr_square.rs index 761eaf3f..0a34f8b5 100644 --- a/gkr/src/prover/gkr_square.rs +++ b/gkr/src/prover/gkr_square.rs @@ -1,9 +1,12 @@ // an implementation of the GKR^2 protocol //! This module implements the core GKR^2 IOP. +use arith::{Field, SimdField}; use ark_std::{end_timer, start_timer}; use circuit::Circuit; use gkr_field_config::GKRFieldConfig; +use mpi_config::MPIConfig; +use polynomials::MultiLinearPoly; use sumcheck::{sumcheck_prove_gkr_square_layer, ProverScratchPad}; use transcript::Transcript; @@ -11,7 +14,12 @@ pub fn gkr_square_prove>( circuit: &Circuit, sp: &mut ProverScratchPad, transcript: &mut T, -) -> (C::Field, Vec) { + mpi_config: &MPIConfig, +) -> ( + C::ChallengeField, + Vec, + Vec, +) { let timer = start_timer!(|| "gkr^2 prove"); let layer_num = circuit.layers.len(); @@ -20,11 +28,49 @@ pub fn gkr_square_prove>( rz0.push(transcript.generate_challenge_field_element()); } - let circuit_output = &circuit.layers.last().unwrap().output_vals; - let claimed_v = C::eval_circuit_vals_at_challenge(circuit_output, &rz0, &mut sp.hg_evals); + let mut r_simd = vec![]; + for _i in 0..C::get_field_pack_size().trailing_zeros() { + r_simd.push(transcript.generate_challenge_field_element()); + } + log::trace!("Initial r_simd: {:?}", r_simd); + + // TODO: MPI support + assert_eq!( + mpi_config.world_size().trailing_zeros(), + 0, + "MPI not supported yet" + ); + let mut r_mpi = vec![]; + for _ in 0..mpi_config.world_size().trailing_zeros() { + r_mpi.push(transcript.generate_challenge_field_element()); + } + + let output_vals = &circuit.layers.last().unwrap().output_vals; + let claimed_v_simd = C::eval_circuit_vals_at_challenge(output_vals, &rz0, &mut sp.hg_evals); + let claimed_v_local = MultiLinearPoly::::evaluate_with_buffer( + &claimed_v_simd.unpack(), + &r_simd, + &mut sp.eq_evals_at_r_simd0, + ); + + let claimed_v = if mpi_config.is_root() { + let mut claimed_v_gathering_buffer = + vec![C::ChallengeField::zero(); mpi_config.world_size()]; + mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer); + MultiLinearPoly::evaluate_with_buffer( + &claimed_v_gathering_buffer, + &r_mpi, + &mut sp.eq_evals_at_r_mpi0, + ) + } else { + mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]); + C::ChallengeField::zero() + }; + log::trace!("Claimed v: {:?}", claimed_v); for i in (0..layer_num).rev() { - rz0 = sumcheck_prove_gkr_square_layer(&circuit.layers[i], &rz0, transcript, sp); + (rz0, r_simd) = + sumcheck_prove_gkr_square_layer(&circuit.layers[i], &rz0, &r_simd, transcript, sp); log::trace!("Layer {} proved", i); log::trace!("rz0.0: {:?}", rz0[0]); @@ -33,5 +79,5 @@ pub fn gkr_square_prove>( } end_timer!(timer); - (claimed_v, rz0) + (claimed_v, rz0, r_simd) } diff --git a/gkr/src/prover/linear_gkr.rs b/gkr/src/prover/linear_gkr.rs index 5aca67f4..70795a5e 100644 --- a/gkr/src/prover/linear_gkr.rs +++ b/gkr/src/prover/linear_gkr.rs @@ -115,7 +115,8 @@ impl Prover { let mut rmpi = vec![]; if self.config.gkr_scheme == GKRScheme::GkrSquare { - (_, rx) = gkr_square_prove(c, &mut self.sp, &mut transcript); + (claimed_v, rx, rsimd) = + gkr_square_prove(c, &mut self.sp, &mut transcript, &self.config.mpi_config); } else { (claimed_v, rx, ry, rsimd, rmpi) = gkr_prove(c, &mut self.sp, &mut transcript, &self.config.mpi_config); diff --git a/sumcheck/src/prover_helper/simd_gate.rs b/sumcheck/src/prover_helper/simd_gate.rs index 11d8a292..049e75b6 100644 --- a/sumcheck/src/prover_helper/simd_gate.rs +++ b/sumcheck/src/prover_helper/simd_gate.rs @@ -82,6 +82,87 @@ impl SumcheckSimdProdGateHelper { [p0, p1, p2, p3] } + /// Evaluate the GKR2 sumcheck polynomial at a SIMD variable, + /// after x-sumcheck rounds have fixed the x variables. The + /// polynomial is degree (D-1) in the SIMD variables. + pub(crate) fn gkr2_poly_eval_at( + &self, + var_idx: usize, + bk_eq: &[C::ChallengeField], + bk_v_simd: &[C::ChallengeField], + add_eval: C::ChallengeField, + pow_5_eval: C::ChallengeField, + ) -> [C::ChallengeField; D] { + let mut p = [C::ChallengeField::zero(); D]; + let mut p_add = [C::ChallengeField::zero(); 3]; + let eval_size = 1 << (self.var_num - var_idx - 1); + + for i in 0..eval_size { + // witness polynomial along current variable + let mut f_v = [C::ChallengeField::zero(); D]; + // eq polynomial along current variable + let mut eq_v = [C::ChallengeField::zero(); D]; + f_v[0] = bk_v_simd[i * 2]; + f_v[1] = bk_v_simd[i * 2 + 1]; + eq_v[0] = bk_eq[i * 2]; + eq_v[1] = bk_eq[i * 2 + 1]; + + // Evaluate term eq(A, r_z) * Pow5(r_z, r_x) * V(A, r_x)^5 + let delta_f = f_v[1] - f_v[0]; + let delta_eq = eq_v[1] - eq_v[0]; + for i in 2..D { + f_v[i] = f_v[i - 1] + delta_f; + eq_v[i] = eq_v[i - 1] + delta_eq; + } + for i in 0..D { + let pow5 = f_v[i].square().square() * f_v[i]; + p[i] += pow_5_eval * pow5 * eq_v[i]; + } + + // Evaluate term eq(A, r_z) * Add(r_z, r_x) * V(A, r_x) + p_add[0] += add_eval * f_v[0] * eq_v[0]; + p_add[1] += add_eval * f_v[1] * eq_v[1]; + // Intermediate term for p_add[2] + p_add[2] += add_eval * (f_v[0] + f_v[1]) * (eq_v[0] + eq_v[1]); + } + p_add[2] = p_add[1].mul_by_6() + p_add[0].mul_by_3() - p_add[2].double(); + + // Interpolate p_add into 7 points, add to p + Self::interpolate_3::(&p_add, &mut p); + p + } + + // Function to interpolate a quadratic polynomial and update an array of points + fn interpolate_3( + p_add: &[C::ChallengeField; 3], + p: &mut [C::ChallengeField; D], + ) { + // Calculate coefficients for the interpolating polynomial + let p_add_coef_0 = p_add[0]; + let p_add_coef_2 = C::challenge_mul_circuit_field( + &(p_add[2] - p_add[1] - p_add[1] + p_add[0]), + &C::CircuitField::INV_2, + ); + + let p_add_coef_1 = p_add[1] - p_add_coef_0 - p_add_coef_2; + + // Update the p array by evaluating the interpolated polynomial at different points + // and adding the results to the existing values + p[0] += p_add_coef_0; + p[1] += p_add_coef_0 + p_add_coef_1 + p_add_coef_2; + p[2] += p_add_coef_0 + p_add_coef_1.double() + p_add_coef_2.double().double(); + p[3] += p_add_coef_0 + p_add_coef_1.mul_by_3() + p_add_coef_2.mul_by_3().mul_by_3(); + p[4] += p_add_coef_0 + + p_add_coef_1.double().double() + + C::challenge_mul_circuit_field(&p_add_coef_2, &C::CircuitField::from(16)); + p[5] += p_add_coef_0 + + p_add_coef_1.mul_by_5() + + C::challenge_mul_circuit_field(&p_add_coef_2, &C::CircuitField::from(25)); + p[6] += p_add_coef_0 + + p_add_coef_1.mul_by_3().double() + + C::challenge_mul_circuit_field(&p_add_coef_2, &C::CircuitField::from(36)); + } + #[inline] pub(crate) fn receive_challenge( &mut self, diff --git a/sumcheck/src/prover_helper/sumcheck_gkr_square.rs b/sumcheck/src/prover_helper/sumcheck_gkr_square.rs index b6075491..d6de7b14 100644 --- a/sumcheck/src/prover_helper/sumcheck_gkr_square.rs +++ b/sumcheck/src/prover_helper/sumcheck_gkr_square.rs @@ -1,24 +1,27 @@ -use arith::Field; +use crate::{unpack_and_combine, ProverScratchPad}; +use arith::{Field, SimdField}; use circuit::CircuitLayer; use gkr_field_config::GKRFieldConfig; use polynomials::EqPolynomial; -use crate::ProverScratchPad; - -use super::power_gate::SumcheckPowerGateHelper; +use super::{power_gate::SumcheckPowerGateHelper, simd_gate::SumcheckSimdProdGateHelper}; // todo: Move D to GKRFieldConfig pub(crate) struct SumcheckGkrSquareHelper<'a, C: GKRFieldConfig, const D: usize> { pub(crate) rx: Vec, + pub(crate) r_simd_var: Vec, layer: &'a CircuitLayer, sp: &'a mut ProverScratchPad, rz0: &'a [C::ChallengeField], + r_simd: &'a [C::ChallengeField], _input_var_num: usize, _output_var_num: usize, + pub(crate) simd_var_num: usize, x_helper: SumcheckPowerGateHelper, + simd_helper: SumcheckSimdProdGateHelper, } impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { @@ -26,25 +29,32 @@ impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { pub(crate) fn new( layer: &'a CircuitLayer, rz0: &'a [C::ChallengeField], + r_simd: &'a [C::ChallengeField], sp: &'a mut ProverScratchPad, ) -> Self { + let simd_var_num = C::get_field_pack_size().trailing_zeros() as usize; + SumcheckGkrSquareHelper { rx: vec![], + r_simd_var: vec![], layer, sp, rz0, + r_simd, _input_var_num: layer.input_var_num, _output_var_num: layer.output_var_num, + simd_var_num, x_helper: SumcheckPowerGateHelper::new(layer.input_var_num), + simd_helper: SumcheckSimdProdGateHelper::new(simd_var_num), } } #[inline] - pub(crate) fn poly_evals_at(&self, var_idx: usize) -> [C::Field; D] { - self.x_helper.poly_eval_at::( + pub(crate) fn poly_evals_at_x(&self, var_idx: usize) -> [C::ChallengeField; D] { + let evals = self.x_helper.poly_eval_at::( var_idx, &self.sp.v_evals, &self.sp.hg_evals_5, @@ -52,11 +62,27 @@ impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { &self.layer.input_vals, &self.sp.gate_exists_5, &self.sp.gate_exists_1, + ); + let mut simd_combined = [C::ChallengeField::zero(); D]; + for (combined, simd_val) in simd_combined.iter_mut().zip(evals.iter()) { + *combined = unpack_and_combine(simd_val, &self.sp.eq_evals_at_r_simd0); + } + simd_combined + } + + #[inline] + pub(crate) fn poly_evals_at_simd(&self, var_idx: usize) -> [C::ChallengeField; D] { + self.simd_helper.gkr2_poly_eval_at::( + var_idx, + &self.sp.eq_evals_at_r_simd0, + &self.sp.simd_var_v_evals, + self.sp.hg_evals_1[0], + self.sp.hg_evals_5[0], ) } #[inline] - pub(crate) fn receive_challenge(&mut self, var_idx: usize, r: C::ChallengeField) { + pub(crate) fn receive_x_challenge(&mut self, var_idx: usize, r: C::ChallengeField) { self.x_helper.receive_challenge::( var_idx, r, @@ -71,9 +97,37 @@ impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { self.rx.push(r); } + #[inline] + pub(crate) fn receive_simd_challenge(&mut self, var_idx: usize, r: C::ChallengeField) { + self.simd_helper.receive_challenge::( + var_idx, + r, + &mut self.sp.eq_evals_at_r_simd0, + &mut self.sp.simd_var_v_evals, + &mut self.sp.simd_var_hg_evals, + ); + self.r_simd_var.push(r); + } + #[inline(always)] - pub(crate) fn vx_claim(&self) -> C::Field { - self.sp.v_evals[0] + pub(crate) fn vx_claim(&self) -> C::ChallengeField { + self.sp.simd_var_v_evals[0] + } + + #[inline] + pub(crate) fn prepare_simd(&mut self) { + EqPolynomial::::eq_eval_at( + self.r_simd, + &C::ChallengeField::one(), + &mut self.sp.eq_evals_at_r_simd0, + &mut self.sp.eq_evals_first_half, + &mut self.sp.eq_evals_second_half, + ); + } + + #[inline] + pub(crate) fn prepare_simd_var_vals(&mut self) { + self.sp.simd_var_v_evals = self.sp.v_evals[0].unpack(); } #[inline] diff --git a/sumcheck/src/sumcheck.rs b/sumcheck/src/sumcheck.rs index 9ad374c8..13c2356e 100644 --- a/sumcheck/src/sumcheck.rs +++ b/sumcheck/src/sumcheck.rs @@ -1,4 +1,3 @@ -use arith::FieldSerde; use circuit::CircuitLayer; use gkr_field_config::GKRFieldConfig; use mpi_config::MPIConfig; @@ -99,41 +98,49 @@ pub fn sumcheck_prove_gkr_layer>( layer: &CircuitLayer, rz0: &[C::ChallengeField], + r_simd: &[C::ChallengeField], transcript: &mut T, sp: &mut ProverScratchPad, -) -> Vec { +) -> (Vec, Vec) { const D: usize = 7; - let mut helper = SumcheckGkrSquareHelper::new(layer, rz0, sp); + let mut helper = SumcheckGkrSquareHelper::new(layer, rz0, r_simd, sp); + + helper.prepare_simd(); + helper.prepare_g_x_vals(); + // x-variable sumcheck rounds for i_var in 0..layer.input_var_num { - if i_var == 0 { - helper.prepare_g_x_vals(); - } - let evals: [C::Field; D] = helper.poly_evals_at(i_var); + let evals: [C::ChallengeField; D] = helper.poly_evals_at_x(i_var); for deg in 0..D { - let mut buf = vec![]; - evals[deg].serialize_into(&mut buf).unwrap(); - transcript.append_u8_slice(&buf); + transcript.append_field_element(&evals[deg]); } - let r = transcript.generate_challenge_field_element(); - log::trace!("i_var={} evals: {:?} r: {:?}", i_var, evals, r); + log::trace!("x i_var={} evals: {:?} r: {:?}", i_var, evals, r); + + helper.receive_x_challenge(i_var, r); + } - helper.receive_challenge(i_var, r); - if i_var == layer.input_var_num - 1 { - log::trace!("vx claim: {:?}", helper.vx_claim()); - let mut buf = vec![]; - helper.vx_claim().serialize_into(&mut buf).unwrap(); - transcript.append_u8_slice(&buf); + // Unpack SIMD witness polynomial evaluations + helper.prepare_simd_var_vals(); + + // SIMD-variable sumcheck rounds + for i_var in 0..helper.simd_var_num { + let evals = helper.poly_evals_at_simd(i_var); + + for deg in 0..D { + transcript.append_field_element(&evals[deg]); } + let r = transcript.generate_challenge_field_element(); + + log::trace!("SIMD i_var={} evals: {:?} r: {:?}", i_var, evals, r); + + helper.receive_simd_challenge(i_var, r); } - log::trace!("claimed vx = {:?}", helper.vx_claim()); - let mut buf = vec![]; - helper.vx_claim().serialize_into(&mut buf).unwrap(); - transcript.append_u8_slice(&buf); + log::trace!("vx claim: {:?}", helper.vx_claim()); + transcript.append_field_element(&helper.vx_claim()); - helper.rx + (helper.rx, helper.r_simd_var) } From 0f866ab924d019f122277d07d4b0c756ad39c1af Mon Sep 17 00:00:00 2001 From: enpsi Date: Tue, 5 Nov 2024 14:52:06 +0100 Subject: [PATCH 02/14] wip: gkr2 correctness test --- gkr/src/tests/gkr_correctness.rs | 167 ++++++++++++++++++++++++++++++- 1 file changed, 165 insertions(+), 2 deletions(-) diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index 939ee7c7..8ee993f6 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -3,8 +3,8 @@ use std::panic::AssertUnwindSafe; use std::time::Instant; use std::{fs, panic}; -use arith::{Field, FieldSerde}; -use circuit::Circuit; +use arith::{Field, FieldSerde, SimdField}; +use circuit::{Circuit, CircuitLayer, CoefType, GateUni}; use config::{Config, FiatShamirHashType, GKRConfig, GKRScheme, PolynomialCommitmentType}; use config_macros::declare_gkr_config; use gkr_field_config::{BN254Config, FieldType, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; @@ -244,3 +244,166 @@ fn test_gkr_correctness_helper(config: &Config, write_proof println!("============== end ==============="); } } + +/// A simple GKR2 test circuit: +/// ```text +/// N_0_0 N_0_1 Layer 0 (Output) +/// x11 / \ / | \ +/// N_1_0 N_1_1 N_1_2 N_1_3 Layer 1 +/// | | / | | +/// Pow5| | / | | +/// N_2_0 N_2_1 N_2_2 N_2_3 Layer 2 (Input) +/// ``` +/// (Unmarked lines are `+` gates with coeff 1) +pub fn gkr_square_test_circuit() -> Circuit { + let mut circuit = Circuit::default(); + + // Layer 1 + let mut l1 = CircuitLayer { + input_var_num: 2, + output_var_num: 2, + ..Default::default() + }; + // N_1_0 += (N_2_0)^5 + l1.uni.push(GateUni { + i_ids: [0], + o_id: 0, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12345, + }); + + // N_1_1 += N_2_1 + l1.uni.push(GateUni { + i_ids: [1], + o_id: 1, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_1_2 += N_2_1 + l1.uni.push(GateUni { + i_ids: [1], + o_id: 2, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_1_2 += N_2_2 + l1.uni.push(GateUni { + i_ids: [2], + o_id: 2, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_1_3 += N_2_3 + l1.uni.push(GateUni { + i_ids: [3], + o_id: 3, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + circuit.layers.push(l1); + + // Output layer + let mut output_layer = CircuitLayer { + input_var_num: 2, + output_var_num: 1, + ..Default::default() + }; + // N_0_0 += 11 * N_1_0 + output_layer.uni.push(GateUni { + i_ids: [0], + o_id: 0, + coef: C::CircuitField::from(11), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_0_0 += N_1_1 + output_layer.uni.push(GateUni { + i_ids: [1], + o_id: 0, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_0_1 += N_1_1 + output_layer.uni.push(GateUni { + i_ids: [1], + o_id: 1, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_0_1 += N_1_2 + output_layer.uni.push(GateUni { + i_ids: [2], + o_id: 1, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_0_1 += N_1_3 + output_layer.uni.push(GateUni { + i_ids: [3], + o_id: 1, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + circuit.layers.push(output_layer); + + circuit.identify_rnd_coefs(); + circuit +} + +#[test] +fn gkr_square_correctness() { + declare_gkr_config!( + GkrConfigType, + FieldType::M31, + FiatShamirHashType::SHA256, + PolynomialCommitmentType::Raw + ); + env_logger::init(); + type GkrFieldConfigType = ::FieldConfig; + + let mut circuit = gkr_square_test_circuit::(); + // Set input layers with N_2_0 = 3, N_2_1 = 5, N_2_2 = 7, + // and N_2_3 varying from 0 to 15 + let final_vals = (0..16).map(|x| x.into()).collect::>(); + let final_vals = ::SimdCircuitField::pack(&final_vals); + circuit.layers[0].input_vals = vec![2.into(), 3.into(), 5.into(), final_vals]; + + let config = Config::::new(GKRScheme::GkrSquare, MPIConfig::default()); + do_prove_verify(config, &mut circuit); +} + +fn do_prove_verify(config: Config, circuit: &mut Circuit) { + circuit.evaluate(); + + let (pcs_params, pcs_proving_key, pcs_verification_key, mut pcs_scratch) = + expander_pcs_init_testing_only::( + circuit.log_input_size(), + &config.mpi_config, + ); + + // Prove + let mut prover = Prover::new(&config); + prover.prepare_mem(&circuit); + let (claimed_v, proof) = prover.prove(circuit, &pcs_params, &pcs_proving_key, &mut pcs_scratch); + + // Verify + let verifier = Verifier::new(&config); + let public_input = vec![]; + assert!(verifier.verify( + circuit, + &public_input, + &claimed_v, + &pcs_params, + &pcs_verification_key, + &proof + )) +} From 38983bef089189c871c27d01a19495383b26737d Mon Sep 17 00:00:00 2001 From: enpsi Date: Tue, 5 Nov 2024 18:57:36 +0100 Subject: [PATCH 03/14] feat: GKR2 verifier --- gkr/src/verifier.rs | 122 ++++++++++++++++-------- gkr/src/verifier/gkr_square.rs | 163 ++++++++++++++++++++++++++++++++ sumcheck/src/scratch_pad.rs | 31 ++++++ sumcheck/src/verifier_helper.rs | 54 ++++++++++- 4 files changed, 325 insertions(+), 45 deletions(-) create mode 100644 gkr/src/verifier/gkr_square.rs diff --git a/gkr/src/verifier.rs b/gkr/src/verifier.rs index 712ba76e..f22572cb 100644 --- a/gkr/src/verifier.rs +++ b/gkr/src/verifier.rs @@ -6,7 +6,7 @@ use std::{ use arith::{Field, FieldSerde}; use ark_std::{end_timer, start_timer}; use circuit::{Circuit, CircuitLayer}; -use config::{Config, GKRConfig}; +use config::{Config, GKRConfig, GKRScheme}; use gkr_field_config::GKRFieldConfig; use mpi_config::MPIConfig; use poly_commit::{ExpanderGKRChallenge, PCSForExpanderGKR, StructuredReferenceString}; @@ -16,6 +16,9 @@ use transcript::{Proof, Transcript}; #[cfg(feature = "grinding")] use crate::grind; +mod gkr_square; +pub use gkr_square::gkr_square_verify; + #[inline(always)] fn verify_sumcheck_step>( mut proof_reader: impl Read, @@ -40,6 +43,10 @@ fn verify_sumcheck_step>( *claimed_sum = GKRVerifierHelper::degree_2_eval(&ps, r, sp); } else if degree == 3 { *claimed_sum = GKRVerifierHelper::degree_3_eval(&ps, r, sp); + } else if degree == 6 { + *claimed_sum = GKRVerifierHelper::degree_6_eval(&ps, r, sp); + } else { + panic!("unsupported degree"); } verified @@ -296,46 +303,79 @@ impl Verifier { circuit.fill_rnd_coefs(&mut transcript); - let (mut verified, rz0, rz1, r_simd, r_mpi, claimed_v0, claimed_v1) = gkr_verify( - &self.config.mpi_config, - circuit, - public_input, - claimed_v, - &mut transcript, - &mut cursor, - ); - - log::info!("GKR verification: {}", verified); - - verified &= self.get_pcs_opening_from_proof_and_verify( - pcs_params, - pcs_verification_key, - &commitment, - &ExpanderGKRChallenge { - x: rz0, - x_simd: r_simd.clone(), - x_mpi: r_mpi.clone(), - }, - &claimed_v0, - &mut transcript, - &mut cursor, - ); - - if let Some(rz1) = rz1 { - verified &= self.get_pcs_opening_from_proof_and_verify( - pcs_params, - pcs_verification_key, - &commitment, - &ExpanderGKRChallenge { - x: rz1, - x_simd: r_simd, - x_mpi: r_mpi, - }, - &claimed_v1.unwrap(), - &mut transcript, - &mut cursor, - ); - } + let verified = match self.config.gkr_scheme { + GKRScheme::Vanilla => { + let (mut verified, rz0, rz1, r_simd, r_mpi, claimed_v0, claimed_v1) = gkr_verify( + &self.config.mpi_config, + circuit, + public_input, + claimed_v, + &mut transcript, + &mut cursor, + ); + + log::info!("GKR verification: {}", verified); + + verified &= self.get_pcs_opening_from_proof_and_verify( + pcs_params, + pcs_verification_key, + &commitment, + &ExpanderGKRChallenge { + x: rz0, + x_simd: r_simd.clone(), + x_mpi: r_mpi.clone(), + }, + &claimed_v0, + &mut transcript, + &mut cursor, + ); + + if let Some(rz1) = rz1 { + verified &= self.get_pcs_opening_from_proof_and_verify( + pcs_params, + pcs_verification_key, + &commitment, + &ExpanderGKRChallenge { + x: rz1, + x_simd: r_simd, + x_mpi: r_mpi, + }, + &claimed_v1.unwrap(), + &mut transcript, + &mut cursor, + ); + } + + verified + } + GKRScheme::GkrSquare => { + let (mut verified, rz, r_simd, r_mpi, claimed_v) = gkr_square_verify( + &self.config.mpi_config, + circuit, + public_input, + claimed_v, + &mut transcript, + &mut cursor, + ); + + log::info!("GKR verification: {}", verified); + + verified &= self.get_pcs_opening_from_proof_and_verify( + pcs_params, + pcs_verification_key, + &commitment, + &ExpanderGKRChallenge { + x: rz, + x_simd: r_simd.clone(), + x_mpi: r_mpi.clone(), + }, + &claimed_v, + &mut transcript, + &mut cursor, + ); + verified + } + }; end_timer!(timer); diff --git a/gkr/src/verifier/gkr_square.rs b/gkr/src/verifier/gkr_square.rs new file mode 100644 index 00000000..7d4bc685 --- /dev/null +++ b/gkr/src/verifier/gkr_square.rs @@ -0,0 +1,163 @@ +use super::verify_sumcheck_step; +use arith::{Field, FieldSerde}; +use ark_std::{end_timer, start_timer}; +use circuit::{Circuit, CircuitLayer}; +use gkr_field_config::GKRFieldConfig; +use mpi_config::MPIConfig; +use std::{io::Read, vec}; +use sumcheck::{GKRVerifierHelper, VerifierScratchPad}; +use transcript::Transcript; + +#[allow(clippy::type_complexity)] +pub fn gkr_square_verify>( + mpi_config: &MPIConfig, + circuit: &Circuit, + public_input: &[C::SimdCircuitField], + claimed_v: &C::ChallengeField, + transcript: &mut T, + mut proof_reader: impl Read, +) -> ( + bool, + Vec, + Vec, + Vec, + C::ChallengeField, +) { + let timer = start_timer!(|| "gkr verify"); + let mut sp = VerifierScratchPad::::new(circuit, mpi_config.world_size()); + + let layer_num = circuit.layers.len(); + let mut rz = vec![]; + let mut r_simd = vec![]; + let mut r_mpi = vec![]; + + for _ in 0..circuit.layers.last().unwrap().output_var_num { + rz.push(transcript.generate_challenge_field_element()); + } + log::trace!("rz {:?}", rz); + + for _ in 0..C::get_field_pack_size().trailing_zeros() { + r_simd.push(transcript.generate_challenge_field_element()); + } + log::trace!("r_simd {:?}", r_simd); + + // TODO: MPI support + assert_eq!( + mpi_config.world_size().trailing_zeros(), + 0, + "MPI not supported yet" + ); + for _ in 0..mpi_config.world_size().trailing_zeros() { + r_mpi.push(transcript.generate_challenge_field_element()); + } + + let mut verified = true; + let mut current_claim = *claimed_v; + log::trace!("Starting claim: {:?}", current_claim); + for i in (0..layer_num).rev() { + let cur_verified; + (cur_verified, rz, r_simd, r_mpi, current_claim) = sumcheck_verify_gkr_square_layer( + mpi_config, + &circuit.layers[i], + public_input, + &rz, + &r_simd, + &r_mpi, + current_claim, + &mut proof_reader, + transcript, + &mut sp, + i == layer_num - 1, + ); + verified &= cur_verified; + } + end_timer!(timer); + (verified, rz, r_simd, r_mpi, current_claim) +} + +#[allow(clippy::too_many_arguments)] +#[allow(clippy::type_complexity)] +#[allow(clippy::unnecessary_unwrap)] +fn sumcheck_verify_gkr_square_layer>( + mpi_config: &MPIConfig, + layer: &CircuitLayer, + public_input: &[C::SimdCircuitField], + rz: &[C::ChallengeField], + r_simd: &Vec, + r_mpi: &Vec, + current_claim: C::ChallengeField, + mut proof_reader: impl Read, + transcript: &mut T, + sp: &mut VerifierScratchPad, + is_output_layer: bool, +) -> ( + bool, + Vec, + Vec, + Vec, + C::ChallengeField, +) { + // GKR2 with Power5 gate has degree 6 polynomial + let degree = 6; + + GKRVerifierHelper::prepare_layer(layer, &None, rz, &None, r_simd, r_mpi, sp, is_output_layer); + + let var_num = layer.input_var_num; + let mut sum = current_claim; + sum -= GKRVerifierHelper::eval_cst(&layer.const_, public_input, sp); + + let mut rx = vec![]; + let mut r_simd_var = vec![]; + let mut r_mpi_var = vec![]; + let mut verified = true; + + for i_var in 0..var_num { + verified &= verify_sumcheck_step::( + &mut proof_reader, + degree, + transcript, + &mut sum, + &mut rx, + sp, + ); + log::trace!("x {} var, verified? {}", i_var, verified); + } + GKRVerifierHelper::set_rx(&rx, sp); + + for i_var in 0..C::get_field_pack_size().trailing_zeros() { + verified &= verify_sumcheck_step::( + &mut proof_reader, + degree, + transcript, + &mut sum, + &mut r_simd_var, + sp, + ); + log::trace!("simd {} var, verified? {}", i_var, verified); + } + GKRVerifierHelper::set_r_simd_xy(&r_simd_var, sp); + + // TODO: nontrivial MPI support + for _i_var in 0..mpi_config.world_size().trailing_zeros() { + verified &= verify_sumcheck_step::( + &mut proof_reader, + 3, + transcript, + &mut sum, + &mut r_mpi_var, + sp, + ); + // println!("{} mpi var, verified? {}", _i_var, verified); + } + GKRVerifierHelper::set_r_mpi_xy(&r_mpi_var, sp); + + let v_claim = C::ChallengeField::deserialize_from(&mut proof_reader).unwrap(); + + sum -= v_claim * GKRVerifierHelper::eval_pow_1(&layer.uni, sp) + + v_claim.exp(5) * GKRVerifierHelper::eval_pow_5(&layer.uni, sp); + transcript.append_field_element(&v_claim); + + verified &= sum == C::ChallengeField::ZERO; + + (verified, rx, r_simd_var, r_mpi_var, v_claim) +} diff --git a/sumcheck/src/scratch_pad.rs b/sumcheck/src/scratch_pad.rs index 79cc08af..569b3c60 100644 --- a/sumcheck/src/scratch_pad.rs +++ b/sumcheck/src/scratch_pad.rs @@ -91,6 +91,9 @@ pub struct VerifierScratchPad { pub gf2_deg2_eval_coef: C::ChallengeField, // 1 / x(x - 1) pub deg3_eval_at: [C::ChallengeField; 4], pub deg3_lag_denoms_inv: [C::ChallengeField; 4], + // ====== for deg6 eval ====== + pub deg6_eval_at: [C::ChallengeField; 7], + pub deg6_lag_denoms_inv: [C::ChallengeField; 7], } impl VerifierScratchPad { @@ -143,6 +146,32 @@ impl VerifierScratchPad { deg3_lag_denoms_inv[i] = denominator.inv().unwrap(); } + let deg6_eval_at = if C::FIELD_TYPE == FieldType::GF2 { + panic!("GF2 not supported yet"); + } else { + [ + C::ChallengeField::ZERO, + C::ChallengeField::ONE, + C::ChallengeField::from(2), + C::ChallengeField::from(3), + C::ChallengeField::from(4), + C::ChallengeField::from(5), + C::ChallengeField::from(6), + ] + }; + + let mut deg6_lag_denoms_inv = [C::ChallengeField::ZERO; 7]; + for i in 0..7 { + let mut denominator = C::ChallengeField::ONE; + for j in 0..7 { + if j == i { + continue; + } + denominator *= deg6_eval_at[i] - deg6_eval_at[j]; + } + deg6_lag_denoms_inv[i] = denominator.inv().unwrap(); + } + Self { eq_evals_at_rz0: vec![C::ChallengeField::zero(); max_io_size], eq_evals_at_r_simd: vec![C::ChallengeField::zero(); simd_size], @@ -168,6 +197,8 @@ impl VerifierScratchPad { gf2_deg2_eval_coef, deg3_eval_at, deg3_lag_denoms_inv, + deg6_eval_at, + deg6_lag_denoms_inv, } } } diff --git a/sumcheck/src/verifier_helper.rs b/sumcheck/src/verifier_helper.rs index 8d41b1c1..8ccc9ee4 100644 --- a/sumcheck/src/verifier_helper.rs +++ b/sumcheck/src/verifier_helper.rs @@ -1,5 +1,5 @@ use arith::{ExtensionField, Field}; -use circuit::{CircuitLayer, CoefType, GateAdd, GateConst, GateMul}; +use circuit::{CircuitLayer, CoefType, GateAdd, GateConst, GateMul, GateUni}; use gkr_field_config::{FieldType, GKRFieldConfig}; use polynomials::EqPolynomial; @@ -140,6 +140,39 @@ impl GKRVerifierHelper { v * sp.eq_r_simd_r_simd_xy * sp.eq_r_mpi_r_mpi_xy } + /// GKR2 equivalent of `eval_add`. (Note that GKR2 uses pow1 gates instead of add gates) + #[inline(always)] + pub fn eval_pow_1( + gates: &[GateUni], + sp: &VerifierScratchPad, + ) -> C::ChallengeField { + let mut v = C::ChallengeField::zero(); + for gate in gates { + // Gates of type 12346 represent an add gate + if gate.gate_type == 12346 { + v += sp.eq_evals_at_rz0[gate.o_id] + * C::challenge_mul_circuit_field(&sp.eq_evals_at_rx[gate.i_ids[0]], &gate.coef); + } + } + v * sp.eq_r_simd_r_simd_xy + } + + #[inline(always)] + pub fn eval_pow_5( + gates: &[GateUni], + sp: &VerifierScratchPad, + ) -> C::ChallengeField { + let mut v = C::ChallengeField::zero(); + for gate in gates { + // Gates of type 12345 represent a pow5 gate + if gate.gate_type == 12345 { + v += sp.eq_evals_at_rz0[gate.o_id] + * C::challenge_mul_circuit_field(&sp.eq_evals_at_rx[gate.i_ids[0]], &gate.coef); + } + } + v * sp.eq_r_simd_r_simd_xy + } + #[inline(always)] pub fn set_rx(rx: &[C::ChallengeField], sp: &mut VerifierScratchPad) { EqPolynomial::::eq_eval_at( @@ -219,13 +252,26 @@ impl GKRVerifierHelper { Self::lag_eval(vals, x, sp) } + #[inline(always)] + pub fn degree_6_eval( + vals: &[C::ChallengeField], + x: C::ChallengeField, + sp: &VerifierScratchPad, + ) -> C::ChallengeField { + Self::lag_eval(vals, x, sp) + } + #[inline(always)] fn lag_eval( vals: &[C::ChallengeField], x: C::ChallengeField, sp: &VerifierScratchPad, ) -> C::ChallengeField { - assert_eq!(sp.deg3_eval_at.len(), vals.len()); + let (evals, lag_denoms_inv) = match vals.len() { + 4 => (sp.deg3_eval_at.to_vec(), sp.deg3_lag_denoms_inv.to_vec()), + 7 => (sp.deg6_eval_at.to_vec(), sp.deg6_lag_denoms_inv.to_vec()), + _ => panic!("unsupported degree"), + }; let mut v = C::ChallengeField::ZERO; for i in 0..vals.len() { @@ -234,9 +280,9 @@ impl GKRVerifierHelper { if j == i { continue; } - numerator *= x - sp.deg3_eval_at[j]; + numerator *= x - evals[j]; } - v += numerator * sp.deg3_lag_denoms_inv[i] * vals[i]; + v += numerator * lag_denoms_inv[i] * vals[i]; } v } From a53832d771f692316303468fd294580c18ebd618 Mon Sep 17 00:00:00 2001 From: enpsi Date: Wed, 6 Nov 2024 15:02:21 +0100 Subject: [PATCH 04/14] public input test coverage --- gkr/src/tests/gkr_correctness.rs | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index 8ee993f6..73660e34 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -4,7 +4,7 @@ use std::time::Instant; use std::{fs, panic}; use arith::{Field, FieldSerde, SimdField}; -use circuit::{Circuit, CircuitLayer, CoefType, GateUni}; +use circuit::{Circuit, CircuitLayer, CoefType, GateConst, GateUni}; use config::{Config, FiatShamirHashType, GKRConfig, GKRScheme, PolynomialCommitmentType}; use config_macros::declare_gkr_config; use gkr_field_config::{BN254Config, FieldType, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; @@ -250,8 +250,8 @@ fn test_gkr_correctness_helper(config: &Config, write_proof /// N_0_0 N_0_1 Layer 0 (Output) /// x11 / \ / | \ /// N_1_0 N_1_1 N_1_2 N_1_3 Layer 1 -/// | | / | | -/// Pow5| | / | | +/// | | / | | \ +/// Pow5| | / | | PI[0] /// N_2_0 N_2_1 N_2_2 N_2_3 Layer 2 (Input) /// ``` /// (Unmarked lines are `+` gates with coeff 1) @@ -264,6 +264,14 @@ pub fn gkr_square_test_circuit() -> Circuit { output_var_num: 2, ..Default::default() }; + // N_1_3 += PI[0] (public input) + l1.const_.push(GateConst { + i_ids: [], + o_id: 3, + coef: C::CircuitField::from(1), + coef_type: CoefType::PublicInput(0), + gate_type: 0, + }); // N_1_0 += (N_2_0)^5 l1.uni.push(GateUni { i_ids: [0], @@ -376,6 +384,8 @@ fn gkr_square_correctness() { let final_vals = (0..16).map(|x| x.into()).collect::>(); let final_vals = ::SimdCircuitField::pack(&final_vals); circuit.layers[0].input_vals = vec![2.into(), 3.into(), 5.into(), final_vals]; + // Set public input PI[0] = 13 + circuit.public_input = vec![13.into()]; let config = Config::::new(GKRScheme::GkrSquare, MPIConfig::default()); do_prove_verify(config, &mut circuit); @@ -397,7 +407,7 @@ fn do_prove_verify(config: Config, circuit: &mut Circuit Date: Fri, 15 Nov 2024 15:56:16 +0100 Subject: [PATCH 05/14] feat: MPI support for GKR2 --- gkr/src/prover/gkr_square.rs | 20 ++-- gkr/src/prover/linear_gkr.rs | 2 +- gkr/src/verifier/gkr_square.rs | 20 ++-- .../src/prover_helper/sumcheck_gkr_square.rs | 96 +++++++++++++++++-- sumcheck/src/sumcheck.rs | 39 ++++---- sumcheck/src/utils.rs | 6 +- sumcheck/src/verifier_helper.rs | 4 +- 7 files changed, 135 insertions(+), 52 deletions(-) diff --git a/gkr/src/prover/gkr_square.rs b/gkr/src/prover/gkr_square.rs index 0a34f8b5..279aee80 100644 --- a/gkr/src/prover/gkr_square.rs +++ b/gkr/src/prover/gkr_square.rs @@ -19,6 +19,7 @@ pub fn gkr_square_prove>( C::ChallengeField, Vec, Vec, + Vec, ) { let timer = start_timer!(|| "gkr^2 prove"); let layer_num = circuit.layers.len(); @@ -34,12 +35,6 @@ pub fn gkr_square_prove>( } log::trace!("Initial r_simd: {:?}", r_simd); - // TODO: MPI support - assert_eq!( - mpi_config.world_size().trailing_zeros(), - 0, - "MPI not supported yet" - ); let mut r_mpi = vec![]; for _ in 0..mpi_config.world_size().trailing_zeros() { r_mpi.push(transcript.generate_challenge_field_element()); @@ -69,8 +64,15 @@ pub fn gkr_square_prove>( log::trace!("Claimed v: {:?}", claimed_v); for i in (0..layer_num).rev() { - (rz0, r_simd) = - sumcheck_prove_gkr_square_layer(&circuit.layers[i], &rz0, &r_simd, transcript, sp); + (rz0, r_simd, r_mpi) = sumcheck_prove_gkr_square_layer( + &circuit.layers[i], + &rz0, + &r_simd, + &r_mpi, + transcript, + sp, + mpi_config, + ); log::trace!("Layer {} proved", i); log::trace!("rz0.0: {:?}", rz0[0]); @@ -79,5 +81,5 @@ pub fn gkr_square_prove>( } end_timer!(timer); - (claimed_v, rz0, r_simd) + (claimed_v, rz0, r_simd, r_mpi) } diff --git a/gkr/src/prover/linear_gkr.rs b/gkr/src/prover/linear_gkr.rs index 70795a5e..9265665f 100644 --- a/gkr/src/prover/linear_gkr.rs +++ b/gkr/src/prover/linear_gkr.rs @@ -115,7 +115,7 @@ impl Prover { let mut rmpi = vec![]; if self.config.gkr_scheme == GKRScheme::GkrSquare { - (claimed_v, rx, rsimd) = + (claimed_v, rx, rsimd, rmpi) = gkr_square_prove(c, &mut self.sp, &mut transcript, &self.config.mpi_config); } else { (claimed_v, rx, ry, rsimd, rmpi) = diff --git a/gkr/src/verifier/gkr_square.rs b/gkr/src/verifier/gkr_square.rs index 7d4bc685..9fe09318 100644 --- a/gkr/src/verifier/gkr_square.rs +++ b/gkr/src/verifier/gkr_square.rs @@ -34,22 +34,15 @@ pub fn gkr_square_verify>( for _ in 0..circuit.layers.last().unwrap().output_var_num { rz.push(transcript.generate_challenge_field_element()); } - log::trace!("rz {:?}", rz); - for _ in 0..C::get_field_pack_size().trailing_zeros() { r_simd.push(transcript.generate_challenge_field_element()); } - log::trace!("r_simd {:?}", r_simd); - - // TODO: MPI support - assert_eq!( - mpi_config.world_size().trailing_zeros(), - 0, - "MPI not supported yet" - ); for _ in 0..mpi_config.world_size().trailing_zeros() { r_mpi.push(transcript.generate_challenge_field_element()); } + log::trace!("Initial rz0: {:?}", rz); + log::trace!("Initial r_simd: {:?}", r_simd); + log::trace!("Initial r_mpi: {:?}", r_mpi); let mut verified = true; let mut current_claim = *claimed_v; @@ -69,6 +62,7 @@ pub fn gkr_square_verify>( &mut sp, i == layer_num - 1, ); + log::trace!("Layer {} verified? {}", i, cur_verified); verified &= cur_verified; } end_timer!(timer); @@ -137,21 +131,21 @@ fn sumcheck_verify_gkr_square_layer( &mut proof_reader, - 3, + degree, transcript, &mut sum, &mut r_mpi_var, sp, ); - // println!("{} mpi var, verified? {}", _i_var, verified); + log::trace!("{} mpi var, verified? {}", _i_var, verified); } GKRVerifierHelper::set_r_mpi_xy(&r_mpi_var, sp); let v_claim = C::ChallengeField::deserialize_from(&mut proof_reader).unwrap(); + log::trace!("v_claim: {:?}", v_claim); sum -= v_claim * GKRVerifierHelper::eval_pow_1(&layer.uni, sp) + v_claim.exp(5) * GKRVerifierHelper::eval_pow_5(&layer.uni, sp); diff --git a/sumcheck/src/prover_helper/sumcheck_gkr_square.rs b/sumcheck/src/prover_helper/sumcheck_gkr_square.rs index d6de7b14..b5195788 100644 --- a/sumcheck/src/prover_helper/sumcheck_gkr_square.rs +++ b/sumcheck/src/prover_helper/sumcheck_gkr_square.rs @@ -2,6 +2,7 @@ use crate::{unpack_and_combine, ProverScratchPad}; use arith::{Field, SimdField}; use circuit::CircuitLayer; use gkr_field_config::GKRFieldConfig; +use mpi_config::MPIConfig; use polynomials::EqPolynomial; use super::{power_gate::SumcheckPowerGateHelper, simd_gate::SumcheckSimdProdGateHelper}; @@ -10,11 +11,13 @@ use super::{power_gate::SumcheckPowerGateHelper, simd_gate::SumcheckSimdProdGate pub(crate) struct SumcheckGkrSquareHelper<'a, C: GKRFieldConfig, const D: usize> { pub(crate) rx: Vec, pub(crate) r_simd_var: Vec, + pub(crate) r_mpi_var: Vec, layer: &'a CircuitLayer, sp: &'a mut ProverScratchPad, rz0: &'a [C::ChallengeField], r_simd: &'a [C::ChallengeField], + r_mpi: &'a [C::ChallengeField], _input_var_num: usize, _output_var_num: usize, @@ -22,6 +25,9 @@ pub(crate) struct SumcheckGkrSquareHelper<'a, C: GKRFieldConfig, const D: usize> x_helper: SumcheckPowerGateHelper, simd_helper: SumcheckSimdProdGateHelper, + mpi_helper: SumcheckSimdProdGateHelper, + + mpi_config: &'a MPIConfig, } impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { @@ -30,18 +36,22 @@ impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { layer: &'a CircuitLayer, rz0: &'a [C::ChallengeField], r_simd: &'a [C::ChallengeField], + r_mpi: &'a [C::ChallengeField], sp: &'a mut ProverScratchPad, + mpi_config: &'a MPIConfig, ) -> Self { let simd_var_num = C::get_field_pack_size().trailing_zeros() as usize; SumcheckGkrSquareHelper { rx: vec![], r_simd_var: vec![], + r_mpi_var: vec![], layer, sp, rz0, r_simd, + r_mpi, _input_var_num: layer.input_var_num, _output_var_num: layer.output_var_num, @@ -49,12 +59,16 @@ impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { x_helper: SumcheckPowerGateHelper::new(layer.input_var_num), simd_helper: SumcheckSimdProdGateHelper::new(simd_var_num), + mpi_helper: SumcheckSimdProdGateHelper::new( + mpi_config.world_size().trailing_zeros() as usize + ), + mpi_config, } } #[inline] pub(crate) fn poly_evals_at_x(&self, var_idx: usize) -> [C::ChallengeField; D] { - let evals = self.x_helper.poly_eval_at::( + let local_vals_simd = self.x_helper.poly_eval_at::( var_idx, &self.sp.v_evals, &self.sp.hg_evals_5, @@ -63,22 +77,56 @@ impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { &self.sp.gate_exists_5, &self.sp.gate_exists_1, ); - let mut simd_combined = [C::ChallengeField::zero(); D]; - for (combined, simd_val) in simd_combined.iter_mut().zip(evals.iter()) { - *combined = unpack_and_combine(simd_val, &self.sp.eq_evals_at_r_simd0); + + // SIMD + let local_vals = local_vals_simd + .iter() + .map(|p| unpack_and_combine(p, &self.sp.eq_evals_at_r_simd0)) + .collect::>(); + + // MPI + let global_vals = self + .mpi_config + .coef_combine_vec(&local_vals, &self.sp.eq_evals_at_r_mpi0); + if self.mpi_config.is_root() { + global_vals.try_into().unwrap() + } else { + [C::ChallengeField::ZERO; D] } - simd_combined } #[inline] pub(crate) fn poly_evals_at_simd(&self, var_idx: usize) -> [C::ChallengeField; D] { - self.simd_helper.gkr2_poly_eval_at::( + let local_vals = self.simd_helper.gkr2_poly_eval_at::( var_idx, &self.sp.eq_evals_at_r_simd0, &self.sp.simd_var_v_evals, self.sp.hg_evals_1[0], self.sp.hg_evals_5[0], - ) + ); + let global_vals = self + .mpi_config + .coef_combine_vec(&local_vals.to_vec(), &self.sp.eq_evals_at_r_mpi0); + if self.mpi_config.is_root() { + global_vals.try_into().unwrap() + } else { + [C::ChallengeField::ZERO; D] + } + } + + pub(crate) fn poly_evals_at_mpi(&mut self, var_idx: usize) -> [C::ChallengeField; D] { + assert!(var_idx < self.mpi_config.world_size().trailing_zeros() as usize); + let mut evals = self.mpi_helper.gkr2_poly_eval_at::( + var_idx, + &self.sp.eq_evals_at_r_mpi0, + &mut self.sp.mpi_var_v_evals, + self.sp.hg_evals_1[0], + self.sp.hg_evals_5[0], + ); + for eval in evals.iter_mut() { + *eval *= self.sp.eq_evals_at_r_simd0[0]; + } + evals } #[inline] @@ -109,9 +157,21 @@ impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { self.r_simd_var.push(r); } + #[inline] + pub(crate) fn receive_mpi_challenge(&mut self, var_idx: usize, r: C::ChallengeField) { + self.mpi_helper.receive_challenge::( + var_idx, + r, + &mut self.sp.eq_evals_at_r_mpi0, + &mut self.sp.mpi_var_v_evals, + &mut self.sp.mpi_var_hg_evals, + ); + self.r_mpi_var.push(r); + } + #[inline(always)] pub(crate) fn vx_claim(&self) -> C::ChallengeField { - self.sp.simd_var_v_evals[0] + self.sp.mpi_var_v_evals[0] } #[inline] @@ -125,11 +185,31 @@ impl<'a, C: GKRFieldConfig, const D: usize> SumcheckGkrSquareHelper<'a, C, D> { ); } + #[inline] + pub(crate) fn prepare_mpi(&mut self) { + // TODO: No need to evaluate it at all world ranks, remove redundancy later. + EqPolynomial::::eq_eval_at( + self.r_mpi, + &C::ChallengeField::one(), + &mut self.sp.eq_evals_at_r_mpi0, + &mut self.sp.eq_evals_first_half, + &mut self.sp.eq_evals_second_half, + ); + } + #[inline] pub(crate) fn prepare_simd_var_vals(&mut self) { self.sp.simd_var_v_evals = self.sp.v_evals[0].unpack(); } + #[inline] + pub(crate) fn prepare_mpi_var_vals(&mut self) { + self.mpi_config.gather_vec( + &vec![self.sp.simd_var_v_evals[0]], + &mut self.sp.mpi_var_v_evals, + ); + } + #[inline] pub(crate) fn prepare_g_x_vals(&mut self) { let uni = &self.layer.uni; // univariate things like square, pow5, etc. diff --git a/sumcheck/src/sumcheck.rs b/sumcheck/src/sumcheck.rs index 13c2356e..c8116bd7 100644 --- a/sumcheck/src/sumcheck.rs +++ b/sumcheck/src/sumcheck.rs @@ -99,26 +99,28 @@ pub fn sumcheck_prove_gkr_square_layer, rz0: &[C::ChallengeField], r_simd: &[C::ChallengeField], + r_mpi: &[C::ChallengeField], transcript: &mut T, sp: &mut ProverScratchPad, -) -> (Vec, Vec) { + mpi_config: &MPIConfig, +) -> ( + Vec, + Vec, + Vec, +) { const D: usize = 7; - let mut helper = SumcheckGkrSquareHelper::new(layer, rz0, r_simd, sp); + let mut helper = + SumcheckGkrSquareHelper::::new(layer, rz0, r_simd, r_mpi, sp, mpi_config); helper.prepare_simd(); + helper.prepare_mpi(); helper.prepare_g_x_vals(); // x-variable sumcheck rounds for i_var in 0..layer.input_var_num { - let evals: [C::ChallengeField; D] = helper.poly_evals_at_x(i_var); - - for deg in 0..D { - transcript.append_field_element(&evals[deg]); - } - let r = transcript.generate_challenge_field_element(); - + let evals = helper.poly_evals_at_x(i_var); + let r = transcript_io::(mpi_config, &evals, transcript); log::trace!("x i_var={} evals: {:?} r: {:?}", i_var, evals, r); - helper.receive_x_challenge(i_var, r); } @@ -128,19 +130,20 @@ pub fn sumcheck_prove_gkr_square_layer(mpi_config, &evals, transcript); log::trace!("SIMD i_var={} evals: {:?} r: {:?}", i_var, evals, r); - helper.receive_simd_challenge(i_var, r); } + helper.prepare_mpi_var_vals(); + for i_var in 0..mpi_config.world_size().trailing_zeros() as usize { + let evals = helper.poly_evals_at_mpi(i_var); + let r = transcript_io::(mpi_config, &evals, transcript); + helper.receive_mpi_challenge(i_var, r); + } + log::trace!("vx claim: {:?}", helper.vx_claim()); transcript.append_field_element(&helper.vx_claim()); - (helper.rx, helper.r_simd_var) + (helper.rx, helper.r_simd_var, helper.r_mpi_var) } diff --git a/sumcheck/src/utils.rs b/sumcheck/src/utils.rs index 37398a6d..6736d5ea 100644 --- a/sumcheck/src/utils.rs +++ b/sumcheck/src/utils.rs @@ -30,7 +30,11 @@ where F: Field, T: Transcript, { - assert!(ps.len() == 3 || ps.len() == 4); // 3 for x, y; 4 for simd var + // 3 for x, y; 4 for simd var; 7 for pow5, 9 for pow7 + assert!( + ps.len() == 3 || ps.len() == 4 || ps.len() == 7 || ps.len() == 9, + "Unexpected polynomial size" + ); for p in ps { transcript.append_field_element(p); } diff --git a/sumcheck/src/verifier_helper.rs b/sumcheck/src/verifier_helper.rs index 8ccc9ee4..1a4ba0ab 100644 --- a/sumcheck/src/verifier_helper.rs +++ b/sumcheck/src/verifier_helper.rs @@ -154,7 +154,7 @@ impl GKRVerifierHelper { * C::challenge_mul_circuit_field(&sp.eq_evals_at_rx[gate.i_ids[0]], &gate.coef); } } - v * sp.eq_r_simd_r_simd_xy + v * sp.eq_r_simd_r_simd_xy * sp.eq_r_mpi_r_mpi_xy } #[inline(always)] @@ -170,7 +170,7 @@ impl GKRVerifierHelper { * C::challenge_mul_circuit_field(&sp.eq_evals_at_rx[gate.i_ids[0]], &gate.coef); } } - v * sp.eq_r_simd_r_simd_xy + v * sp.eq_r_simd_r_simd_xy * sp.eq_r_mpi_r_mpi_xy } #[inline(always)] From 18ba6c3cd53e1eb2847a936606cb9dbd17dbd164 Mon Sep 17 00:00:00 2001 From: enpsi Date: Fri, 15 Nov 2024 16:08:04 +0100 Subject: [PATCH 06/14] chore: update gkr2 correctness test (M31 only) --- gkr/src/tests/gkr_correctness.rs | 35 +++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index 73660e34..e5308a0c 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -368,7 +368,7 @@ pub fn gkr_square_test_circuit() -> Circuit { } #[test] -fn gkr_square_correctness() { +fn gkr_square_correctness_test() { declare_gkr_config!( GkrConfigType, FieldType::M31, @@ -377,18 +377,23 @@ fn gkr_square_correctness() { ); env_logger::init(); type GkrFieldConfigType = ::FieldConfig; + let mpi_config = MPIConfig::new(); + let config = Config::::new(GKRScheme::GkrSquare, mpi_config.clone()); let mut circuit = gkr_square_test_circuit::(); // Set input layers with N_2_0 = 3, N_2_1 = 5, N_2_2 = 7, // and N_2_3 varying from 0 to 15 - let final_vals = (0..16).map(|x| x.into()).collect::>(); + let mut final_vals = (0..16).map(|x| x.into()).collect::>(); // Add variety for MPI participants + final_vals[0] += ::CircuitField::from( + config.mpi_config.world_rank as u32, + ); let final_vals = ::SimdCircuitField::pack(&final_vals); circuit.layers[0].input_vals = vec![2.into(), 3.into(), 5.into(), final_vals]; // Set public input PI[0] = 13 circuit.public_input = vec![13.into()]; - let config = Config::::new(GKRScheme::GkrSquare, MPIConfig::default()); do_prove_verify(config, &mut circuit); + MPIConfig::finalize(); } fn do_prove_verify(config: Config, circuit: &mut Circuit) { @@ -405,15 +410,17 @@ fn do_prove_verify(config: Config, circuit: &mut Circuit Date: Fri, 15 Nov 2024 16:46:35 +0100 Subject: [PATCH 07/14] wip: GF2 field --- sumcheck/src/scratch_pad.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sumcheck/src/scratch_pad.rs b/sumcheck/src/scratch_pad.rs index 569b3c60..cabdeae1 100644 --- a/sumcheck/src/scratch_pad.rs +++ b/sumcheck/src/scratch_pad.rs @@ -147,7 +147,20 @@ impl VerifierScratchPad { } let deg6_eval_at = if C::FIELD_TYPE == FieldType::GF2 { - panic!("GF2 not supported yet"); + // TODO: Does this correctly define Lagrange poly for GF2? + [ + C::ChallengeField::ZERO, + C::ChallengeField::ONE, + C::ChallengeField::X, + C::ChallengeField::X.mul_by_x(), + C::ChallengeField::X.mul_by_x().mul_by_x(), + C::ChallengeField::X.mul_by_x().mul_by_x().mul_by_x(), + C::ChallengeField::X + .mul_by_x() + .mul_by_x() + .mul_by_x() + .mul_by_x(), + ] } else { [ C::ChallengeField::ZERO, From 781b0cba96df49bc8e2e5a956dce93c763fff888 Mon Sep 17 00:00:00 2001 From: enpsi Date: Fri, 15 Nov 2024 16:51:53 +0100 Subject: [PATCH 08/14] clippy --- gkr/src/prover/gkr_square.rs | 1 + gkr/src/tests/gkr_correctness.rs | 2 +- sumcheck/src/prover_helper/sumcheck_gkr_square.rs | 2 +- sumcheck/src/sumcheck.rs | 1 + sumcheck/src/verifier_helper.rs | 1 + 5 files changed, 5 insertions(+), 2 deletions(-) diff --git a/gkr/src/prover/gkr_square.rs b/gkr/src/prover/gkr_square.rs index 279aee80..24f35fc5 100644 --- a/gkr/src/prover/gkr_square.rs +++ b/gkr/src/prover/gkr_square.rs @@ -10,6 +10,7 @@ use polynomials::MultiLinearPoly; use sumcheck::{sumcheck_prove_gkr_square_layer, ProverScratchPad}; use transcript::Transcript; +#[allow(clippy::type_complexity)] pub fn gkr_square_prove>( circuit: &Circuit, sp: &mut ProverScratchPad, diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index e5308a0c..79949525 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -407,7 +407,7 @@ fn do_prove_verify(config: Config, circuit: &mut Circuit SumcheckGkrSquareHelper<'a, C, D> { let mut evals = self.mpi_helper.gkr2_poly_eval_at::( var_idx, &self.sp.eq_evals_at_r_mpi0, - &mut self.sp.mpi_var_v_evals, + &self.sp.mpi_var_v_evals, self.sp.hg_evals_1[0], self.sp.hg_evals_5[0], ); diff --git a/sumcheck/src/sumcheck.rs b/sumcheck/src/sumcheck.rs index c8116bd7..5baabbfb 100644 --- a/sumcheck/src/sumcheck.rs +++ b/sumcheck/src/sumcheck.rs @@ -95,6 +95,7 @@ pub fn sumcheck_prove_gkr_layer>( layer: &CircuitLayer, rz0: &[C::ChallengeField], diff --git a/sumcheck/src/verifier_helper.rs b/sumcheck/src/verifier_helper.rs index 1a4ba0ab..e6ddcc72 100644 --- a/sumcheck/src/verifier_helper.rs +++ b/sumcheck/src/verifier_helper.rs @@ -262,6 +262,7 @@ impl GKRVerifierHelper { } #[inline(always)] + #[allow(clippy::needless_range_loop)] fn lag_eval( vals: &[C::ChallengeField], x: C::ChallengeField, From c6fd57e42dcb765c1dd3d1a71dc76459629ec6bf Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 24 Feb 2025 19:53:32 -0600 Subject: [PATCH 09/14] fmt --- gkr/src/prover/linear_gkr.rs | 5 +---- gkr/src/verifier.rs | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/gkr/src/prover/linear_gkr.rs b/gkr/src/prover/linear_gkr.rs index d83ffe84..31879a31 100644 --- a/gkr/src/prover/linear_gkr.rs +++ b/gkr/src/prover/linear_gkr.rs @@ -111,11 +111,8 @@ impl Prover { c.fill_rnd_coefs(&mut transcript); c.evaluate(); - let mut claimed_v = ::ChallengeField::default(); - let rx; + let (claimed_v, rx, rsimd, rmpi); let mut ry = None; - let mut rsimd = vec![]; - let mut rmpi = vec![]; let gkr_prove_timer = Timer::new("gkr prove", self.config.mpi_config.is_root()); if self.config.gkr_scheme == GKRScheme::GkrSquare { diff --git a/gkr/src/verifier.rs b/gkr/src/verifier.rs index 4f6c39ad..f789abca 100644 --- a/gkr/src/verifier.rs +++ b/gkr/src/verifier.rs @@ -340,7 +340,7 @@ impl Verifier { if let Some(rz1) = rz1 { transcript_verifier_sync(&mut transcript, &self.config.mpi_config); - verified &= self.get_pcs_opening_from_proof_and_verify( + verified &= self.get_pcs_opening_from_proof_and_verify( pcs_params, pcs_verification_key, &commitment, From 85bace6268b21bebdd46e0b9acc36c7b3dce8980 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 24 Feb 2025 20:05:21 -0600 Subject: [PATCH 10/14] separate testing file to avoid mpi issue --- gkr/src/tests/gkr_correctness.rs | 180 --------------------------- gkr/src/tests/gkr_square.rs | 207 +++++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+), 180 deletions(-) create mode 100644 gkr/src/tests/gkr_square.rs diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index bd1a8091..7e78f293 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -290,183 +290,3 @@ fn test_gkr_correctness_helper(config: &Config, write_proof println!("============== end ==============="); } } - -/// A simple GKR2 test circuit: -/// ```text -/// N_0_0 N_0_1 Layer 0 (Output) -/// x11 / \ / | \ -/// N_1_0 N_1_1 N_1_2 N_1_3 Layer 1 -/// | | / | | \ -/// Pow5| | / | | PI[0] -/// N_2_0 N_2_1 N_2_2 N_2_3 Layer 2 (Input) -/// ``` -/// (Unmarked lines are `+` gates with coeff 1) -pub fn gkr_square_test_circuit() -> Circuit { - let mut circuit = Circuit::default(); - - // Layer 1 - let mut l1 = CircuitLayer { - input_var_num: 2, - output_var_num: 2, - ..Default::default() - }; - // N_1_3 += PI[0] (public input) - l1.const_.push(GateConst { - i_ids: [], - o_id: 3, - coef: C::CircuitField::from(1), - coef_type: CoefType::PublicInput(0), - gate_type: 0, - }); - // N_1_0 += (N_2_0)^5 - l1.uni.push(GateUni { - i_ids: [0], - o_id: 0, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12345, - }); - - // N_1_1 += N_2_1 - l1.uni.push(GateUni { - i_ids: [1], - o_id: 1, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - // N_1_2 += N_2_1 - l1.uni.push(GateUni { - i_ids: [1], - o_id: 2, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - // N_1_2 += N_2_2 - l1.uni.push(GateUni { - i_ids: [2], - o_id: 2, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - // N_1_3 += N_2_3 - l1.uni.push(GateUni { - i_ids: [3], - o_id: 3, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - circuit.layers.push(l1); - - // Output layer - let mut output_layer = CircuitLayer { - input_var_num: 2, - output_var_num: 1, - ..Default::default() - }; - // N_0_0 += 11 * N_1_0 - output_layer.uni.push(GateUni { - i_ids: [0], - o_id: 0, - coef: C::CircuitField::from(11), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - // N_0_0 += N_1_1 - output_layer.uni.push(GateUni { - i_ids: [1], - o_id: 0, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - // N_0_1 += N_1_1 - output_layer.uni.push(GateUni { - i_ids: [1], - o_id: 1, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - // N_0_1 += N_1_2 - output_layer.uni.push(GateUni { - i_ids: [2], - o_id: 1, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - // N_0_1 += N_1_3 - output_layer.uni.push(GateUni { - i_ids: [3], - o_id: 1, - coef: C::CircuitField::from(1), - coef_type: CoefType::Constant, - gate_type: 12346, - }); - circuit.layers.push(output_layer); - - circuit.identify_rnd_coefs(); - circuit -} - -#[test] -fn gkr_square_correctness_test() { - declare_gkr_config!( - GkrConfigType, - FieldType::M31, - FiatShamirHashType::SHA256, - PolynomialCommitmentType::Raw - ); - env_logger::init(); - type GkrFieldConfigType = ::FieldConfig; - let mpi_config = MPIConfig::new(); - let config = Config::::new(GKRScheme::GkrSquare, mpi_config.clone()); - - let mut circuit = gkr_square_test_circuit::(); - // Set input layers with N_2_0 = 3, N_2_1 = 5, N_2_2 = 7, - // and N_2_3 varying from 0 to 15 - let mut final_vals = (0..16).map(|x| x.into()).collect::>(); // Add variety for MPI participants - final_vals[0] += ::CircuitField::from( - config.mpi_config.world_rank as u32, - ); - let final_vals = ::SimdCircuitField::pack(&final_vals); - circuit.layers[0].input_vals = vec![2.into(), 3.into(), 5.into(), final_vals]; - // Set public input PI[0] = 13 - circuit.public_input = vec![13.into()]; - - do_prove_verify(config, &mut circuit); - MPIConfig::finalize(); -} - -fn do_prove_verify(config: Config, circuit: &mut Circuit) { - circuit.evaluate(); - - let (pcs_params, pcs_proving_key, pcs_verification_key, mut pcs_scratch) = - expander_pcs_init_testing_only::( - circuit.log_input_size(), - &config.mpi_config, - ); - - // Prove - let mut prover = Prover::new(&config); - prover.prepare_mem(circuit); - let (claimed_v, proof) = prover.prove(circuit, &pcs_params, &pcs_proving_key, &mut pcs_scratch); - - // Verify if root process - if config.mpi_config.is_root() { - let verifier = Verifier::new(&config); - let public_input = circuit.public_input.clone(); - assert!(verifier.verify( - circuit, - &public_input, - &claimed_v, - &pcs_params, - &pcs_verification_key, - &proof - )) - } -} diff --git a/gkr/src/tests/gkr_square.rs b/gkr/src/tests/gkr_square.rs new file mode 100644 index 00000000..494f65bd --- /dev/null +++ b/gkr/src/tests/gkr_square.rs @@ -0,0 +1,207 @@ +use std::io::Write; +use std::panic::AssertUnwindSafe; +use std::time::Instant; +use std::{fs, panic}; + +use arith::{Field, FieldSerde, SimdField}; +use circuit::{Circuit, CircuitLayer, CoefType, GateConst, GateUni}; +use config::{Config, FiatShamirHashType, GKRConfig, GKRScheme, PolynomialCommitmentType}; +use config_macros::declare_gkr_config; +use field_hashers::{MiMC5FiatShamirHasher, PoseidonFiatShamirHasher}; +use gf2::GF2x128; +use gkr_field_config::{BN254Config, FieldType, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; +use halo2curves::bn256::G1Affine; +use mersenne31::M31x16; +use mpi_config::{root_println, MPIConfig}; +use poly_commit::{expander_pcs_init_testing_only, HyraxPCS, OrionPCSForGKR, RawExpanderGKR}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha12Rng; +use sha2::Digest; +use transcript::{BytesHashTranscript, FieldHashTranscript, Keccak256hasher, SHA256hasher}; + +use crate::{utils::*, Prover, Verifier}; + +const PCS_TESTING_SEED_U64: u64 = 114514; + + +/// A simple GKR2 test circuit: +/// ```text +/// N_0_0 N_0_1 Layer 0 (Output) +/// x11 / \ / | \ +/// N_1_0 N_1_1 N_1_2 N_1_3 Layer 1 +/// | | / | | \ +/// Pow5| | / | | PI[0] +/// N_2_0 N_2_1 N_2_2 N_2_3 Layer 2 (Input) +/// ``` +/// (Unmarked lines are `+` gates with coeff 1) +pub fn gkr_square_test_circuit() -> Circuit { + let mut circuit = Circuit::default(); + + // Layer 1 + let mut l1 = CircuitLayer { + input_var_num: 2, + output_var_num: 2, + ..Default::default() + }; + // N_1_3 += PI[0] (public input) + l1.const_.push(GateConst { + i_ids: [], + o_id: 3, + coef: C::CircuitField::from(1), + coef_type: CoefType::PublicInput(0), + gate_type: 0, + }); + // N_1_0 += (N_2_0)^5 + l1.uni.push(GateUni { + i_ids: [0], + o_id: 0, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12345, + }); + + // N_1_1 += N_2_1 + l1.uni.push(GateUni { + i_ids: [1], + o_id: 1, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_1_2 += N_2_1 + l1.uni.push(GateUni { + i_ids: [1], + o_id: 2, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_1_2 += N_2_2 + l1.uni.push(GateUni { + i_ids: [2], + o_id: 2, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_1_3 += N_2_3 + l1.uni.push(GateUni { + i_ids: [3], + o_id: 3, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + circuit.layers.push(l1); + + // Output layer + let mut output_layer = CircuitLayer { + input_var_num: 2, + output_var_num: 1, + ..Default::default() + }; + // N_0_0 += 11 * N_1_0 + output_layer.uni.push(GateUni { + i_ids: [0], + o_id: 0, + coef: C::CircuitField::from(11), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_0_0 += N_1_1 + output_layer.uni.push(GateUni { + i_ids: [1], + o_id: 0, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_0_1 += N_1_1 + output_layer.uni.push(GateUni { + i_ids: [1], + o_id: 1, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_0_1 += N_1_2 + output_layer.uni.push(GateUni { + i_ids: [2], + o_id: 1, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + // N_0_1 += N_1_3 + output_layer.uni.push(GateUni { + i_ids: [3], + o_id: 1, + coef: C::CircuitField::from(1), + coef_type: CoefType::Constant, + gate_type: 12346, + }); + circuit.layers.push(output_layer); + + circuit.identify_rnd_coefs(); + circuit +} + +#[test] +fn gkr_square_correctness_test() { + declare_gkr_config!( + GkrConfigType, + FieldType::M31, + FiatShamirHashType::SHA256, + PolynomialCommitmentType::Raw + ); + env_logger::init(); + type GkrFieldConfigType = ::FieldConfig; + let mpi_config = MPIConfig::new(); + let config = Config::::new(GKRScheme::GkrSquare, mpi_config.clone()); + + let mut circuit = gkr_square_test_circuit::(); + // Set input layers with N_2_0 = 3, N_2_1 = 5, N_2_2 = 7, + // and N_2_3 varying from 0 to 15 + let mut final_vals = (0..16).map(|x| x.into()).collect::>(); // Add variety for MPI participants + final_vals[0] += ::CircuitField::from( + config.mpi_config.world_rank as u32, + ); + let final_vals = ::SimdCircuitField::pack(&final_vals); + circuit.layers[0].input_vals = vec![2.into(), 3.into(), 5.into(), final_vals]; + // Set public input PI[0] = 13 + circuit.public_input = vec![13.into()]; + + do_prove_verify(config, &mut circuit); + MPIConfig::finalize(); +} + +fn do_prove_verify(config: Config, circuit: &mut Circuit) { + circuit.evaluate(); + + let mut rng = ChaCha12Rng::seed_from_u64(PCS_TESTING_SEED_U64); + let (pcs_params, pcs_proving_key, pcs_verification_key, mut pcs_scratch) = + expander_pcs_init_testing_only::( + circuit.log_input_size(), + &config.mpi_config, + &mut rng, + ); + + // Prove + let mut prover = Prover::new(&config); + prover.prepare_mem(circuit); + let (claimed_v, proof) = prover.prove(circuit, &pcs_params, &pcs_proving_key, &mut pcs_scratch); + + // Verify if root process + if config.mpi_config.is_root() { + let verifier = Verifier::new(&config); + let public_input = circuit.public_input.clone(); + assert!(verifier.verify( + circuit, + &public_input, + &claimed_v, + &pcs_params, + &pcs_verification_key, + &proof + )) + } +} From 974e3b76cd581819d8fc88fd89461c2ab7a55961 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 24 Feb 2025 20:06:56 -0600 Subject: [PATCH 11/14] clippy auto fix --- gkr/src/tests/gkr_correctness.rs | 4 ++-- poly_commit/src/orion/base_field_tests.rs | 2 +- poly_commit/src/orion/simd_field_tests.rs | 2 +- poly_commit/src/orion/utils.rs | 2 +- poly_commit/tests/test_hyrax.rs | 4 ++-- poly_commit/tests/test_orion.rs | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index 7e78f293..5e68888d 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -3,8 +3,8 @@ use std::panic::AssertUnwindSafe; use std::time::Instant; use std::{fs, panic}; -use arith::{Field, FieldSerde, SimdField}; -use circuit::{Circuit, CircuitLayer, CoefType, GateConst, GateUni}; +use arith::{Field, FieldSerde}; +use circuit::Circuit; use config::{Config, FiatShamirHashType, GKRConfig, GKRScheme, PolynomialCommitmentType}; use config_macros::declare_gkr_config; use field_hashers::{MiMC5FiatShamirHasher, PoseidonFiatShamirHasher}; diff --git a/poly_commit/src/orion/base_field_tests.rs b/poly_commit/src/orion/base_field_tests.rs index 80d453f2..e09b7845 100644 --- a/poly_commit/src/orion/base_field_tests.rs +++ b/poly_commit/src/orion/base_field_tests.rs @@ -23,7 +23,7 @@ where let mut interleaved_codewords: Vec<_> = poly .coeffs .chunks(msg_size) - .flat_map(|msg| orion_srs.code_instance.encode(&msg).unwrap()) + .flat_map(|msg| orion_srs.code_instance.encode(msg).unwrap()) .collect(); let mut scratch = vec![F::ZERO; row_num * orion_srs.codeword_len()]; diff --git a/poly_commit/src/orion/simd_field_tests.rs b/poly_commit/src/orion/simd_field_tests.rs index 5f30fa2e..5fcabd1d 100644 --- a/poly_commit/src/orion/simd_field_tests.rs +++ b/poly_commit/src/orion/simd_field_tests.rs @@ -29,7 +29,7 @@ where let mut interleaved_codewords: Vec<_> = poly .coeffs .chunks(msg_size) - .flat_map(|msg| orion_srs.code_instance.encode(&msg).unwrap()) + .flat_map(|msg| orion_srs.code_instance.encode(msg).unwrap()) .collect(); let mut scratch = vec![SimdF::ZERO; row_num * orion_srs.codeword_len()]; diff --git a/poly_commit/src/orion/utils.rs b/poly_commit/src/orion/utils.rs index 025a4d88..c571e745 100644 --- a/poly_commit/src/orion/utils.rs +++ b/poly_commit/src/orion/utils.rs @@ -503,7 +503,7 @@ mod tests { let mut table = SubsetSumLUTs::new(8, 1); table.build(&weights); - let actual_lut_inner_prod = table.lookup_and_sum(&vec![simd_bases]); + let actual_lut_inner_prod = table.lookup_and_sum(&[simd_bases]); assert_eq!(expected_simd_inner_prod, actual_lut_inner_prod) } diff --git a/poly_commit/tests/test_hyrax.rs b/poly_commit/tests/test_hyrax.rs index f3965395..6fbb6058 100644 --- a/poly_commit/tests/test_hyrax.rs +++ b/poly_commit/tests/test_hyrax.rs @@ -72,10 +72,10 @@ fn test_hyrax_for_expander_gkr_generics(mpi_config_ref: &MPIConfig, total_num_va HyraxPCS>, >( &num_vars_in_each_poly, - &mpi_config_ref, + mpi_config_ref, &mut transcript, &local_poly, - &vec![challenge_point], + &[challenge_point], ); } diff --git a/poly_commit/tests/test_orion.rs b/poly_commit/tests/test_orion.rs index 8e970c68..93273074 100644 --- a/poly_commit/tests/test_orion.rs +++ b/poly_commit/tests/test_orion.rs @@ -151,10 +151,10 @@ fn test_orion_for_expander_gkr_generics( OrionSIMDFieldPCS, >( &num_vars_in_each_poly, - &mpi_config_ref, + mpi_config_ref, &mut transcript, &local_poly, - &vec![challenge_point], + &[challenge_point], ); } From c12cfbd34f890590fbbcb1c2c23453fbbb12f8aa Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 24 Feb 2025 20:30:19 -0600 Subject: [PATCH 12/14] refactor expander-dedicated mle --- Cargo.lock | 2 + arith/polynomials/Cargo.toml | 2 + arith/polynomials/src/mle.rs | 134 +++++++++++++++++- config/gkr_field_config/src/traits.rs | 33 ----- crosslayer_prototype/src/gkr.rs | 8 +- gkr/src/prover/gkr.rs | 31 ++-- gkr/src/prover/gkr_square.rs | 31 ++-- poly_commit/src/orion/simd_field_agg_tests.rs | 15 +- poly_commit/src/raw.rs | 41 ++---- poly_commit/tests/common.rs | 10 +- 10 files changed, 188 insertions(+), 119 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 68513bb9..0becd3dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1865,7 +1865,9 @@ dependencies = [ "arith", "ark-std", "criterion", + "gkr_field_config", "halo2curves", + "mpi_config", ] [[package]] diff --git a/arith/polynomials/Cargo.toml b/arith/polynomials/Cargo.toml index 85bbde37..d6fe38c8 100644 --- a/arith/polynomials/Cargo.toml +++ b/arith/polynomials/Cargo.toml @@ -5,6 +5,8 @@ edition = "2021" [dependencies] arith = { path = "../" } +gkr_field_config = { path = "../../config/gkr_field_config" } +mpi_config = { path = "../../config/mpi_config" } ark-std.workspace = true criterion.workspace = true diff --git a/arith/polynomials/src/mle.rs b/arith/polynomials/src/mle.rs index 991a163e..9bf9c55c 100644 --- a/arith/polynomials/src/mle.rs +++ b/arith/polynomials/src/mle.rs @@ -1,7 +1,13 @@ -use std::ops::{Index, IndexMut, Mul}; +use std::{ + cmp, + marker::PhantomData, + ops::{Index, IndexMut, Mul}, +}; -use arith::Field; +use arith::{Field, SimdField}; use ark_std::{log2, rand::RngCore}; +use gkr_field_config::GKRFieldConfig; +use mpi_config::MPIConfig; use crate::{EqPolynomial, MultilinearExtension, MutableMultilinearExtension}; @@ -205,3 +211,127 @@ impl MutableMultilinearExtension for MultiLinearPoly { } } } + +#[derive(Debug, Clone, Default)] +pub struct MultiLinearPolyExpander { + _config: PhantomData, +} + +/// Some dedicated mle implementations for GKRFieldConfig +/// Take into consideration the simd challenge and the mpi challenge +/// +/// This is more efficient than the generic implementation by avoiding +/// unnecessary conversions between field types +impl MultiLinearPolyExpander { + pub fn new() -> Self { + Self { + _config: PhantomData, + } + } + + #[inline] + pub fn eval_circuit_vals_at_challenge( + evals: &[C::SimdCircuitField], + x: &[C::ChallengeField], + scratch: &mut [C::Field], + ) -> C::Field { + assert_eq!(1 << x.len(), evals.len()); + assert!(scratch.len() >= evals.len()); + + if x.is_empty() { + C::simd_circuit_field_into_field(&evals[0]) + } else { + for i in 0..(evals.len() >> 1) { + scratch[i] = C::field_add_simd_circuit_field( + &C::simd_circuit_field_mul_challenge_field( + &(evals[i * 2 + 1] - evals[i * 2]), + &x[0], + ), + &evals[i * 2], + ); + } + + let mut cur_eval_size = evals.len() >> 2; + for r in x.iter().skip(1) { + for i in 0..cur_eval_size { + scratch[i] = scratch[i * 2] + (scratch[i * 2 + 1] - scratch[i * 2]).scale(r); + } + cur_eval_size >>= 1; + } + scratch[0] + } + } + + /// This assumes each mpi core hold their own evals, and collectively + /// compute the global evaluation. + /// Mostly used by the prover run with `mpiexec` + #[inline] + pub fn collectively_eval_circuit_vals_at_expander_challenge( + local_evals: &[C::SimdCircuitField], + x: &[C::ChallengeField], + x_simd: &[C::ChallengeField], + x_mpi: &[C::ChallengeField], + scratch_field: &mut [C::Field], + scratch_challenge_field: &mut [C::ChallengeField], + mpi_config: &MPIConfig, + ) -> C::ChallengeField { + assert!(scratch_challenge_field.len() >= 1 << cmp::max(x_simd.len(), x_mpi.len())); + + let local_simd = Self::eval_circuit_vals_at_challenge(local_evals, x, scratch_field); + let local_simd_unpacked = local_simd.unpack(); + let local_v = MultiLinearPoly::evaluate_with_buffer( + &local_simd_unpacked, + x_simd, + scratch_challenge_field, + ); + + let global_v = if mpi_config.is_root() { + let mut claimed_v_gathering_buffer = + vec![C::ChallengeField::zero(); mpi_config.world_size()]; + mpi_config.gather_vec(&vec![local_v], &mut claimed_v_gathering_buffer); + MultiLinearPoly::evaluate_with_buffer( + &claimed_v_gathering_buffer, + &x_mpi, + scratch_challenge_field, + ) + } else { + mpi_config.gather_vec(&vec![local_v], &mut vec![]); + C::ChallengeField::zero() + }; + + global_v + } + + /// This assumes only a single core holds all the evals, and evaluate it locally + /// mostly used by the verifier + #[inline] + pub fn single_core_eval_circuit_vals_at_expander_challenge( + global_vals: &[C::SimdCircuitField], + x: &[C::ChallengeField], + x_simd: &[C::ChallengeField], + x_mpi: &[C::ChallengeField], + ) -> C::ChallengeField { + let local_poly_size = global_vals.len() >> x_mpi.len(); + assert_eq!(local_poly_size, 1 << x.len()); + + let mut scratch_field = vec![C::Field::default(); local_poly_size]; + let mut scratch_challenge_field = + vec![C::ChallengeField::default(); 1 << cmp::max(x_simd.len(), x_mpi.len())]; + let local_evals = global_vals + .chunks(local_poly_size) + .map(|local_vals| { + let local_simd = + Self::eval_circuit_vals_at_challenge(local_vals, x, &mut scratch_field); + let local_simd_unpacked = local_simd.unpack(); + MultiLinearPoly::evaluate_with_buffer( + &local_simd_unpacked, + x_simd, + &mut scratch_challenge_field, + ) + }) + .collect::>(); + + let mut scratch = vec![C::ChallengeField::default(); local_evals.len()]; + MultiLinearPoly::evaluate_with_buffer(&local_evals, x_mpi, &mut scratch) + } +} diff --git a/config/gkr_field_config/src/traits.rs b/config/gkr_field_config/src/traits.rs index ca5d29b0..96f0313f 100644 --- a/config/gkr_field_config/src/traits.rs +++ b/config/gkr_field_config/src/traits.rs @@ -73,37 +73,4 @@ pub trait GKRFieldConfig: Default + Debug + Clone + Send + Sync + 'static { fn get_field_pack_size() -> usize { Self::SimdCircuitField::PACK_SIZE } - - /// Evaluate the circuit values at the challenge - #[inline] - fn eval_circuit_vals_at_challenge( - evals: &[Self::SimdCircuitField], - x: &[Self::ChallengeField], - scratch: &mut [Self::Field], - ) -> Self::Field { - assert_eq!(1 << x.len(), evals.len()); - - if x.is_empty() { - Self::simd_circuit_field_into_field(&evals[0]) - } else { - for i in 0..(evals.len() >> 1) { - scratch[i] = Self::field_add_simd_circuit_field( - &Self::simd_circuit_field_mul_challenge_field( - &(evals[i * 2 + 1] - evals[i * 2]), - &x[0], - ), - &evals[i * 2], - ); - } - - let mut cur_eval_size = evals.len() >> 2; - for r in x.iter().skip(1) { - for i in 0..cur_eval_size { - scratch[i] = scratch[i * 2] + (scratch[i * 2 + 1] - scratch[i * 2]).scale(r); - } - cur_eval_size >>= 1; - } - scratch[0] - } - } } diff --git a/crosslayer_prototype/src/gkr.rs b/crosslayer_prototype/src/gkr.rs index 5a7db386..eaf443a4 100644 --- a/crosslayer_prototype/src/gkr.rs +++ b/crosslayer_prototype/src/gkr.rs @@ -1,6 +1,6 @@ use arith::{Field, SimdField}; use gkr_field_config::GKRFieldConfig; -use polynomials::MultiLinearPoly; +use polynomials::{MultiLinearPoly, MultiLinearPolyExpander}; use transcript::Transcript; use crate::sumcheck::{sumcheck_prove_gather_layer, sumcheck_prove_scatter_layer}; @@ -18,7 +18,11 @@ pub fn prove_gkr>( .generate_challenge_field_elements(final_layer_vals.len().trailing_zeros() as usize); let r_simd = transcript .generate_challenge_field_elements(C::get_field_pack_size().trailing_zeros() as usize); - let output_claim = C::eval_circuit_vals_at_challenge(final_layer_vals, &rz0, &mut sp.v_evals); + let output_claim = MultiLinearPolyExpander::::eval_circuit_vals_at_challenge( + final_layer_vals, + &rz0, + &mut sp.v_evals, + ); let output_claim = MultiLinearPoly::::evaluate_with_buffer( &output_claim.unpack(), &r_simd, diff --git a/gkr/src/prover/gkr.rs b/gkr/src/prover/gkr.rs index 62159604..e8f2d8b3 100644 --- a/gkr/src/prover/gkr.rs +++ b/gkr/src/prover/gkr.rs @@ -4,7 +4,7 @@ use arith::{Field, SimdField}; use circuit::Circuit; use gkr_field_config::GKRFieldConfig; use mpi_config::MPIConfig; -use polynomials::MultiLinearPoly; +use polynomials::{MultiLinearPoly, MultiLinearPolyExpander}; use sumcheck::{sumcheck_prove_gkr_layer, ProverScratchPad}; use transcript::Transcript; use utils::timer::Timer; @@ -44,27 +44,16 @@ pub fn gkr_prove>( let mut alpha = None; let output_vals = &circuit.layers.last().unwrap().output_vals; - - let claimed_v_simd = C::eval_circuit_vals_at_challenge(output_vals, &rz0, &mut sp.hg_evals); - let claimed_v_local = MultiLinearPoly::::evaluate_with_buffer( - &claimed_v_simd.unpack(), - &r_simd, - &mut sp.eq_evals_at_r_simd0, - ); - - let claimed_v = if mpi_config.is_root() { - let mut claimed_v_gathering_buffer = - vec![C::ChallengeField::zero(); mpi_config.world_size()]; - mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer); - MultiLinearPoly::evaluate_with_buffer( - &claimed_v_gathering_buffer, + let claimed_v = + MultiLinearPolyExpander::::collectively_eval_circuit_vals_at_expander_challenge( + output_vals, + &rz0, + &r_simd, &r_mpi, - &mut sp.eq_evals_at_r_mpi0, - ) - } else { - mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]); - C::ChallengeField::zero() - }; + &mut sp.hg_evals, + &mut sp.eq_evals_first_half, // confusing name here.. + mpi_config, + ); for i in (0..layer_num).rev() { let timer = Timer::new( diff --git a/gkr/src/prover/gkr_square.rs b/gkr/src/prover/gkr_square.rs index 24f35fc5..210d07d7 100644 --- a/gkr/src/prover/gkr_square.rs +++ b/gkr/src/prover/gkr_square.rs @@ -6,7 +6,7 @@ use ark_std::{end_timer, start_timer}; use circuit::Circuit; use gkr_field_config::GKRFieldConfig; use mpi_config::MPIConfig; -use polynomials::MultiLinearPoly; +use polynomials::MultiLinearPolyExpander; use sumcheck::{sumcheck_prove_gkr_square_layer, ProverScratchPad}; use transcript::Transcript; @@ -42,26 +42,17 @@ pub fn gkr_square_prove>( } let output_vals = &circuit.layers.last().unwrap().output_vals; - let claimed_v_simd = C::eval_circuit_vals_at_challenge(output_vals, &rz0, &mut sp.hg_evals); - let claimed_v_local = MultiLinearPoly::::evaluate_with_buffer( - &claimed_v_simd.unpack(), - &r_simd, - &mut sp.eq_evals_at_r_simd0, - ); - - let claimed_v = if mpi_config.is_root() { - let mut claimed_v_gathering_buffer = - vec![C::ChallengeField::zero(); mpi_config.world_size()]; - mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer); - MultiLinearPoly::evaluate_with_buffer( - &claimed_v_gathering_buffer, + let claimed_v = + MultiLinearPolyExpander::::collectively_eval_circuit_vals_at_expander_challenge( + output_vals, + &rz0, + &r_simd, &r_mpi, - &mut sp.eq_evals_at_r_mpi0, - ) - } else { - mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]); - C::ChallengeField::zero() - }; + &mut sp.hg_evals, + &mut sp.eq_evals_first_half, // confusing name here.. + mpi_config, + ); + log::trace!("Claimed v: {:?}", claimed_v); for i in (0..layer_num).rev() { diff --git a/poly_commit/src/orion/simd_field_agg_tests.rs b/poly_commit/src/orion/simd_field_agg_tests.rs index 9cd1c2e4..72fb88cf 100644 --- a/poly_commit/src/orion/simd_field_agg_tests.rs +++ b/poly_commit/src/orion/simd_field_agg_tests.rs @@ -7,7 +7,7 @@ use gf2_128::GF2_128; use gkr_field_config::{GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; use itertools::izip; use mersenne31::{M31Ext3, M31x16}; -use polynomials::{EqPolynomial, MultiLinearPoly}; +use polynomials::{EqPolynomial, MultiLinearPoly, MultiLinearPolyExpander}; use transcript::{BytesHashTranscript, Keccak256hasher, Transcript}; use crate::{ @@ -165,12 +165,13 @@ where let aggregated_proof = orion_proof_aggregate::(&openings, &gkr_challenge.x_mpi, &mut aggregator_transcript); - let final_expected_eval = RawExpanderGKR::::eval( - &global_poly.coeffs, - &gkr_challenge.x, - &gkr_challenge.x_simd, - &gkr_challenge.x_mpi, - ); + let final_expected_eval = + MultiLinearPolyExpander::::single_core_eval_circuit_vals_at_expander_challenge( + &global_poly.coeffs, + &gkr_challenge.x, + &gkr_challenge.x_simd, + &gkr_challenge.x_mpi, + ); assert!(orion_verify_simd_field_aggregated::( num_parties, diff --git a/poly_commit/src/raw.rs b/poly_commit/src/raw.rs index f48dd43c..4bae7b90 100644 --- a/poly_commit/src/raw.rs +++ b/poly_commit/src/raw.rs @@ -7,7 +7,7 @@ use arith::{BN254Fr, ExtensionField, Field, FieldForECC, FieldSerde, FieldSerdeR use ethnum::U256; use gkr_field_config::GKRFieldConfig; use mpi_config::MPIConfig; -use polynomials::{MultiLinearPoly, MultilinearExtension}; +use polynomials::{MultiLinearPoly, MultiLinearPolyExpander, MultilinearExtension}; use rand::RngCore; use transcript::Transcript; @@ -231,36 +231,13 @@ impl> PCSForExpanderGKR bool { assert!(mpi_config.is_root()); // Only the root will verify let ExpanderGKRChallenge:: { x, x_simd, x_mpi } = x; - Self::eval(&commitment.evals, x, x_simd, x_mpi) == v - } -} - -impl> RawExpanderGKR { - pub fn eval_local( - vals: &[C::SimdCircuitField], - x: &[C::ChallengeField], - x_simd: &[C::ChallengeField], - ) -> C::ChallengeField { - let mut scratch = vec![C::Field::default(); vals.len()]; - let y_simd = C::eval_circuit_vals_at_challenge(vals, x, &mut scratch); - let y_simd_unpacked = y_simd.unpack(); - let mut scratch = vec![C::ChallengeField::default(); y_simd_unpacked.len()]; - MultiLinearPoly::evaluate_with_buffer(&y_simd_unpacked, x_simd, &mut scratch) - } - - pub fn eval( - vals: &[C::SimdCircuitField], - x: &[C::ChallengeField], - x_simd: &[C::ChallengeField], - x_mpi: &[C::ChallengeField], - ) -> C::ChallengeField { - let local_poly_size = vals.len() >> x_mpi.len(); - let local_evals = vals - .chunks(local_poly_size) - .map(|local_vals| Self::eval_local(local_vals, x, x_simd)) - .collect::>(); - - let mut scratch = vec![C::ChallengeField::default(); local_evals.len()]; - MultiLinearPoly::evaluate_with_buffer(&local_evals, x_mpi, &mut scratch) + let v_target = + MultiLinearPolyExpander::::single_core_eval_circuit_vals_at_expander_challenge( + &commitment.evals, + x, + x_simd, + x_mpi, + ); + v == v_target } } diff --git a/poly_commit/tests/common.rs b/poly_commit/tests/common.rs index 0145657b..4f7debac 100644 --- a/poly_commit/tests/common.rs +++ b/poly_commit/tests/common.rs @@ -6,7 +6,7 @@ use poly_commit::raw::RawExpanderGKR; use poly_commit::{ ExpanderGKRChallenge, PCSForExpanderGKR, PolynomialCommitmentScheme, StructuredReferenceString, }; -use polynomials::MultilinearExtension; +use polynomials::{MultiLinearPolyExpander, MultilinearExtension}; use rand::thread_rng; use transcript::Transcript; @@ -91,7 +91,13 @@ pub fn test_pcs_for_expander_gkr< if mpi_config.is_root() { // this will always pass for RawExpanderGKR, so make sure it is correct - let v = RawExpanderGKR::::eval(&coeffs_gathered, x, x_simd, x_mpi); + let v = + MultiLinearPolyExpander::::single_core_eval_circuit_vals_at_expander_challenge( + &coeffs_gathered, + x, + x_simd, + x_mpi, + ); transcript.lock_proof(); assert!(P::verify( From ce2d439ced498e0588f4796040607e3c1d64317d Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 24 Feb 2025 20:31:10 -0600 Subject: [PATCH 13/14] clippy auto fix --- arith/polynomials/src/mle.rs | 10 +++++----- gkr/src/prover/gkr.rs | 3 +-- gkr/src/prover/gkr_square.rs | 1 - poly_commit/src/orion/simd_field_agg_tests.rs | 2 +- poly_commit/src/raw.rs | 2 +- poly_commit/tests/common.rs | 1 - 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/arith/polynomials/src/mle.rs b/arith/polynomials/src/mle.rs index 9bf9c55c..d74b05d4 100644 --- a/arith/polynomials/src/mle.rs +++ b/arith/polynomials/src/mle.rs @@ -285,21 +285,21 @@ impl MultiLinearPolyExpander { scratch_challenge_field, ); - let global_v = if mpi_config.is_root() { + + + if mpi_config.is_root() { let mut claimed_v_gathering_buffer = vec![C::ChallengeField::zero(); mpi_config.world_size()]; mpi_config.gather_vec(&vec![local_v], &mut claimed_v_gathering_buffer); MultiLinearPoly::evaluate_with_buffer( &claimed_v_gathering_buffer, - &x_mpi, + x_mpi, scratch_challenge_field, ) } else { mpi_config.gather_vec(&vec![local_v], &mut vec![]); C::ChallengeField::zero() - }; - - global_v + } } /// This assumes only a single core holds all the evals, and evaluate it locally diff --git a/gkr/src/prover/gkr.rs b/gkr/src/prover/gkr.rs index e8f2d8b3..95272572 100644 --- a/gkr/src/prover/gkr.rs +++ b/gkr/src/prover/gkr.rs @@ -1,10 +1,9 @@ //! This module implements the core GKR IOP. -use arith::{Field, SimdField}; use circuit::Circuit; use gkr_field_config::GKRFieldConfig; use mpi_config::MPIConfig; -use polynomials::{MultiLinearPoly, MultiLinearPolyExpander}; +use polynomials::MultiLinearPolyExpander; use sumcheck::{sumcheck_prove_gkr_layer, ProverScratchPad}; use transcript::Transcript; use utils::timer::Timer; diff --git a/gkr/src/prover/gkr_square.rs b/gkr/src/prover/gkr_square.rs index 210d07d7..66404d95 100644 --- a/gkr/src/prover/gkr_square.rs +++ b/gkr/src/prover/gkr_square.rs @@ -1,7 +1,6 @@ // an implementation of the GKR^2 protocol //! This module implements the core GKR^2 IOP. -use arith::{Field, SimdField}; use ark_std::{end_timer, start_timer}; use circuit::Circuit; use gkr_field_config::GKRFieldConfig; diff --git a/poly_commit/src/orion/simd_field_agg_tests.rs b/poly_commit/src/orion/simd_field_agg_tests.rs index 72fb88cf..600dac39 100644 --- a/poly_commit/src/orion/simd_field_agg_tests.rs +++ b/poly_commit/src/orion/simd_field_agg_tests.rs @@ -12,7 +12,7 @@ use transcript::{BytesHashTranscript, Keccak256hasher, Transcript}; use crate::{ orion::{simd_field_agg_impl::*, utils::*, *}, - ExpanderGKRChallenge, RawExpanderGKR, + ExpanderGKRChallenge, }; #[derive(Clone)] diff --git a/poly_commit/src/raw.rs b/poly_commit/src/raw.rs index 4bae7b90..0c837fc2 100644 --- a/poly_commit/src/raw.rs +++ b/poly_commit/src/raw.rs @@ -3,7 +3,7 @@ use crate::{ ExpanderGKRChallenge, PCSEmptyType, PCSForExpanderGKR, PolynomialCommitmentScheme, StructuredReferenceString, }; -use arith::{BN254Fr, ExtensionField, Field, FieldForECC, FieldSerde, FieldSerdeResult, SimdField}; +use arith::{BN254Fr, ExtensionField, Field, FieldForECC, FieldSerde, FieldSerdeResult}; use ethnum::U256; use gkr_field_config::GKRFieldConfig; use mpi_config::MPIConfig; diff --git a/poly_commit/tests/common.rs b/poly_commit/tests/common.rs index 4f7debac..026f7a50 100644 --- a/poly_commit/tests/common.rs +++ b/poly_commit/tests/common.rs @@ -2,7 +2,6 @@ use arith::{ExtensionField, Field}; use ark_std::test_rng; use gkr_field_config::GKRFieldConfig; use mpi_config::MPIConfig; -use poly_commit::raw::RawExpanderGKR; use poly_commit::{ ExpanderGKRChallenge, PCSForExpanderGKR, PolynomialCommitmentScheme, StructuredReferenceString, }; From 068bb816f5953783844c1d2c0d914c0a3d9708a5 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 24 Feb 2025 20:34:19 -0600 Subject: [PATCH 14/14] how comes clippy auto fix invalidates fmt... --- arith/polynomials/src/mle.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/arith/polynomials/src/mle.rs b/arith/polynomials/src/mle.rs index d74b05d4..a30bbda5 100644 --- a/arith/polynomials/src/mle.rs +++ b/arith/polynomials/src/mle.rs @@ -285,8 +285,6 @@ impl MultiLinearPolyExpander { scratch_challenge_field, ); - - if mpi_config.is_root() { let mut claimed_v_gathering_buffer = vec![C::ChallengeField::zero(); mpi_config.world_size()];