Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gkr2 #207

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions arith/polynomials/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ edition = "2021"

[dependencies]
arith = { path = "../" }
gkr_field_config = { path = "../../config/gkr_field_config" }
mpi_config = { path = "../../config/mpi_config" }

ark-std.workspace = true
criterion.workspace = true
Expand Down
132 changes: 130 additions & 2 deletions arith/polynomials/src/mle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use std::ops::{Index, IndexMut, Mul};
use std::{
cmp,
marker::PhantomData,
ops::{Index, IndexMut, Mul},
};

use arith::Field;
use arith::{Field, SimdField};
use ark_std::{log2, rand::RngCore};
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;

use crate::{EqPolynomial, MultilinearExtension, MutableMultilinearExtension};

Expand Down Expand Up @@ -205,3 +211,125 @@ impl<F: Field> MutableMultilinearExtension<F> for MultiLinearPoly<F> {
}
}
}

#[derive(Debug, Clone, Default)]
pub struct MultiLinearPolyExpander<C: GKRFieldConfig> {
_config: PhantomData<C>,
}

/// Some dedicated mle implementations for GKRFieldConfig
/// Take into consideration the simd challenge and the mpi challenge
///
/// This is more efficient than the generic implementation by avoiding
/// unnecessary conversions between field types
impl<C: GKRFieldConfig> MultiLinearPolyExpander<C> {
pub fn new() -> Self {
Self {
_config: PhantomData,
}
}

#[inline]
pub fn eval_circuit_vals_at_challenge(
evals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
scratch: &mut [C::Field],
) -> C::Field {
assert_eq!(1 << x.len(), evals.len());
assert!(scratch.len() >= evals.len());

if x.is_empty() {
C::simd_circuit_field_into_field(&evals[0])
} else {
for i in 0..(evals.len() >> 1) {
scratch[i] = C::field_add_simd_circuit_field(
&C::simd_circuit_field_mul_challenge_field(
&(evals[i * 2 + 1] - evals[i * 2]),
&x[0],
),
&evals[i * 2],
);
}

let mut cur_eval_size = evals.len() >> 2;
for r in x.iter().skip(1) {
for i in 0..cur_eval_size {
scratch[i] = scratch[i * 2] + (scratch[i * 2 + 1] - scratch[i * 2]).scale(r);
}
cur_eval_size >>= 1;
}
scratch[0]
}
}

/// This assumes each mpi core hold their own evals, and collectively
/// compute the global evaluation.
/// Mostly used by the prover run with `mpiexec`
#[inline]
pub fn collectively_eval_circuit_vals_at_expander_challenge(
local_evals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
x_simd: &[C::ChallengeField],
x_mpi: &[C::ChallengeField],
scratch_field: &mut [C::Field],
scratch_challenge_field: &mut [C::ChallengeField],
mpi_config: &MPIConfig,
) -> C::ChallengeField {
assert!(scratch_challenge_field.len() >= 1 << cmp::max(x_simd.len(), x_mpi.len()));

let local_simd = Self::eval_circuit_vals_at_challenge(local_evals, x, scratch_field);
let local_simd_unpacked = local_simd.unpack();
let local_v = MultiLinearPoly::evaluate_with_buffer(
&local_simd_unpacked,
x_simd,
scratch_challenge_field,
);

if mpi_config.is_root() {
let mut claimed_v_gathering_buffer =
vec![C::ChallengeField::zero(); mpi_config.world_size()];
mpi_config.gather_vec(&vec![local_v], &mut claimed_v_gathering_buffer);
MultiLinearPoly::evaluate_with_buffer(
&claimed_v_gathering_buffer,
x_mpi,
scratch_challenge_field,
)
} else {
mpi_config.gather_vec(&vec![local_v], &mut vec![]);
C::ChallengeField::zero()
}
}

/// This assumes only a single core holds all the evals, and evaluate it locally
/// mostly used by the verifier
#[inline]
pub fn single_core_eval_circuit_vals_at_expander_challenge(
global_vals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
x_simd: &[C::ChallengeField],
x_mpi: &[C::ChallengeField],
) -> C::ChallengeField {
let local_poly_size = global_vals.len() >> x_mpi.len();
assert_eq!(local_poly_size, 1 << x.len());

let mut scratch_field = vec![C::Field::default(); local_poly_size];
let mut scratch_challenge_field =
vec![C::ChallengeField::default(); 1 << cmp::max(x_simd.len(), x_mpi.len())];
let local_evals = global_vals
.chunks(local_poly_size)
.map(|local_vals| {
let local_simd =
Self::eval_circuit_vals_at_challenge(local_vals, x, &mut scratch_field);
let local_simd_unpacked = local_simd.unpack();
MultiLinearPoly::evaluate_with_buffer(
&local_simd_unpacked,
x_simd,
&mut scratch_challenge_field,
)
})
.collect::<Vec<C::ChallengeField>>();

let mut scratch = vec![C::ChallengeField::default(); local_evals.len()];
MultiLinearPoly::evaluate_with_buffer(&local_evals, x_mpi, &mut scratch)
}
}
33 changes: 0 additions & 33 deletions config/gkr_field_config/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,37 +73,4 @@ pub trait GKRFieldConfig: Default + Debug + Clone + Send + Sync + 'static {
fn get_field_pack_size() -> usize {
Self::SimdCircuitField::PACK_SIZE
}

