Skip to content

Gkr2 #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 2, 2025
Merged

Gkr2 #207

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:

- uses: Swatinem/rust-cache@v2
with:
prefix-key: "mpi-v5.0.6"
prefix-key: "mpi-v5.0.7"

- name: Install MPI for MacOS workflow
if: matrix.os == 'macos-latest'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
- name: Cache Rust dependencies
uses: Swatinem/rust-cache@v2
with:
prefix-key: "mpi-v5.0.6"
prefix-key: "mpi-v5.0.7"

- uses: actions/setup-go@v5
if: matrix.os != '7950x3d'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly_e2e._yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- uses: Swatinem/rust-cache@v2
with:
# The prefix cache key, this can be changed to start a new cache manually.
prefix-key: "mpi-v5.0.6" # update me if brew formula changes to a new version
prefix-key: "mpi-v5.0.7" # update me if brew formula changes to a new version

- name: Run tests
run: |
Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions arith/polynomials/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ edition = "2021"

[dependencies]
arith = { path = "../" }
gkr_field_config = { path = "../../config/gkr_field_config" }
mpi_config = { path = "../../config/mpi_config" }

ark-std.workspace = true
criterion.workspace = true
Expand Down
132 changes: 130 additions & 2 deletions arith/polynomials/src/mle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use std::ops::{Index, IndexMut, Mul};
use std::{
cmp,
marker::PhantomData,
ops::{Index, IndexMut, Mul},
};

use arith::Field;
use arith::{Field, SimdField};
use ark_std::{log2, rand::RngCore};
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;

use crate::{EqPolynomial, MultilinearExtension, MutableMultilinearExtension};

Expand Down Expand Up @@ -205,3 +211,125 @@ impl<F: Field> MutableMultilinearExtension<F> for MultiLinearPoly<F> {
}
}
}

#[derive(Debug, Clone, Default)]
pub struct MultiLinearPolyExpander<C: GKRFieldConfig> {
_config: PhantomData<C>,
}

/// Some dedicated mle implementations for GKRFieldConfig
/// Take into consideration the simd challenge and the mpi challenge
///
/// This is more efficient than the generic implementation by avoiding
/// unnecessary conversions between field types
impl<C: GKRFieldConfig> MultiLinearPolyExpander<C> {
pub fn new() -> Self {
Self {
_config: PhantomData,
}
}

#[inline]
pub fn eval_circuit_vals_at_challenge(
evals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
scratch: &mut [C::Field],
) -> C::Field {
assert_eq!(1 << x.len(), evals.len());
assert!(scratch.len() >= evals.len());

if x.is_empty() {
C::simd_circuit_field_into_field(&evals[0])
} else {
for i in 0..(evals.len() >> 1) {
scratch[i] = C::field_add_simd_circuit_field(
&C::simd_circuit_field_mul_challenge_field(
&(evals[i * 2 + 1] - evals[i * 2]),
&x[0],
),
&evals[i * 2],
);
}

let mut cur_eval_size = evals.len() >> 2;
for r in x.iter().skip(1) {
for i in 0..cur_eval_size {
scratch[i] = scratch[i * 2] + (scratch[i * 2 + 1] - scratch[i * 2]).scale(r);
}
cur_eval_size >>= 1;
}
scratch[0]
}
}

/// This assumes each mpi core hold their own evals, and collectively
/// compute the global evaluation.
/// Mostly used by the prover run with `mpiexec`
#[inline]
pub fn collectively_eval_circuit_vals_at_expander_challenge(
local_evals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
x_simd: &[C::ChallengeField],
x_mpi: &[C::ChallengeField],
scratch_field: &mut [C::Field],
scratch_challenge_field: &mut [C::ChallengeField],
mpi_config: &MPIConfig,
) -> C::ChallengeField {
assert!(scratch_challenge_field.len() >= 1 << cmp::max(x_simd.len(), x_mpi.len()));

let local_simd = Self::eval_circuit_vals_at_challenge(local_evals, x, scratch_field);
let local_simd_unpacked = local_simd.unpack();
let local_v = MultiLinearPoly::evaluate_with_buffer(
&local_simd_unpacked,
x_simd,
scratch_challenge_field,
);

if mpi_config.is_root() {
let mut claimed_v_gathering_buffer =
vec![C::ChallengeField::zero(); mpi_config.world_size()];
mpi_config.gather_vec(&vec![local_v], &mut claimed_v_gathering_buffer);
MultiLinearPoly::evaluate_with_buffer(
&claimed_v_gathering_buffer,
x_mpi,
scratch_challenge_field,
)
} else {
mpi_config.gather_vec(&vec![local_v], &mut vec![]);
C::ChallengeField::zero()
}
}

/// This assumes only a single core holds all the evals, and evaluate it locally
/// mostly used by the verifier
#[inline]
pub fn single_core_eval_circuit_vals_at_expander_challenge(
global_vals: &[C::SimdCircuitField],
x: &[C::ChallengeField],
x_simd: &[C::ChallengeField],
x_mpi: &[C::ChallengeField],
) -> C::ChallengeField {
let local_poly_size = global_vals.len() >> x_mpi.len();
assert_eq!(local_poly_size, 1 << x.len());

let mut scratch_field = vec![C::Field::default(); local_poly_size];
let mut scratch_challenge_field =
vec![C::ChallengeField::default(); 1 << cmp::max(x_simd.len(), x_mpi.len())];
let local_evals = global_vals
.chunks(local_poly_size)
.map(|local_vals| {
let local_simd =
Self::eval_circuit_vals_at_challenge(local_vals, x, &mut scratch_field);
let local_simd_unpacked = local_simd.unpack();
MultiLinearPoly::evaluate_with_buffer(
&local_simd_unpacked,
x_simd,
&mut scratch_challenge_field,
)
})
.collect::<Vec<C::ChallengeField>>();

let mut scratch = vec![C::ChallengeField::default(); local_evals.len()];
MultiLinearPoly::evaluate_with_buffer(&local_evals, x_mpi, &mut scratch)
}
}
33 changes: 0 additions & 33 deletions config/gkr_field_config/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,37 +73,4 @@ pub trait GKRFieldConfig: Default + Debug + Clone + Send + Sync + 'static {
fn get_field_pack_size() -> usize {
Self::SimdCircuitField::PACK_SIZE
}

