diff --git a/core/src/air/trace.rs b/core/src/air/trace.rs index a89ecc51ea..48432fcbef 100644 --- a/core/src/air/trace.rs +++ b/core/src/air/trace.rs @@ -2,14 +2,30 @@ use p3_air::BaseAir; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; -use crate::runtime::ExecutionRecord; +use crate::runtime::{ExecutionRecord, Program}; +/// An AIR that is part of a Risc-V AIR arithmetization. pub trait MachineAir: BaseAir { + /// A unique identifier for this AIR as part of a machine. fn name(&self) -> String; + /// Generate the trace for a given execution record. + /// + /// The mutable borrow of `record` allows a `MachineAir` to store additional information in the + /// record, such as inserting events for other AIRs to process. fn generate_trace(&self, record: &mut ExecutionRecord) -> RowMajorMatrix; fn shard(&self, input: &ExecutionRecord, outputs: &mut Vec); fn include(&self, record: &ExecutionRecord) -> bool; + + /// The number of preprocessed columns in the trace. + fn preprocessed_width(&self) -> usize { + 0 + } + + #[allow(unused_variables)] + fn preprocessed_trace(&self, program: &Program) -> Option> { + None + } } diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index 6819147ba3..168c6ca699 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -2,7 +2,10 @@ use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use super::{air::BYTE_MULT_INDICES, ByteChip}; -use crate::{air::MachineAir, runtime::ExecutionRecord}; +use crate::{ + air::MachineAir, + runtime::{ExecutionRecord, Program}, +}; pub const NUM_ROWS: usize = 1 << 16; @@ -33,4 +36,16 @@ impl MachineAir for ByteChip { trace } + + fn preprocessed_width(&self) -> usize { + 10 + } + + fn preprocessed_trace(&self, _program: &Program) -> Option> { + let values = (0..10 * NUM_ROWS) + .map(|i| F::from_canonical_usize(i)) + .collect(); + + Some(RowMajorMatrix::new(values, 10)) + } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 9c1d1e27a7..8f6f9def60 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -37,7 +37,7 @@ use p3_matrix::dense::RowMajorMatrix; use runtime::{Program, Runtime}; use serde::de::DeserializeOwned; use serde::Serialize; -use stark::{MainData, OpeningProof, ProgramVerificationError, Proof}; +use stark::{OpeningProof, ProgramVerificationError, Proof, ShardMainData}; use stark::{RiscvStark, StarkGenericConfig}; use std::fs; use utils::{prove_core, BabyBearBlake3, StarkUtils}; @@ -96,7 +96,7 @@ impl CurtaProver { OpeningProof: Send + Sync, >>::Commitment: Send + Sync, >>::ProverData: Send + Sync, - MainData: Serialize + DeserializeOwned, + ShardMainData: Serialize + DeserializeOwned, ::Val: p3_field::PrimeField32, { let program = Program::from(elf); @@ -121,8 +121,9 @@ impl CurtaVerifier { ) -> Result<(), ProgramVerificationError> { let config = BabyBearBlake3::new(); let mut challenger = config.challenger(); - let (machine, prover_data) = RiscvStark::init(config); - machine.verify(&mut challenger, &proof.proof) + let machine = RiscvStark::new(config); + let (_, vk) = machine.setup(&Program::from(elf)); + machine.verify(&vk, &proof.proof, &mut challenger) } /// Verify a proof generated by `CurtaProver` with a custom config. @@ -138,12 +139,14 @@ impl CurtaVerifier { OpeningProof: Send + Sync, >>::Commitment: Send + Sync, >>::ProverData: Send + Sync, - MainData: Serialize + DeserializeOwned, + ShardMainData: Serialize + DeserializeOwned, ::Val: p3_field::PrimeField32, { let mut challenger = config.challenger(); - let (machine, prover_data) = RiscvStark::init(config); - machine.verify(&mut challenger, &proof.proof) + let machine = RiscvStark::new(config); + + let (_, vk) = machine.setup(&Program::from(elf)); + machine.verify(&vk, &proof.proof, &mut challenger) } } diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index cd91c5ef45..cd1dfd6c4e 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -219,7 +219,7 @@ mod tests { let mut runtime = Runtime::new(program); runtime.run(); - let (machine, _prover_data) = RiscvStark::init(BabyBearPoseidon2::new()); + let machine = RiscvStark::new(BabyBearPoseidon2::new()); debug_interactions_with_all_chips( &machine.chips(), &runtime.record, @@ -234,7 +234,7 @@ mod tests { let mut runtime = Runtime::new(program); runtime.run(); - let (machine, _prover_data) = RiscvStark::init(BabyBearPoseidon2::new()); + let machine = RiscvStark::new(BabyBearPoseidon2::new()); debug_interactions_with_all_chips( &machine.chips(), &runtime.record, diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 975be8abf7..02774b876c 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -28,6 +28,9 @@ pub struct ExecutionRecord { /// A trace of the CPU events which get emitted during execution. pub cpu_events: Vec, + /// Multiplicity counts for each instruction in the program. + pub instruction_counts: HashMap, + /// A trace of the ADD, and ADDI events. pub add_events: Vec, diff --git a/core/src/stark/chip.rs b/core/src/stark/chip.rs index fd6e019909..d250d4587d 100644 --- a/core/src/stark/chip.rs +++ b/core/src/stark/chip.rs @@ -6,7 +6,7 @@ use p3_util::log2_ceil_usize; use crate::{ air::{CurtaAirBuilder, MachineAir, MultiTableAirBuilder}, lookup::{Interaction, InteractionBuilder}, - runtime::ExecutionRecord, + runtime::{ExecutionRecord, Program}, }; use super::{ @@ -156,6 +156,14 @@ where fn include(&self, record: &ExecutionRecord) -> bool { self.air.include(record) } + + fn preprocessed_trace(&self, program: &Program) -> Option> { + >::preprocessed_trace(&self.air, program) + } + + fn preprocessed_width(&self) -> usize { + self.air.preprocessed_width() + } } // Implement AIR directly on Chip, evaluating both execution and permutation constraints. @@ -201,6 +209,14 @@ impl<'a, SC: StarkGenericConfig> MachineAir for ChipRef<'a, SC> { fn include(&self, record: &ExecutionRecord) -> bool { as MachineAir>::include(self.air, record) } + + fn preprocessed_trace(&self, program: &Program) -> Option> { + as MachineAir>::preprocessed_trace(self.air, program) + } + + fn preprocessed_width(&self) -> usize { + as MachineAir>::preprocessed_width(self.air) + } } impl<'a, 'b, SC: StarkGenericConfig> Air> for ChipRef<'a, SC> { diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index abe8bcf7c6..bf1b629831 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -14,6 +14,7 @@ use crate::memory::MemoryChipKind; use crate::memory::MemoryGlobalChip; use crate::program::ProgramChip; use crate::runtime::ExecutionRecord; +use crate::runtime::Program; use crate::syscall::precompiles::blake3::Blake3CompressInnerChip; use crate::syscall::precompiles::edwards::EdAddAssignChip; use crate::syscall::precompiles::edwards::EdDecompressChip; @@ -27,22 +28,18 @@ use crate::utils::ec::edwards::ed25519::Ed25519Parameters; use crate::utils::ec::edwards::EdwardsCurve; use crate::utils::ec::weierstrass::secp256k1::Secp256k1Parameters; use crate::utils::ec::weierstrass::SWCurve; -use p3_air::BaseAir; use p3_challenger::CanObserve; use p3_commit::Pcs; use p3_field::AbstractField; use p3_field::Field; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::prelude::*; -use serde::de::DeserializeOwned; -use serde::Serialize; +use p3_matrix::Dimensions; +use p3_matrix::Matrix; use super::Chip; use super::ChipRef; use super::Com; -use super::MainData; -use super::OpeningProof; use super::PcsProverData; use super::Proof; use super::Prover; @@ -50,13 +47,18 @@ use super::StarkGenericConfig; use super::VerificationError; use super::Verifier; -pub struct ProverData { - pub preprocessed_traces: Vec>>, - pub preprocessed_data: Option>, +pub struct ProvingKey { + pub data: PcsProverData, + pub byte_trace: RowMajorMatrix, + // TODO: + // program_trace: RowMajorMatrix, } -pub struct PublicParameters { - pub preprocessed_commitment: Option>, +pub struct VerifyingKey { + pub commit: Com, + pub byte_dimensions: Dimensions, + // TODO: + // program_dimensions: Dimensions, } pub struct RiscvStark { @@ -88,16 +90,13 @@ pub struct RiscvStark { memory_init: Chip, memory_finalize: Chip, program_memory_init: Chip, - - // Commitment to the preprocessed data - preprocessed_commitment: Option>, } impl RiscvStark where SC::Val: PrimeField32, { - pub fn init(config: SC) -> (Self, ProverData) { + pub fn new(config: SC) -> Self { let program = Chip::new(ProgramChip::default()); let cpu = Chip::new(CpuChip::default()); let sha_extend = Chip::new(ShaExtendChip::default()); @@ -125,7 +124,7 @@ where let memory_finalize = Chip::new(MemoryGlobalChip::new(MemoryChipKind::Finalize)); let program_memory_init = Chip::new(MemoryGlobalChip::new(MemoryChipKind::Program)); - let mut machine = Self { + Self { config, program, cpu, @@ -151,40 +150,10 @@ where memory_init, memory_finalize, program_memory_init, - - preprocessed_commitment: None, - }; - - // Compute commitments to the preprocessed data - let preprocessed_traces = machine - .chips() - .iter() - .map(|chip| chip.preprocessed_trace()) - .collect::>(); - let traces = preprocessed_traces - .iter() - .flatten() - .cloned() - .collect::>(); - let (commit, data) = if !traces.is_empty() { - Some(machine.config.pcs().commit_batches(traces)) - } else { - None } - .unzip(); - - // Store the commitments in the machine - machine.preprocessed_commitment = commit; - - ( - machine, - ProverData { - preprocessed_traces, - preprocessed_data: data, - }, - ) } + /// Get an array containing a `ChipRef` for all the chips of this RISC-V STARK machine. pub fn chips(&self) -> [ChipRef; 24] { [ self.program.as_ref(), @@ -214,24 +183,27 @@ where ] } - /// Prove the program. + /// The setup preprocessing phase. /// - /// The function returns a vector of segment proofs, one for each segment, and a global proof. - pub fn prove