/// Evaluate the circuit values at the challenge
#[inline]
fn eval_circuit_vals_at_challenge(
evals: &[Self::SimdCircuitField],
x: &[Self::ChallengeField],
scratch: &mut [Self::Field],
) -> Self::Field {
assert_eq!(1 << x.len(), evals.len());

if x.is_empty() {
Self::simd_circuit_field_into_field(&evals[0])
} else {
for i in 0..(evals.len() >> 1) {
scratch[i] = Self::field_add_simd_circuit_field(
&Self::simd_circuit_field_mul_challenge_field(
&(evals[i * 2 + 1] - evals[i * 2]),
&x[0],
),
&evals[i * 2],
);
}

let mut cur_eval_size = evals.len() >> 2;
for r in x.iter().skip(1) {
for i in 0..cur_eval_size {
scratch[i] = scratch[i * 2] + (scratch[i * 2 + 1] - scratch[i * 2]).scale(r);
}
cur_eval_size >>= 1;
}
scratch[0]
}
}
}
8 changes: 6 additions & 2 deletions crosslayer_prototype/src/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arith::{Field, SimdField};
use gkr_field_config::GKRFieldConfig;
use polynomials::MultiLinearPoly;
use polynomials::{MultiLinearPoly, MultiLinearPolyExpander};
use transcript::Transcript;

use crate::sumcheck::{sumcheck_prove_gather_layer, sumcheck_prove_scatter_layer};
Expand All @@ -18,7 +18,11 @@ pub fn prove_gkr<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
.generate_challenge_field_elements(final_layer_vals.len().trailing_zeros() as usize);
let r_simd = transcript
.generate_challenge_field_elements(C::get_field_pack_size().trailing_zeros() as usize);
let output_claim = C::eval_circuit_vals_at_challenge(final_layer_vals, &rz0, &mut sp.v_evals);
let output_claim = MultiLinearPolyExpander::<C>::eval_circuit_vals_at_challenge(
final_layer_vals,
&rz0,
&mut sp.v_evals,
);
let output_claim = MultiLinearPoly::<C::ChallengeField>::evaluate_with_buffer(
&output_claim.unpack(),
&r_simd,
Expand Down
32 changes: 10 additions & 22 deletions gkr/src/prover/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
//! This module implements the core GKR IOP.

use arith::{Field, SimdField};
use circuit::Circuit;
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;
use polynomials::MultiLinearPoly;
use polynomials::MultiLinearPolyExpander;
use sumcheck::{sumcheck_prove_gkr_layer, ProverScratchPad};
use transcript::Transcript;
use utils::timer::Timer;
Expand Down Expand Up @@ -44,27 +43,16 @@ pub fn gkr_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
let mut alpha = None;

let output_vals = &circuit.layers.last().unwrap().output_vals;

let claimed_v_simd = C::eval_circuit_vals_at_challenge(output_vals, &rz0, &mut sp.hg_evals);
let claimed_v_local = MultiLinearPoly::<C::ChallengeField>::evaluate_with_buffer(
&claimed_v_simd.unpack(),
&r_simd,
&mut sp.eq_evals_at_r_simd0,
);

let claimed_v = if mpi_config.is_root() {
let mut claimed_v_gathering_buffer =
vec![C::ChallengeField::zero(); mpi_config.world_size()];
mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer);
MultiLinearPoly::evaluate_with_buffer(
&claimed_v_gathering_buffer,
let claimed_v =
MultiLinearPolyExpander::<C>::collectively_eval_circuit_vals_at_expander_challenge(
output_vals,
&rz0,
&r_simd,
&r_mpi,
&mut sp.eq_evals_at_r_mpi0,
)
} else {
mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]);
C::ChallengeField::zero()
};
&mut sp.hg_evals,
&mut sp.eq_evals_first_half, // confusing name here..
mpi_config,
);