/// Evaluate the circuit values at the challenge
#[inline]
fn eval_circuit_vals_at_challenge(
evals: &[Self::SimdCircuitField],
x: &[Self::ChallengeField],
scratch: &mut [Self::Field],
) -> Self::Field {
assert_eq!(1 << x.len(), evals.len());

if x.is_empty() {
Self::simd_circuit_field_into_field(&evals[0])
} else {
for i in 0..(evals.len() >> 1) {
scratch[i] = Self::field_add_simd_circuit_field(
&Self::simd_circuit_field_mul_challenge_field(
&(evals[i * 2 + 1] - evals[i * 2]),
&x[0],
),
&evals[i * 2],
);
}

let mut cur_eval_size = evals.len() >> 2;
for r in x.iter().skip(1) {
for i in 0..cur_eval_size {
scratch[i] = scratch[i * 2] + (scratch[i * 2 + 1] - scratch[i * 2]).scale(r);
}
cur_eval_size >>= 1;
}
scratch[0]
}
}
}
2 changes: 1 addition & 1 deletion config/mpi_config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ impl MPIConfig {
ret
} else {
self.gather_vec(local_vec, &mut vec![]);
vec![]
vec![F::ZERO; local_vec.len()]
}
}

Expand Down
8 changes: 6 additions & 2 deletions crosslayer_prototype/src/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arith::{Field, SimdField};
use gkr_field_config::GKRFieldConfig;
use polynomials::MultiLinearPoly;
use polynomials::{MultiLinearPoly, MultiLinearPolyExpander};
use transcript::Transcript;

