Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Twist, d=1 #573

Merged
merged 11 commits into from
Feb 20, 2025
13 changes: 12 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions jolt-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ tokio = { version = "1.38.0", optional = true }
alloy-primitives = "0.7.6"
alloy-sol-types = "0.7.6"
once_cell = "1.19.0"
rand_distr = "0.4.3"

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
73 changes: 73 additions & 0 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::commitment::hyperkzg::HyperKZG;
use crate::poly::commitment::zeromorph::Zeromorph;
use crate::subprotocols::shout::ShoutProof;
use crate::subprotocols::twist::{TwistAlgorithm, TwistProof};
use crate::utils::math::Math;
use crate::utils::transcript::{KeccakTranscript, Transcript};
use ark_bn254::{Bn254, Fr};
use ark_std::test_rng;
use rand_core::RngCore;
use rand_distr::{Distribution, Zipf};
use serde::Serialize;

#[derive(Debug, Copy, Clone, clap::ValueEnum)]
Expand All @@ -26,6 +28,7 @@ pub enum BenchType {
Sha3,
Sha2Chain,
Shout,
Twist,
}

#[allow(unreachable_patterns)] // good errors on new BenchTypes
Expand All @@ -47,6 +50,7 @@ pub fn benchmarks(
fibonacci::<Fr, Zeromorph<Bn254, KeccakTranscript>, KeccakTranscript>()
}
BenchType::Shout => shout::<Fr, KeccakTranscript>(),
BenchType::Twist => twist::<Fr, KeccakTranscript>(),
_ => panic!("BenchType does not have a mapping"),
},
PCSType::HyperKZG => match bench_type {
Expand All @@ -59,6 +63,7 @@ pub fn benchmarks(
fibonacci::<Fr, HyperKZG<Bn254, KeccakTranscript>, KeccakTranscript>()
}
BenchType::Shout => shout::<Fr, KeccakTranscript>(),
BenchType::Twist => twist::<Fr, KeccakTranscript>(),
_ => panic!("BenchType does not have a mapping"),
},
_ => panic!("PCS Type does not have a mapping"),
Expand Down Expand Up @@ -105,6 +110,74 @@ where
tasks
}

fn twist<F, ProofTranscript>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
ProofTranscript: Transcript,
{
let small_value_lookup_tables = F::compute_lookup_tables();
F::initialize_lookup_tables(small_value_lookup_tables);

let mut tasks = Vec::new();

const K: usize = 1 << 10;
const T: usize = 1 << 20;
const ZIPF_S: f64 = 0.0;
let zipf = Zipf::new(K as u64, ZIPF_S).unwrap();

let mut rng = test_rng();

let mut registers = [0u32; K];
let mut read_addresses: Vec<usize> = Vec::with_capacity(T);
let mut read_values: Vec<u32> = Vec::with_capacity(T);
let mut write_addresses: Vec<usize> = Vec::with_capacity(T);
let mut write_values: Vec<u32> = Vec::with_capacity(T);
let mut write_increments: Vec<i64> = Vec::with_capacity(T);
for _ in 0..T {
// Random read register
let read_address = zipf.sample(&mut rng) as usize - 1;
// Random write register
let write_address = zipf.sample(&mut rng) as usize - 1;
read_addresses.push(read_address);
write_addresses.push(write_address);
// Read the value currently in the read register
read_values.push(registers[read_address]);
// Random write value
let write_value = rng.next_u32();
write_values.push(write_value);
// The increment is the difference between the new value and the old value
let write_increment = (write_value as i64) - (registers[write_address] as i64);
write_increments.push(write_increment);
// Write the new value to the write register
registers[write_address] = write_value;
}

let mut prover_transcript = ProofTranscript::new(b"test_transcript");
let r: Vec<F> = prover_transcript.challenge_vector(K.log_2());
let r_prime: Vec<F> = prover_transcript.challenge_vector(T.log_2());

let task = move || {
let _proof = TwistProof::prove(
read_addresses,
read_values,
write_addresses,
write_values,
write_increments,
r.clone(),
r_prime.clone(),
&mut prover_transcript,
TwistAlgorithm::Local,
);
};

tasks.push((
tracing::info_span!("Twist d=1"),
Box::new(task) as Box<dyn FnOnce()>,
));

tasks
}