( - &self, - prover_data: &ProverData, - record: &mut ExecutionRecord, - challenger: &mut SC::Challenger, - ) -> Proof - where - P: Prover, - SC: Send + Sync, - SC::Challenger: Clone, - >>::Commitment: Send + Sync, - >>::ProverData: Send + Sync, - MainData: Serialize + DeserializeOwned, - OpeningProof: Send + Sync, - { + /// Given a program, this function generates the proving and verifying keys. The keys correspond + /// to the program code and other preprocessed colunms such as lookup tables. + pub fn setup(&self, program: &Program) -> (ProvingKey, VerifyingKey) { + let byte_trace = self.byte.preprocessed_trace(program).unwrap(); + + let (commit, data) = self.config.pcs().commit_batches(vec![byte_trace.clone()]); + + // TODO: commit to the program trace as well. + + let verifying_key = VerifyingKey { + commit, + byte_dimensions: byte_trace.dimensions(), + }; + let proving_key = ProvingKey { data, byte_trace }; + + (proving_key, verifying_key) + } + + pub fn shard(&self, record: &mut ExecutionRecord) -> Vec { // Get the local and global chips. let chips = self.chips(); @@ -258,49 +230,35 @@ where chip.shard(record, &mut shards); }); - tracing::info!("Generating and commiting traces for each shard."); - // Generate and commit the traces for each segment. - let (shard_commits, shard_data) = P::commit_shards(&self.config, &mut shards, &chips); + shards + } - // Observe the challenges for each segment. - tracing::info_span!("observing all challenges").in_scope(|| { - shard_commits.into_iter().for_each(|commitment| { - challenger.observe(commitment); - }); - }); + /// Prove the execution record is valid. + /// + /// Given a proving key `pk` and a matching execution record `record`, this function generates + /// a STARK proof that the execution record is valid. + pub fn prove>( + &self, + pk: &ProvingKey, + record: &mut ExecutionRecord, + challenger: &mut SC::Challenger, + ) -> Proof { + tracing::info!("Sharding the execution record."); + let mut shards = self.shard(record); - // Generate a proof for each segment. Note that we clone the challenger so we can observe - // identical global challenges across the segments. - let shard_proofs = shard_data - .into_par_iter() - .map(|data| { - let data = tracing::info_span!("materializing data") - .in_scope(|| data.materialize().expect("failed to load shard main data")); - let chips = self - .chips() - .into_iter() - .filter(|chip| data.chip_ids.contains(&chip.name())) - .collect::>(); - tracing::info_span!("proving shard").in_scope(|| { - P::prove_shard( - &self.config, - &mut challenger.clone(), - &chips, - data, - &prover_data.preprocessed_traces, - &prover_data.preprocessed_data, - ) - }) - }) - .collect::>(); - - Proof { shard_proofs } + tracing::info!("Generating the shard proofs."); + P::prove_shards(self, pk, &mut shards, challenger) + } + + pub const fn config(&self) -> &SC { + &self.config } pub fn verify( &self, - challenger: &mut SC::Challenger, + _vk: &VerifyingKey, proof: &Proof, + challenger: &mut SC::Challenger, ) -> Result<(), ProgramVerificationError> where SC::Val: PrimeField32, @@ -363,19 +321,19 @@ pub mod tests { use crate::runtime::Opcode; use crate::runtime::Program; use crate::utils; - use crate::utils::prove; + use crate::utils::run_test; use crate::utils::setup_logger; #[test] fn test_simple_prove() { let program = simple_program(); - prove(program); + run_test(program).unwrap(); } #[test] fn test_ecall_lwa_prove() { let program = ecall_lwa_program(); - prove(program); + run_test(program).unwrap(); } #[test] @@ -396,7 +354,7 @@ pub mod tests { Instruction::new(*shift_op, 31, 29, 3, false, false), ]; let program = Program::new(instructions, 0, 0); - prove(program); + run_test(program).unwrap(); } } } @@ -409,7 +367,7 @@ pub mod tests { Instruction::new(Opcode::SUB, 31, 30, 29, false, false), ]; let program = Program::new(instructions, 0, 0); - prove(program); + run_test(program).unwrap(); } #[test] @@ -421,7 +379,7 @@ pub mod tests { Instruction::new(Opcode::ADD, 31, 30, 29, false, false), ]; let program = Program::new(instructions, 0, 0); - prove(program); + run_test(program).unwrap(); } #[test] @@ -443,7 +401,7 @@ pub mod tests { Instruction::new(*mul_op, 31, 30, 29, false, false), ]; let program = Program::new(instructions, 0, 0); - prove(program); + run_test(program).unwrap(); } } } @@ -458,7 +416,7 @@ pub mod tests { Instruction::new(*lt_op, 31, 30, 29, false, false), ]; let program = Program::new(instructions, 0, 0); - prove(program); + run_test(program).unwrap(); } } @@ -473,7 +431,7 @@ pub mod tests { Instruction::new(*bitwise_op, 31, 30, 29, false, false), ]; let program = Program::new(instructions, 0, 0); - prove(program); + run_test(program).unwrap(); } } @@ -495,7 +453,7 @@ pub mod tests { Instruction::new(*div_rem_op, 31, 29, 30, false, false), ]; let program = Program::new(instructions, 0, 0); - prove(program); + run_test(program).unwrap(); } } } @@ -504,12 +462,12 @@ pub mod tests { fn test_fibonacci_prove() { setup_logger(); let program = fibonacci_program(); - prove(program); + run_test(program).unwrap(); } #[test] fn test_simple_memory_program_prove() { let program = simple_memory_program(); - prove(program); + run_test(program).unwrap(); } } diff --git a/core/src/stark/mod.rs b/core/src/stark/mod.rs index 1f4cedde73..5e0cb81609 100644 --- a/core/src/stark/mod.rs +++ b/core/src/stark/mod.rs @@ -5,6 +5,7 @@ mod folder; mod machine; mod permutation; mod prover; +mod quotient; mod types; mod util; mod verifier; @@ -17,6 +18,7 @@ pub use folder::*; pub use machine::*; pub use permutation::*; pub use prover::*; +pub use quotient::*; pub use types::*; pub use verifier::*; diff --git a/core/src/stark/prover.rs b/core/src/stark/prover.rs index 68d9b9c7bd..5bb4b5b4c1 100644 --- a/core/src/stark/prover.rs +++ b/core/src/stark/prover.rs @@ -1,15 +1,16 @@ +use super::ProvingKey; +use super::{quotient_values, RiscvStark}; use itertools::izip; #[cfg(not(feature = "perf"))] use p3_air::BaseAir; -use p3_air::{Air, TwoRowMatrixView}; use p3_challenger::{CanObserve, FieldChallenger}; use p3_commit::{Pcs, UnivariatePcs, UnivariatePcsWithLde}; -use p3_field::{cyclic_subgroup_coset_known_order, AbstractExtensionField, AbstractField, Field}; -use p3_field::{ExtensionField, PackedField, PrimeField}; +use p3_field::{AbstractExtensionField, AbstractField}; +use p3_field::{ExtensionField, PrimeField}; use p3_field::{PrimeField32, TwoAdicField}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::MatrixRows; -use p3_matrix::{Matrix, MatrixGet, MatrixRowSlices}; +use p3_matrix::{Matrix, MatrixRowSlices}; use p3_maybe_rayon::prelude::*; use p3_util::log2_ceil_usize; use p3_util::log2_strict_usize; @@ -17,9 +18,7 @@ use serde::de::DeserializeOwned; use serde::Serialize; use std::marker::PhantomData; -use super::folder::ProverConstraintFolder; use super::util::decompose_and_flatten; -use super::zerofier_coset::ZerofierOnCoset; use super::{types::*, ChipRef, StarkGenericConfig}; use crate::air::MachineAir; use crate::runtime::ExecutionRecord; @@ -29,28 +28,82 @@ use crate::utils::env; #[cfg(not(feature = "perf"))] use crate::stark::debug_constraints; -pub trait Prover +pub trait Prover { + fn prove_shards( + machine: &RiscvStark, + pk: &ProvingKey, + shards: &mut Vec, + challenger: &mut SC::Challenger, + ) -> Proof; +} + +impl Prover for LocalProver where - SC: StarkGenericConfig, + SC::Val: PrimeField32 + TwoAdicField + Send + Sync, + SC: StarkGenericConfig + Send + Sync, + SC::Challenger: Clone, + Com: Send + Sync, + PcsProverData: Send + Sync, + PcsProof: Send + Sync, + ShardMainData: Serialize + DeserializeOwned, { - fn commit_shards( - config: &SC, + fn prove_shards( + machine: &RiscvStark, + pk: &ProvingKey, shards: &mut Vec, - chips: &[ChipRef], - ) -> ( - Vec<>>::Commitment>, - Vec>, - ) - where - F: PrimeField + TwoAdicField + PrimeField32, - EF: ExtensionField, - SC: StarkGenericConfig + Send + Sync, - SC::Challenger: Clone, - >>::Commitment: Send + Sync, - >>::ProverData: Send + Sync, - MainData: Serialize + DeserializeOwned; + challenger: &mut SC::Challenger, + ) -> Proof { + let config = machine.config(); + let all_chips = machine.chips(); + tracing::info!("Generating and commiting traces for each shard."); + // Generate and commit the traces for each segment. + let (shard_commits, shard_data) = Self::commit_shards(config, shards, &all_chips); + + // Observe the challenges for each segment. + tracing::info_span!("observing all challenges").in_scope(|| { + shard_commits.into_iter().for_each(|commitment| { + challenger.observe(commitment); + }); + }); + + // Generate a proof for each segment. Note that we clone the challenger so we can observe + // identical global challenges across the segments. + let shard_proofs = shard_data + .into_par_iter() + .map(|data| { + let data = tracing::info_span!("materializing data") + .in_scope(|| data.materialize().expect("failed to load shard main data")); + let chips = all_chips + .iter() + .filter(|chip| data.chip_ids.contains(&chip.name())) + .collect::>(); + tracing::info_span!("proving shard").in_scope(|| { + Self::prove_shard(config, pk, &chips, data, &mut challenger.clone()) + }) + }) + .collect::>(); - fn commit_main(config: &SC, chips: &[ChipRef], shard: &mut ExecutionRecord) -> MainData + Proof { shard_proofs } + } +} + +pub struct LocalProver(PhantomData); + +impl LocalProver +where + SC::Val: PrimeField + TwoAdicField + PrimeField32, + SC: StarkGenericConfig + Send + Sync, + SC::Challenger: Clone, + Com: Send + Sync, + PcsProverData: Send + Sync, + ShardMainData: Serialize + DeserializeOwned, +{ + fn commit_main( + config: &SC, + chips: &[ChipRef], + shard: &mut ExecutionRecord, + index: usize, + ) -> ShardMainData where SC::Val: PrimeField32, { @@ -75,30 +128,30 @@ where .map(|chip| chip.name()) .collect::>(); - MainData { + ShardMainData { traces, main_commit, main_data, chip_ids, + index, } } /// Prove the program for the given shard and given a commitment to the main data. fn prove_shard( config: &SC, + _pk: &ProvingKey, + chips: &[&ChipRef], + shard_data: ShardMainData, challenger: &mut SC::Challenger, - chips: &[ChipRef], - main_data: MainData, - preprocessed_traces: &[Option>], - _preprocessed_data: &Option>, ) -> ShardProof where SC::Val: PrimeField32, SC: Send + Sync, - MainData: DeserializeOwned, + ShardMainData: DeserializeOwned, { // Get the traces. - let traces = main_data.traces; + let traces = shard_data.traces; let log_degrees = traces .iter() @@ -130,12 +183,11 @@ where .par_iter() .zip(receives.par_iter()) .zip(traces.par_iter()) - .zip(preprocessed_traces.par_iter()) - .map(|(((send, rec), main_trace), prep_trace)| { + .map(|((send, rec), main_trace)| { let perm_trace = generate_permutation_trace( send, rec, - prep_trace, + &None, main_trace, &permutation_challenges, ); @@ -183,7 +235,7 @@ where let main_ldes = tracing::info_span!("get main ldes").in_scope(|| { config .pcs() - .get_ldes(&main_data.main_data) + .get_ldes(&shard_data.main_data) .into_iter() .map(|lde| lde.vertically_strided(1 << log_stride_for_quotient, 0)) .collect::>() @@ -203,9 +255,9 @@ where (0..chips.len()) .into_par_iter() .map(|i| { - Self::quotient_values( + quotient_values( config, - &chips[i], + chips[i], cumulative_sums[i], log_degrees[i], &main_ldes[i], @@ -275,7 +327,7 @@ where let (openings, opening_proof) = tracing::info_span!("open multi batches").in_scope(|| { config.pcs().open_multi_batches( &[ - (&main_data.main_data, &trace_opening_points), + (&shard_data.main_data, &trace_opening_points), (&permutation_data, &trace_opening_points), ("ient_data, "ient_opening_points), ], @@ -283,41 +335,6 @@ where ) }); - // Checking the shapes of openings match our expectations. - // - // This is a sanity check to make sure we are using the API correctly. We should remove this - // once everything is stable. - - #[cfg(not(feature = "perf"))] - { - // Check for the correct number of opening collections. - assert_eq!(openings.len(), 3); - - // Check the shape of the main trace openings. - assert_eq!(openings[0].len(), chips.len()); - for (chip, opening) in chips.iter().zip(openings[0].iter()) { - let width = chip.width(); - assert_eq!(opening.len(), 2); - assert_eq!(opening[0].len(), width); - assert_eq!(opening[1].len(), width); - } - // Check the shape of the permutation trace openings. - assert_eq!(openings[1].len(), chips.len()); - for (perm, opening) in permutation_traces.iter().zip(openings[1].iter()) { - let width = perm.width() * SC::Challenge::D; - assert_eq!(opening.len(), 2); - assert_eq!(opening[0].len(), width); - assert_eq!(opening[1].len(), width); - } - // Check the shape of the quotient openings. - assert_eq!(openings[2].len(), chips.len()); - for opening in openings[2].iter() { - let width = SC::Challenge::D << log_quotient_degree; - assert_eq!(opening.len(), 1); - assert_eq!(opening[0].len(), width); - } - } - #[cfg(feature = "perf")] { // Collect the opened values for each chip. @@ -364,8 +381,9 @@ where .collect::>(); ShardProof:: { + index: shard_data.index, commitment: ShardCommitment { - main_commit: main_data.main_commit.clone(), + main_commit: shard_data.main_commit.clone(), permutation_commit, quotient_commit, }, @@ -400,155 +418,13 @@ where }; } - #[allow(clippy::too_many_arguments)] - fn quotient_values( - config: &SC, - chip: &ChipRef, - cumulative_sum: SC::Challenge, - degree_bits: usize, - main_lde: &MainLde, - permutation_lde: &PermLde, - perm_challenges: &[SC::Challenge], - alpha: SC::Challenge, - ) -> Vec - where - SC: StarkGenericConfig, - MainLde: MatrixGet + Sync, - PermLde: MatrixGet + Sync, - { - let degree = 1 << degree_bits; - let quotient_degree_bits = chip.log_quotient_degree(); - let quotient_size_bits = degree_bits + quotient_degree_bits; - let quotient_size = 1 << quotient_size_bits; - let g_subgroup = SC::Val::two_adic_generator(degree_bits); - let g_extended = SC::Val::two_adic_generator(quotient_size_bits); - let subgroup_last = g_subgroup.inverse(); - let coset_shift = config.pcs().coset_shift(); - let next_step = 1 << quotient_degree_bits; - - let coset: Vec<_> = - cyclic_subgroup_coset_known_order(g_extended, coset_shift, quotient_size).collect(); - - let zerofier_on_coset = - ZerofierOnCoset::new(degree_bits, quotient_degree_bits, coset_shift); - - // Evaluations of L_first(x) = Z_H(x) / (x - 1) on our coset s H. - let lagrange_first_evals = zerofier_on_coset.lagrange_basis_unnormalized(0); - let lagrange_last_evals = zerofier_on_coset.lagrange_basis_unnormalized(degree - 1); - - let ext_degree = SC::Challenge::D; - - (0..quotient_size) - .into_par_iter() - .step_by(SC::PackedVal::WIDTH) - .flat_map_iter(|i_local_start| { - let wrap = |i| i % quotient_size; - let i_next_start = wrap(i_local_start + next_step); - let i_range = i_local_start..i_local_start + SC::PackedVal::WIDTH; - - let x = *SC::PackedVal::from_slice(&coset[i_range.clone()]); - let is_transition = x - subgroup_last; - let is_first_row = - *SC::PackedVal::from_slice(&lagrange_first_evals[i_range.clone()]); - let is_last_row = *SC::PackedVal::from_slice(&lagrange_last_evals[i_range]); - - let local: Vec<_> = (0..main_lde.width()) - .map(|col| { - SC::PackedVal::from_fn(|offset| { - let row = wrap(i_local_start + offset); - main_lde.get(row, col) - }) - }) - .collect(); - let next: Vec<_> = (0..main_lde.width()) - .map(|col| { - SC::PackedVal::from_fn(|offset| { - let row = wrap(i_next_start + offset); - main_lde.get(row, col) - }) - }) - .collect(); - - let perm_local: Vec<_> = (0..permutation_lde.width()) - .step_by(ext_degree) - .map(|col| { - SC::PackedChallenge::from_base_fn(|i| { - SC::PackedVal::from_fn(|offset| { - let row = wrap(i_local_start + offset); - permutation_lde.get(row, col + i) - }) - }) - }) - .collect(); - - let perm_next: Vec<_> = (0..permutation_lde.width()) - .step_by(ext_degree) - .map(|col| { - SC::PackedChallenge::from_base_fn(|i| { - SC::PackedVal::from_fn(|offset| { - let row = wrap(i_next_start + offset); - permutation_lde.get(row, col + i) - }) - }) - }) - .collect(); - - let accumulator = SC::PackedChallenge::zero(); - let mut folder = ProverConstraintFolder { - preprocessed: TwoRowMatrixView { - local: &[], - next: &[], - }, - main: TwoRowMatrixView { - local: &local, - next: &next, - }, - perm: TwoRowMatrixView { - local: &perm_local, - next: &perm_next, - }, - perm_challenges, - cumulative_sum, - is_first_row, - is_last_row, - is_transition, - alpha, - accumulator, - }; - chip.eval(&mut folder); - - // quotient(x) = constraints(x) / Z_H(x) - let zerofier_inv: SC::PackedVal = - zerofier_on_coset.eval_inverse_packed(i_local_start); - let quotient = folder.accumulator * zerofier_inv; - - // "Transpose" D packed base coefficients into WIDTH scalar extension coefficients. - (0..SC::PackedVal::WIDTH).map(move |idx_in_packing| { - let quotient_value = (0..>::D) - .map(|coeff_idx| { - quotient.as_base_slice()[coeff_idx].as_slice()[idx_in_packing] - }) - .collect::>(); - SC::Challenge::from_base_slice("ient_value) - }) - }) - .collect() - } -} - -pub struct LocalProver(PhantomData); - -impl Prover for LocalProver -where - SC: StarkGenericConfig, -{ fn commit_shards( config: &SC, shards: &mut Vec, chips: &[ChipRef], ) -> ( Vec<>>::Commitment>, - Vec>, + Vec>, ) where F: PrimeField + TwoAdicField + PrimeField32, @@ -557,7 +433,7 @@ where SC::Challenger: Clone, >>::Commitment: Send + Sync, >>::ProverData: Send + Sync, - MainData: Serialize + DeserializeOwned, + ShardMainData: Serialize + DeserializeOwned, { let num_shards = shards.len(); tracing::info!("num_shards={}", num_shards); @@ -568,9 +444,10 @@ where tracing::info_span!("commit main for all shards").in_scope(|| { shards .into_par_iter() - .map(|shard| { + .enumerate() + .map(|(i, shard)| { let data = tracing::info_span!("shard commit main", shard = shard.index) - .in_scope(|| Self::commit_main(config, chips, shard)); + .in_scope(|| Self::commit_main(config, chips, shard, i)); let commitment = data.main_commit.clone(); let file = tempfile::tempfile().unwrap(); let data = if num_shards > save_disk_threshold { @@ -592,8 +469,8 @@ where let bytes_written = shard_main_data .iter() .map(|data| match data { - MainDataWrapper::InMemory(_) => 0, - MainDataWrapper::TempFile(_, bytes_written) => *bytes_written, + ShardMainDataWrapper::InMemory(_) => 0, + ShardMainDataWrapper::TempFile(_, bytes_written) => *bytes_written, }) .sum::(); if bytes_written > 0 { diff --git a/core/src/stark/quotient.rs b/core/src/stark/quotient.rs new file mode 100644 index 0000000000..5865ccd260 --- /dev/null +++ b/core/src/stark/quotient.rs @@ -0,0 +1,143 @@ +use super::folder::ProverConstraintFolder; +use p3_air::Air; +use p3_air::TwoRowMatrixView; +use p3_commit::UnivariatePcsWithLde; +use p3_field::AbstractExtensionField; +use p3_field::AbstractField; +use p3_field::PackedField; +use p3_field::{cyclic_subgroup_coset_known_order, Field, TwoAdicField}; +use p3_matrix::MatrixGet; +use p3_maybe_rayon::prelude::*; + +use super::{zerofier_coset::ZerofierOnCoset, ChipRef, StarkGenericConfig}; + +#[allow(clippy::too_many_arguments)] +pub fn quotient_values( + config: &SC, + chip: &ChipRef, + cumulative_sum: SC::Challenge, + degree_bits: usize, + main_lde: &MainLde, + permutation_lde: &PermLde, + perm_challenges: &[SC::Challenge], + alpha: SC::Challenge, +) -> Vec +where + SC: StarkGenericConfig, + SC::Val: TwoAdicField, + MainLde: MatrixGet + Sync, + PermLde: MatrixGet + Sync, +{ + let degree = 1 << degree_bits; + let quotient_degree_bits = chip.log_quotient_degree(); + let quotient_size_bits = degree_bits + quotient_degree_bits; + let quotient_size = 1 << quotient_size_bits; + let g_subgroup = SC::Val::two_adic_generator(degree_bits); + let g_extended = SC::Val::two_adic_generator(quotient_size_bits); + let subgroup_last = g_subgroup.inverse(); + let coset_shift = config.pcs().coset_shift(); + let next_step = 1 << quotient_degree_bits; + + let coset: Vec<_> = + cyclic_subgroup_coset_known_order(g_extended, coset_shift, quotient_size).collect(); + + let zerofier_on_coset = ZerofierOnCoset::new(degree_bits, quotient_degree_bits, coset_shift); + + // Evaluations of L_first(x) = Z_H(x) / (x - 1) on our coset s H. + let lagrange_first_evals = zerofier_on_coset.lagrange_basis_unnormalized(0); + let lagrange_last_evals = zerofier_on_coset.lagrange_basis_unnormalized(degree - 1); + + let ext_degree = SC::Challenge::D; + + (0..quotient_size) + .into_par_iter() + .step_by(SC::PackedVal::WIDTH) + .flat_map_iter(|i_local_start| { + let wrap = |i| i % quotient_size; + let i_next_start = wrap(i_local_start + next_step); + let i_range = i_local_start..i_local_start + SC::PackedVal::WIDTH; + + let x = *SC::PackedVal::from_slice(&coset[i_range.clone()]); + let is_transition = x - subgroup_last; + let is_first_row = *SC::PackedVal::from_slice(&lagrange_first_evals[i_range.clone()]); + let is_last_row = *SC::PackedVal::from_slice(&lagrange_last_evals[i_range]); + + let local: Vec<_> = (0..main_lde.width()) + .map(|col| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_local_start + offset); + main_lde.get(row, col) + }) + }) + .collect(); + let next: Vec<_> = (0..main_lde.width()) + .map(|col| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_next_start + offset); + main_lde.get(row, col) + }) + }) + .collect(); + + let perm_local: Vec<_> = (0..permutation_lde.width()) + .step_by(ext_degree) + .map(|col| { + SC::PackedChallenge::from_base_fn(|i| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_local_start + offset); + permutation_lde.get(row, col + i) + }) + }) + }) + .collect(); + + let perm_next: Vec<_> = (0..permutation_lde.width()) + .step_by(ext_degree) + .map(|col| { + SC::PackedChallenge::from_base_fn(|i| { + SC::PackedVal::from_fn(|offset| { + let row = wrap(i_next_start + offset); + permutation_lde.get(row, col + i) + }) + }) + }) + .collect(); + + let accumulator = SC::PackedChallenge::zero(); + let mut folder = ProverConstraintFolder { + preprocessed: TwoRowMatrixView { + local: &[], + next: &[], + }, + main: TwoRowMatrixView { + local: &local, + next: &next, + }, + perm: TwoRowMatrixView { + local: &perm_local, + next: &perm_next, + }, + perm_challenges, + cumulative_sum, + is_first_row, + is_last_row, + is_transition, + alpha, + accumulator, + }; + chip.eval(&mut folder); + + // quotient(x) = constraints(x) / Z_H(x) + let zerofier_inv: SC::PackedVal = zerofier_on_coset.eval_inverse_packed(i_local_start); + let quotient = folder.accumulator * zerofier_inv; + + // "Transpose" D packed base coefficients into WIDTH scalar extension coefficients. + (0..SC::PackedVal::WIDTH).map(move |idx_in_packing| { + let quotient_value = (0..>::D) + .map(|coeff_idx| quotient.as_base_slice()[coeff_idx].as_slice()[idx_in_packing]) + .collect::>(); + SC::Challenge::from_base_slice("ient_value) + }) + }) + .collect() +} diff --git a/core/src/stark/types.rs b/core/src/stark/types.rs index 60712e58de..f5530f5642 100644 --- a/core/src/stark/types.rs +++ b/core/src/stark/types.rs @@ -24,41 +24,41 @@ type ValMat = RowMajorMatrix>; pub type Com = <::Pcs as Pcs, ValMat>>::Commitment; pub type PcsProverData = <::Pcs as Pcs, ValMat>>::ProverData; +pub type PcsProof = <::Pcs as Pcs, ValMat>>::Proof; pub type QuotientOpenedValues = Vec; #[derive(Serialize, Deserialize)] -#[serde(bound( - serialize = "SC: StarkGenericConfig", - deserialize = "SC: StarkGenericConfig" -))] -pub struct MainData { +#[serde(bound(serialize = "PcsProverData: Serialize"))] +#[serde(bound(deserialize = "PcsProverData: Deserialize<'de>"))] +pub struct ShardMainData { pub traces: Vec>, pub main_commit: Com, - #[serde(bound(serialize = "PcsProverData: Serialize"))] - #[serde(bound(deserialize = "PcsProverData: Deserialize<'de>"))] pub main_data: PcsProverData, pub chip_ids: Vec, + pub index: usize, } -impl MainData { +impl ShardMainData { pub fn new( traces: Vec>, main_commit: Com, main_data: PcsProverData, chip_ids: Vec, + index: usize, ) -> Self { Self { traces, main_commit, main_data, chip_ids, + index, } } - pub fn save(&self, file: File) -> Result, Error> + pub fn save(&self, file: File) -> Result, Error> where - MainData: Serialize, + ShardMainData: Serialize, { let mut writer = BufWriter::new(&file); bincode::serialize_into(&mut writer, self)?; @@ -66,26 +66,26 @@ impl MainData { let metadata = file.metadata()?; let bytes_written = metadata.len(); trace!( - "wrote {} while saving MainData", + "wrote {} while saving ShardMainData", Size::from_bytes(bytes_written) ); - Ok(MainDataWrapper::TempFile(file, bytes_written)) + Ok(ShardMainDataWrapper::TempFile(file, bytes_written)) } - pub fn to_in_memory(self) -> MainDataWrapper { - MainDataWrapper::InMemory(self) + pub fn to_in_memory(self) -> ShardMainDataWrapper { + ShardMainDataWrapper::InMemory(self) } } -pub enum MainDataWrapper { - InMemory(MainData), +pub enum ShardMainDataWrapper { + InMemory(ShardMainData), TempFile(File, u64), } -impl MainDataWrapper { - pub fn materialize(self) -> Result, Error> +impl ShardMainDataWrapper { + pub fn materialize(self) -> Result, Error> where - MainData: DeserializeOwned, + ShardMainData: DeserializeOwned, { match self { Self::InMemory(data) => Ok(data), @@ -130,6 +130,7 @@ pub struct ShardOpenedValues { #[cfg(feature = "perf")] #[derive(Serialize)] pub struct ShardProof { + pub index: usize, pub commitment: ShardCommitment>, pub opened_values: ShardOpenedValues>, pub opening_proof: OpeningProof, diff --git a/core/src/syscall/precompiles/blake3/compress/mod.rs b/core/src/syscall/precompiles/blake3/compress/mod.rs index 2709ce5985..1d7df224c6 100644 --- a/core/src/syscall/precompiles/blake3/compress/mod.rs +++ b/core/src/syscall/precompiles/blake3/compress/mod.rs @@ -113,7 +113,7 @@ pub mod compress_tests { use crate::runtime::Opcode; use crate::runtime::Register; use crate::runtime::SyscallCode; - use crate::utils::prove; + use crate::utils::run_test; use crate::utils::setup_logger; use crate::utils::tests::BLAKE3_COMPRESS_ELF; use crate::Program; @@ -166,13 +166,13 @@ pub mod compress_tests { fn prove_babybear() { setup_logger(); let program = blake3_compress_internal_program(); - prove(program); + run_test(program).unwrap(); } #[test] fn test_blake3_compress_inner_elf() { setup_logger(); let program = Program::from(BLAKE3_COMPRESS_ELF); - prove(program); + run_test(program).unwrap(); } } diff --git a/core/src/syscall/precompiles/keccak256/mod.rs b/core/src/syscall/precompiles/keccak256/mod.rs index f3d05e4fbe..6337509e4e 100644 --- a/core/src/syscall/precompiles/keccak256/mod.rs +++ b/core/src/syscall/precompiles/keccak256/mod.rs @@ -87,8 +87,9 @@ pub mod permute_tests { let mut runtime = Runtime::new(program); runtime.run(); - let (machine, prover_data) = RiscvStark::init(config); - machine.prove::>(&prover_data, &mut runtime.record, &mut challenger); + let machine = RiscvStark::new(config); + let (pk, _) = machine.setup(runtime.program.as_ref()); + machine.prove::>(&pk, &mut runtime.record, &mut challenger); } #[test] diff --git a/core/src/syscall/precompiles/sha256/compress/mod.rs b/core/src/syscall/precompiles/sha256/compress/mod.rs index bb01f41b49..35cbc3b725 100644 --- a/core/src/syscall/precompiles/sha256/compress/mod.rs +++ b/core/src/syscall/precompiles/sha256/compress/mod.rs @@ -72,8 +72,9 @@ pub mod compress_tests { let mut runtime = Runtime::new(program); runtime.run(); - let (machine, prover_data) = RiscvStark::init(config); + let machine = RiscvStark::new(config); - machine.prove::>(&prover_data, &mut runtime.record, &mut challenger); + let (pk, _) = machine.setup(runtime.program.as_ref()); + machine.prove::>(&pk, &mut runtime.record, &mut challenger); } } diff --git a/core/src/syscall/precompiles/sha256/extend/mod.rs b/core/src/syscall/precompiles/sha256/extend/mod.rs index 00b723bf78..0e0fcab411 100644 --- a/core/src/syscall/precompiles/sha256/extend/mod.rs +++ b/core/src/syscall/precompiles/sha256/extend/mod.rs @@ -89,8 +89,9 @@ pub mod extend_tests { let mut runtime = Runtime::new(program); runtime.run(); - let (machine, prover_data) = RiscvStark::init(config); + let machine = RiscvStark::new(config); - machine.prove::>(&prover_data, &mut runtime.record, &mut challenger); + let (pk, _) = machine.setup(runtime.program.as_ref()); + machine.prove::>(&pk, &mut runtime.record, &mut challenger); } } diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs index 8e67912d46..986159c9bb 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs @@ -338,13 +338,13 @@ where mod tests { use crate::{ runtime::Program, - utils::{prove, setup_logger, tests::SECP256K1_ADD_ELF}, + utils::{run_test, setup_logger, tests::SECP256K1_ADD_ELF}, }; #[test] fn test_secp256k1_add_simple() { setup_logger(); let program = Program::from(SECP256K1_ADD_ELF); - prove(program); + run_test(program).unwrap(); } } diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index f491275bad..cfaeb3e17d 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -343,13 +343,13 @@ pub mod tests { use crate::{ runtime::Program, - utils::{prove, setup_logger, tests::SECP256K1_DOUBLE_ELF}, + utils::{run_test, setup_logger, tests::SECP256K1_DOUBLE_ELF}, }; #[test] fn test_secp256k1_double_simple() { setup_logger(); let program = Program::from(SECP256K1_DOUBLE_ELF); - prove(program); + run_test(program).unwrap(); } } diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index dccb6059a6..26f52fb368 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -3,7 +3,7 @@ use std::time::Instant; use crate::utils::poseidon2_instance::RC_16_30; use crate::{ runtime::{Program, Runtime}, - stark::{LocalProver, MainData, OpeningProof}, + stark::{LocalProver, OpeningProof, ShardMainData}, stark::{RiscvStark, StarkGenericConfig}, }; pub use baby_bear_blake3::BabyBearBlake3; @@ -46,6 +46,38 @@ pub fn prove(program: Program) -> crate::stark::Proof { prove_core(config, &mut runtime) } +#[cfg(test)] +pub fn run_test(program: Program) -> Result<(), crate::stark::ProgramVerificationError> { + let mut runtime = tracing::info_span!("runtime.run(...)").in_scope(|| { + let mut runtime = Runtime::new(program); + runtime.run(); + runtime + }); + let config = BabyBearBlake3::new(); + + let machine = RiscvStark::new(config); + let (pk, vk) = machine.setup(runtime.program.as_ref()); + let mut challenger = machine.config().challenger(); + + let start = Instant::now(); + let proof = tracing::info_span!("runtime.prove(...)") + .in_scope(|| machine.prove::>(&pk, &mut runtime.record, &mut challenger)); + let cycles = runtime.state.global_clk; + let time = start.elapsed().as_millis(); + let nb_bytes = bincode::serialize(&proof).unwrap().len(); + + tracing::info!( + "cycles={}, e2e={}, khz={:.2}, proofSize={}", + cycles, + time, + (cycles as f64 / time as f64), + Size::from_bytes(nb_bytes), + ); + + let mut challenger = machine.config().challenger(); + machine.verify(&vk, &proof, &mut challenger) +} + pub fn prove_elf(elf: &[u8]) -> crate::stark::Proof { let program = Program::from(elf); prove(program) @@ -60,23 +92,23 @@ where OpeningProof: Send + Sync, >>::Commitment: Send + Sync, >>::ProverData: Send + Sync, - MainData: Serialize + DeserializeOwned, + ShardMainData: Serialize + DeserializeOwned, ::Val: PrimeField32, { let mut challenger = config.challenger(); let start = Instant::now(); - let (machine, prover_data) = RiscvStark::init(config.clone()); + let machine = RiscvStark::new(config); + let (pk, _) = machine.setup(runtime.program.as_ref()); // Because proving modifies the shard, clone beforehand if we debug interactions. #[cfg(not(feature = "perf"))] let shard = runtime.record.clone(); // Prove the program. - let proof = tracing::info_span!("prove").in_scope(|| { - machine.prove::>(&prover_data, &mut runtime.record, &mut challenger) - }); + let proof = tracing::info_span!("runtime.prove(...)") + .in_scope(|| machine.prove::>(&pk, &mut runtime.record, &mut challenger)); let cycles = runtime.state.global_clk; let time = start.elapsed().as_millis(); let nb_bytes = bincode::serialize(&proof).unwrap().len();