use crate::sumcheck::{sumcheck_prove_gather_layer, sumcheck_prove_scatter_layer};
Expand All @@ -18,7 +18,11 @@ pub fn prove_gkr<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
.generate_challenge_field_elements(final_layer_vals.len().trailing_zeros() as usize);
let r_simd = transcript
.generate_challenge_field_elements(C::get_field_pack_size().trailing_zeros() as usize);
let output_claim = C::eval_circuit_vals_at_challenge(final_layer_vals, &rz0, &mut sp.v_evals);
let output_claim = MultiLinearPolyExpander::<C>::eval_circuit_vals_at_challenge(
final_layer_vals,
&rz0,
&mut sp.v_evals,
);
let output_claim = MultiLinearPoly::<C::ChallengeField>::evaluate_with_buffer(
&output_claim.unpack(),
&r_simd,
Expand Down
32 changes: 10 additions & 22 deletions gkr/src/prover/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
//! This module implements the core GKR IOP.

use arith::{Field, SimdField};
use circuit::Circuit;
use gkr_field_config::GKRFieldConfig;
use mpi_config::MPIConfig;
use polynomials::MultiLinearPoly;
use polynomials::MultiLinearPolyExpander;
use sumcheck::{sumcheck_prove_gkr_layer, ProverScratchPad};
use transcript::Transcript;
use utils::timer::Timer;
Expand Down Expand Up @@ -44,27 +43,16 @@ pub fn gkr_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
let mut alpha = None;

let output_vals = &circuit.layers.last().unwrap().output_vals;

let claimed_v_simd = C::eval_circuit_vals_at_challenge(output_vals, &rz0, &mut sp.hg_evals);
let claimed_v_local = MultiLinearPoly::<C::ChallengeField>::evaluate_with_buffer(
&claimed_v_simd.unpack(),
&r_simd,
&mut sp.eq_evals_at_r_simd0,
);

let claimed_v = if mpi_config.is_root() {
let mut claimed_v_gathering_buffer =
vec![C::ChallengeField::zero(); mpi_config.world_size()];
mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer);
MultiLinearPoly::evaluate_with_buffer(
&claimed_v_gathering_buffer,
let claimed_v =
MultiLinearPolyExpander::<C>::collectively_eval_circuit_vals_at_expander_challenge(
output_vals,
&rz0,
&r_simd,
&r_mpi,
&mut sp.eq_evals_at_r_mpi0,
)
} else {
mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]);
C::ChallengeField::zero()
};
&mut sp.hg_evals,
&mut sp.eq_evals_first_half, // confusing name here..
mpi_config,
);

for i in (0..layer_num).rev() {
let timer = Timer::new(
Expand Down
56 changes: 50 additions & 6 deletions gkr/src/prover/gkr_square.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,29 @@

use ark_std::{end_timer, start_timer};
use circuit::Circuit;
use gkr_field_config::GKRFieldConfig;
use gkr_field_config::{FieldType, GKRFieldConfig};
use mpi_config::MPIConfig;
use polynomials::MultiLinearPolyExpander;
use sumcheck::{sumcheck_prove_gkr_square_layer, ProverScratchPad};
use transcript::Transcript;

#[allow(clippy::type_complexity)]
pub fn gkr_square_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
circuit: &Circuit<C>,
sp: &mut ProverScratchPad<C>,
transcript: &mut T,
) -> (C::Field, Vec<C::ChallengeField>) {
mpi_config: &MPIConfig,
) -> (
C::ChallengeField,
Vec<C::ChallengeField>,
Vec<C::ChallengeField>,
Vec<C::ChallengeField>,
) {
assert_ne!(
C::FIELD_TYPE,
FieldType::GF2,
"GF2 is not supported in GKR^2"
);
let timer = start_timer!(|| "gkr^2 prove");
let layer_num = circuit.layers.len();

Expand All @@ -20,11 +34,41 @@ pub fn gkr_square_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
rz0.push(transcript.generate_challenge_field_element());
}

let circuit_output = &circuit.layers.last().unwrap().output_vals;
let claimed_v = C::eval_circuit_vals_at_challenge(circuit_output, &rz0, &mut sp.hg_evals);
let mut r_simd = vec![];
for _i in 0..C::get_field_pack_size().trailing_zeros() {
r_simd.push(transcript.generate_challenge_field_element());
}
log::trace!("Initial r_simd: {:?}", r_simd);

let mut r_mpi = vec![];
for _ in 0..mpi_config.world_size().trailing_zeros() {
r_mpi.push(transcript.generate_challenge_field_element());
}

let output_vals = &circuit.layers.last().unwrap().output_vals;
let claimed_v =
MultiLinearPolyExpander::<C>::collectively_eval_circuit_vals_at_expander_challenge(
output_vals,
&rz0,
&r_simd,
&r_mpi,
&mut sp.hg_evals,
&mut sp.eq_evals_first_half, // confusing name here..
mpi_config,
);

log::trace!("Claimed v: {:?}", claimed_v);

for i in (0..layer_num).rev() {
rz0 = sumcheck_prove_gkr_square_layer(&circuit.layers[i], &rz0, transcript, sp);
(rz0, r_simd, r_mpi) = sumcheck_prove_gkr_square_layer(
&circuit.layers[i],
&rz0,
&r_simd,
&r_mpi,
transcript,
sp,
mpi_config,
);

log::trace!("Layer {} proved", i);
log::trace!("rz0.0: {:?}", rz0[0]);
Expand All @@ -33,5 +77,5 @@ pub fn gkr_square_prove<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
}

end_timer!(timer);
(claimed_v, rz0)
(claimed_v, rz0, r_simd, r_mpi)
}
Loading
Loading