fn fibonacci<F, PCS, ProofTranscript>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
Expand Down
37 changes: 18 additions & 19 deletions jolt-core/src/poly/eq_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ impl<F: JoltField> EqPolynomial<F> {

#[tracing::instrument(skip_all, name = "EqPolynomial::evals")]
pub fn evals(r: &[F]) -> Vec<F> {
let ell = r.len();

match ell {
0..=PARALLEL_THRESHOLD => Self::evals_serial(r, ell, None),
_ => Self::evals_parallel(r, ell, None),
match r.len() {
0..=PARALLEL_THRESHOLD => Self::evals_serial(r, None),
_ => Self::evals_parallel(r, None),
}
}

Expand All @@ -45,19 +43,19 @@ impl<F: JoltField> EqPolynomial<F> {
/// the dynamic programming tree to R^2 instead of 1.
#[tracing::instrument(skip_all, name = "EqPolynomial::evals_with_r2")]
pub fn evals_with_r2(r: &[F]) -> Vec<F> {
let ell = r.len();

match ell {
0..=PARALLEL_THRESHOLD => Self::evals_serial(r, ell, F::montgomery_r2()),
_ => Self::evals_parallel(r, ell, F::montgomery_r2()),
match r.len() {
0..=PARALLEL_THRESHOLD => Self::evals_serial(r, F::montgomery_r2()),
_ => Self::evals_parallel(r, F::montgomery_r2()),
}
}

/// Computes evals serially. Uses less memory (and fewer allocations) than `evals_parallel`.
fn evals_serial(r: &[F], ell: usize, r2: Option<F>) -> Vec<F> {
let mut evals: Vec<F> = vec![r2.unwrap_or(F::one()); ell.pow2()];
/// Computes the table of coefficients:
/// scaling_factor * eq(r, x) for all x in {0, 1}^n
/// serially. More efficient for short `r`.
fn evals_serial(r: &[F], scaling_factor: Option<F>) -> Vec<F> {
let mut evals: Vec<F> = vec![scaling_factor.unwrap_or(F::one()); r.len().pow2()];
let mut size = 1;
for j in 0..ell {
for j in 0..r.len() {
// in each iteration, we double the size of chis
size *= 2;
for i in (0..size).rev().step_by(2) {
Expand All @@ -70,14 +68,15 @@ impl<F: JoltField> EqPolynomial<F> {
evals
}

/// Computes evals in parallel. Uses more memory and allocations than `evals_serial`, but
/// evaluates biggest layers of the dynamic programming tree in parallel.
/// Computes the table of coefficients:
/// scaling_factor * eq(r, x) for all x in {0, 1}^n
/// computing biggest layers of the dynamic programming tree in parallel.
#[tracing::instrument(skip_all, "EqPolynomial::evals_parallel")]
pub fn evals_parallel(r: &[F], ell: usize, r2: Option<F>) -> Vec<F> {
let final_size = (2usize).pow(ell as u32);
pub fn evals_parallel(r: &[F], scaling_factor: Option<F>) -> Vec<F> {
let final_size = r.len().pow2();
let mut evals: Vec<F> = unsafe_allocate_zero_vec(final_size);
let mut size = 1;
evals[0] = r2.unwrap_or(F::one());
evals[0] = scaling_factor.unwrap_or(F::one());

for r in r.iter().rev() {
let (evals_left, evals_right) = evals.split_at_mut(size);
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/subprotocols/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod grand_product_quarks;
pub mod shout;
pub mod sparse_grand_product;
pub mod sumcheck;
pub mod twist;

#[derive(Clone, Copy, Debug, Default)]
pub enum QuarkHybridLayerDepth {
Expand Down
Loading