for i in (0..layer_num).rev() {
let timer = Timer::new(
Expand Down
49 changes: 44 additions & 5 deletions gkr/src/prover/gkr_square.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@
use ark_std::{end_timer, start_timer};
use circuit::Circuit;
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;
use polynomials::MultiLinearPolyExpander;
use sumcheck::{sumcheck_prove_gkr_square_layer, ProverScratchPad};
use transcript::Transcript;

#[allow(clippy::type_complexity)]
pub fn gkr_square_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
circuit: &Circuit<C>,
sp: &mut ProverScratchPad<C>,
transcript: &mut T,
) -> (C::Field, Vec<C::ChallengeField>) {
mpi_config: &MPIConfig,
) -> (
C::ChallengeField,
Vec<C::ChallengeField>,
Vec<C::ChallengeField>,
Vec<C::ChallengeField>,
) {
let timer = start_timer!(|| "gkr^2 prove");
let layer_num = circuit.layers.len();

Expand All @@ -20,11 +29,41 @@ pub fn gkr_square_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
rz0.push(transcript.generate_challenge_field_element());
}

let circuit_output = &circuit.layers.last().unwrap().output_vals;
let claimed_v = C::eval_circuit_vals_at_challenge(circuit_output, &rz0, &mut sp.hg_evals);
let mut r_simd = vec![];
for _i in 0..C::get_field_pack_size().trailing_zeros() {
r_simd.push(transcript.generate_challenge_field_element());
}
log::trace!("Initial r_simd: {:?}", r_simd);

let mut r_mpi = vec![];
for _ in 0..mpi_config.world_size().trailing_zeros() {
r_mpi.push(transcript.generate_challenge_field_element());
}

let output_vals = &circuit.layers.last().unwrap().output_vals;
let claimed_v =
MultiLinearPolyExpander::<C>::collectively_eval_circuit_vals_at_expander_challenge(
output_vals,
&rz0,
&r_simd,
&r_mpi,
&mut sp.hg_evals,
&mut sp.eq_evals_first_half, // confusing name here..
mpi_config,
);

log::trace!("Claimed v: {:?}", claimed_v);

for i in (0..layer_num).rev() {
rz0 = sumcheck_prove_gkr_square_layer(&circuit.layers[i], &rz0, transcript, sp);
(rz0, r_simd, r_mpi) = sumcheck_prove_gkr_square_layer(
&circuit.layers[i],
&rz0,
&r_simd,
&r_mpi,
transcript,
sp,
mpi_config,
);

log::trace!("Layer {} proved", i);
log::trace!("rz0.0: {:?}", rz0[0]);
Expand All @@ -33,5 +72,5 @@ pub fn gkr_square_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
}

end_timer!(timer);
(claimed_v, rz0)
(claimed_v, rz0, r_simd, r_mpi)
}
8 changes: 3 additions & 5 deletions gkr/src/prover/linear_gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,13 @@ impl<Cfg: GKRConfig> Prover<Cfg> {
c.fill_rnd_coefs(&mut transcript);
c.evaluate();

let mut claimed_v = <Cfg::FieldConfig as GKRFieldConfig>::ChallengeField::default();
let rx;
let (claimed_v, rx, rsimd, rmpi);
let mut ry = None;
let mut rsimd = vec![];
let mut rmpi = vec![];

let gkr_prove_timer = Timer::new("gkr prove", self.config.mpi_config.is_root());
if self.config.gkr_scheme == GKRScheme::GkrSquare {
(_, rx) = gkr_square_prove(c, &mut self.sp, &mut transcript);
(claimed_v, rx, rsimd, rmpi) =
gkr_square_prove(c, &mut self.sp, &mut transcript, &self.config.mpi_config);
} else {
(claimed_v, rx, ry, rsimd, rmpi) =
gkr_prove(c, &mut self.sp, &mut transcript, &self.config.mpi_config);
Expand Down
Loading
Loading