diff --git a/uni-stark/src/check_constraints.rs b/uni-stark/src/check_constraints.rs index 6a7b48cd1..1d86bd6fb 100644 --- a/uni-stark/src/check_constraints.rs +++ b/uni-stark/src/check_constraints.rs @@ -15,7 +15,7 @@ pub(crate) fn check_constraints( air: &A, preprocessed: &RowMajorMatrix, stages: Vec<&RowMajorMatrix>, - public_values: &Vec<&Vec>, + public_values: &Vec, challenges: Vec<&Vec>, ) where F: Field, @@ -75,7 +75,7 @@ pub struct DebugConstraintBuilder<'a, F: Field> { preprocessed: VerticalPair, RowMajorMatrixView<'a, F>>, challenges: Vec<&'a Vec>, stages: Vec, RowMajorMatrixView<'a, F>>>, - public_values: &'a [&'a Vec], + public_values: &'a [F], is_first_row: F, is_last_row: F, is_transition: F, @@ -134,7 +134,7 @@ impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> type PublicVar = Self::F; fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) + self.public_values } } @@ -147,10 +147,6 @@ impl<'a, F: Field> PairBuilder for DebugConstraintBuilder<'a, F> { impl<'a, F: Field> MultistageAirBuilder for DebugConstraintBuilder<'a, F> { type Challenge = Self::Expr; - fn stage_public_values(&self, stage: usize) -> &[Self::F] { - self.public_values[stage] - } - fn stage_trace(&self, stage: usize) -> Self::M { self.stages[stage] } diff --git a/uni-stark/src/folder.rs b/uni-stark/src/folder.rs index 307337922..1ffec17b9 100644 --- a/uni-stark/src/folder.rs +++ b/uni-stark/src/folder.rs @@ -13,7 +13,7 @@ pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { pub challenges: Vec>>, pub stages: Vec>>, pub preprocessed: RowMajorMatrix>, - pub public_values: &'a Vec>>, + pub public_values: &'a Vec>, pub is_first_row: PackedVal, pub is_last_row: PackedVal, pub is_transition: PackedVal, @@ -28,7 +28,7 @@ pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> { pub challenges: Vec>>, pub stages: Vec>, pub preprocessed: ViewPair<'a, SC::Challenge>, - pub public_values: Vec<&'a Vec>>, + pub public_values: &'a Vec>, pub is_first_row: SC::Challenge, pub is_last_row: SC::Challenge, pub is_transition: SC::Challenge, @@ -73,7 +73,7 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraint type PublicVar = Val; fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) + self.public_values } } @@ -87,9 +87,6 @@ impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for ProverConstraintFolder fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] { &self.challenges[stage] } - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - &self.public_values[stage] - } } impl<'a, SC: StarkGenericConfig> PairBuilder for ProverConstraintFolder<'a, SC> { @@ -135,7 +132,7 @@ impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstrai type PublicVar = Val; fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) + self.public_values } } @@ -149,9 +146,6 @@ impl<'a, SC: StarkGenericConfig> MultistageAirBuilder for VerifierConstraintFold fn stage_challenges(&self, stage: usize) -> &[Self::Challenge] { &self.challenges[stage] } - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - self.public_values[stage] - } } impl<'a, SC: StarkGenericConfig> PairBuilder for VerifierConstraintFolder<'a, SC> { diff --git a/uni-stark/src/proof.rs b/uni-stark/src/proof.rs index 2ec7dbe12..491e4b87d 100644 --- a/uni-stark/src/proof.rs +++ b/uni-stark/src/proof.rs @@ -59,7 +59,6 @@ pub struct ProcessedStage { pub(crate) commitment: Com, pub(crate) prover_data: PcsProverData, pub(crate) challenge_values: Vec>, - pub(crate) public_values: Vec>, #[cfg(debug_assertions)] pub(crate) trace: RowMajorMatrix>, } diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index fb5a13cc2..ed57f9db1 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -40,6 +40,8 @@ pub fn prove< air: &A, challenger: &mut SC::Challenger, main_trace: RowMajorMatrix>, + #[allow(clippy::ptr_arg)] + // we do not use `&[Val]` in order to keep the same API public_values: &Vec>, ) -> Proof where @@ -47,6 +49,12 @@ where A: MultiStageAir>> + for<'a> MultiStageAir>, { + let public_values = public_values + .iter() + .enumerate() + .map(|(index, value)| (index, *value)) + .collect(); + prove_with_key( config, None, @@ -54,7 +62,7 @@ where challenger, main_trace, &UnusedCallback, - public_values, + &public_values, ) } @@ -72,9 +80,7 @@ pub fn prove_with_key< challenger: &mut SC::Challenger, stage_0_trace: RowMajorMatrix>, next_stage_trace_callback: &C, - #[allow(clippy::ptr_arg)] - // we do not use `&[Val]` in order to keep the same API - stage_0_public_values: &Vec>, + stage_0_public_values: &Vec<(usize, Val)>, ) -> Proof where SC: StarkGenericConfig, @@ -85,10 +91,6 @@ where let degree = stage_0_trace.height(); let log_degree = log2_strict_usize(degree); - let log_quotient_degree = - get_log_quotient_degree::, A>(air, &[stage_0_public_values.len()]); - let quotient_degree = 1 << log_quotient_degree; - let stage_count = >>::stage_count(air); let pcs = config.pcs(); @@ -103,6 +105,7 @@ where }; let mut state: ProverState = ProverState::new(pcs, trace_domain, challenger); + state.add_public_values(stage_0_public_values.iter().cloned()); let mut stage = Stage { trace: stage_0_trace, challenge_count: >>::stage_challenge_count(air, 0), @@ -146,6 +149,12 @@ where // sanity check that we processed as many stages as expected assert_eq!(state.processed_stages.len(), stage_count); + let public_values: Vec<_> = state + .public_values + .into_iter() + .map(|v| v.unwrap()) + .collect(); + // with the witness complete, check the constraints #[cfg(debug_assertions)] crate::check_constraints::check_constraints( @@ -153,11 +162,7 @@ where &air.preprocessed_trace() .unwrap_or(RowMajorMatrix::new(Default::default(), 0)), state.processed_stages.iter().map(|s| &s.trace).collect(), - &state - .processed_stages - .iter() - .map(|s| &s.public_values) - .collect(), + &public_values, state .processed_stages .iter() @@ -165,14 +170,8 @@ where .collect(), ); - let log_quotient_degree = get_log_quotient_degree::, A>( - air, - &state - .processed_stages - .iter() - .map(|s| s.public_values.len()) - .collect::>(), - ); + let log_quotient_degree = get_log_quotient_degree::, A>(air, public_values.len()); + let quotient_degree = 1 << log_quotient_degree; let challenger = &mut state.challenger; @@ -197,12 +196,6 @@ where .map(|stage| stage.challenge_values.clone()) .collect(); - let public_values = state - .processed_stages - .iter() - .map(|stage| stage.public_values.clone()) - .collect(); - let quotient_values = quotient_values( air, &public_values, @@ -312,7 +305,7 @@ where #[instrument(name = "compute quotient polynomial", skip_all)] fn quotient_values<'a, SC, A, Mat>( air: &A, - public_values: &'a Vec>>, + public_values: &'a Vec>, trace_domain: Domain, quotient_domain: Domain, preprocessed_on_quotient_domain: Option, @@ -416,6 +409,7 @@ pub struct ProverState<'a, SC: StarkGenericConfig> { pub(crate) challenger: &'a mut SC::Challenger, pub(crate) pcs: &'a ::Pcs, pub(crate) trace_domain: Domain, + pub(crate) public_values: Vec>>, } impl<'a, SC: StarkGenericConfig> ProverState<'a, SC> { @@ -429,6 +423,24 @@ impl<'a, SC: StarkGenericConfig> ProverState<'a, SC> { challenger, pcs, trace_domain, + public_values: Default::default(), + } + } + + pub(crate) fn add_public_values( + &mut self, + public_values: impl IntoIterator)>, + ) { + for (index, value) in public_values { + if self.public_values.len() <= index + 1 { + self.public_values.resize(index + 1, None); + } + match self.public_values[index] { + Some(_) => panic!("public value at index {index} is already set"), + None => { + self.public_values[index] = Some(value); + } + } } } @@ -446,11 +458,9 @@ impl<'a, SC: StarkGenericConfig> ProverState<'a, SC> { .map(|_| self.challenger.sample()) .collect(); - // observe the public inputs for this stage - self.challenger.observe_slice(&stage.public_values); + self.add_public_values(stage.public_values); self.processed_stages.push(ProcessedStage { - public_values: stage.public_values, prover_data, commitment, challenge_values, @@ -467,20 +477,24 @@ pub struct Stage { /// the number of challenges to be drawn at the end of this stage pub(crate) challenge_count: usize, /// the public values for this stage - pub(crate) public_values: Vec>, + pub(crate) public_values: Vec<(usize, Val)>, } pub struct CallbackResult { /// the trace for this stage pub(crate) trace: RowMajorMatrix, /// the values of the public inputs of this stage - pub(crate) public_values: Vec, + pub(crate) public_values: Vec<(usize, T)>, /// the values of the challenges drawn at the previous stage pub(crate) challenges: Vec, } impl CallbackResult { - pub fn new(trace: RowMajorMatrix, public_values: Vec, challenges: Vec) -> Self { + pub fn new( + trace: RowMajorMatrix, + public_values: Vec<(usize, T)>, + challenges: Vec, + ) -> Self { Self { trace, public_values, diff --git a/uni-stark/src/symbolic_builder.rs b/uni-stark/src/symbolic_builder.rs index 767ad1961..98e397581 100644 --- a/uni-stark/src/symbolic_builder.rs +++ b/uni-stark/src/symbolic_builder.rs @@ -13,13 +13,13 @@ use crate::traits::{MultiStageAir, MultistageAirBuilder}; use crate::Entry; #[instrument(name = "infer log of constraint degree", skip_all)] -pub fn get_log_quotient_degree(air: &A, public_values_counts: &[usize]) -> usize +pub fn get_log_quotient_degree(air: &A, num_public_values: usize) -> usize where F: Field, A: MultiStageAir>, { // We pad to at least degree 2, since a quotient argument doesn't make sense with smaller degrees. - let constraint_degree = get_max_constraint_degree(air, public_values_counts).max(2); + let constraint_degree = get_max_constraint_degree(air, num_public_values).max(2); // The quotient's actual degree is approximately (max_constraint_degree - 1) n, // where subtracting 1 comes from division by the zerofier. @@ -28,12 +28,12 @@ where } #[instrument(name = "infer constraint degree", skip_all, level = "debug")] -pub fn get_max_constraint_degree(air: &A, public_values_counts: &[usize]) -> usize +pub fn get_max_constraint_degree(air: &A, num_public_values: usize) -> usize where F: Field, A: MultiStageAir>, { - get_symbolic_constraints(air, public_values_counts) + get_symbolic_constraints(air, num_public_values) .iter() .map(|c| c.degree_multiple()) .max() @@ -43,7 +43,7 @@ where #[instrument(name = "evaluate constraints symbolically", skip_all, level = "debug")] pub fn get_symbolic_constraints( air: &A, - public_values_counts: &[usize], + num_public_values: usize, ) -> Vec> where F: Field, @@ -58,7 +58,7 @@ where let mut builder = SymbolicAirBuilder::new( air.preprocessed_width(), &widths, - public_values_counts, + num_public_values, challenges, ); air.eval(&mut builder); @@ -71,7 +71,7 @@ pub struct SymbolicAirBuilder { challenges: Vec>>, preprocessed: RowMajorMatrix>, stages: Vec>>, - public_values: Vec>>, + public_values: Vec>, constraints: Vec>, } @@ -79,7 +79,7 @@ impl SymbolicAirBuilder { pub(crate) fn new( preprocessed_width: usize, stage_widths: &[usize], - public_value_counts: &[usize], + num_public_values: usize, challenges: Vec, ) -> Self { let prep_values = [0, 1] @@ -115,18 +115,8 @@ impl SymbolicAirBuilder { .collect() }) .collect(); - let mut public_value_index = 0; - let public_values = public_value_counts - .iter() - .map(|count| { - (0..*count) - .map(|_| { - let res = SymbolicVariable::new(Entry::Public, public_value_index); - public_value_index += 1; - res - }) - .collect() - }) + let public_values = (0..num_public_values) + .map(move |index| SymbolicVariable::new(Entry::Public, index)) .collect(); Self { challenges, @@ -177,17 +167,13 @@ impl AirBuilderWithPublicValues for SymbolicAirBuilder { type PublicVar = SymbolicVariable; fn public_values(&self) -> &[Self::PublicVar] { - self.stage_public_values(0) + &self.public_values } } impl MultistageAirBuilder for SymbolicAirBuilder { type Challenge = Self::Var; - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - &self.public_values[stage] - } - fn stage_trace(&self, stage: usize) -> Self::M { self.stages[stage].clone() } diff --git a/uni-stark/src/traits.rs b/uni-stark/src/traits.rs index 8314b0ceb..8e759736f 100644 --- a/uni-stark/src/traits.rs +++ b/uni-stark/src/traits.rs @@ -8,13 +8,6 @@ pub trait MultistageAirBuilder: AirBuilderWithPublicValues { /// Challenges from each stage, drawn from the base field fn stage_challenges(&self, stage: usize) -> &[Self::Challenge]; - - fn stage_public_values(&self, stage: usize) -> &[Self::PublicVar] { - match stage { - 0 => self.public_values(), - _ => unimplemented!(), - } - } } pub trait MultiStageAir: Air { diff --git a/uni-stark/src/verifier.rs b/uni-stark/src/verifier.rs index 420fd3a8d..f89361bef 100644 --- a/uni-stark/src/verifier.rs +++ b/uni-stark/src/verifier.rs @@ -30,7 +30,7 @@ where A: MultiStageAir>> + for<'a> MultiStageAir>, { - verify_with_key(config, None, air, challenger, proof, vec![public_values]) + verify_with_key(config, None, air, challenger, proof, public_values) } #[instrument(skip_all)] @@ -40,7 +40,7 @@ pub fn verify_with_key( air: &A, challenger: &mut SC::Challenger, proof: &Proof, - public_values: Vec<&Vec>>, + public_values: &Vec>, ) -> Result<(), VerificationError>> where SC: StarkGenericConfig, @@ -56,13 +56,7 @@ where } = proof; let degree = 1 << degree_bits; - let log_quotient_degree = get_log_quotient_degree::, A>( - air, - &public_values - .iter() - .map(|values| values.len()) - .collect::>(), - ); + let log_quotient_degree = get_log_quotient_degree::, A>(air, public_values.len()); let quotient_degree = 1 << log_quotient_degree; let stages = proof.commitments.stages.len(); @@ -118,13 +112,13 @@ where commitments .stages .iter() - .zip(&public_values) .zip(challenge_counts) - .for_each(|((commitment, public_values), challenge_count)| { + .for_each(|(commitment, challenge_count)| { challenger.observe(commitment.clone()); challenges.push((0..*challenge_count).map(|_| challenger.sample()).collect()); - challenger.observe_slice(public_values); }); + + challenger.observe_slice(public_values); let alpha: SC::Challenge = challenger.sample_ext_element(); challenger.observe(commitments.quotient_chunks.clone());