diff --git a/src/field.rs b/src/field.rs index 514fc3dfa1..1427c33c9f 100644 --- a/src/field.rs +++ b/src/field.rs @@ -6,7 +6,7 @@ //! relating this field to the expresions of the language. use clap::ValueEnum; use ff::{PrimeField, PrimeFieldBits}; -use nova::provider::bn256_grumpkin::bn256; +use nova::provider::bn256_grumpkin::{bn256, grumpkin}; use serde::{Deserialize, Serialize}; use std::convert::TryFrom; use std::hash::Hash; @@ -272,6 +272,10 @@ impl LurkField for bn256::Scalar { const FIELD: LanguageField = LanguageField::BN256; } +impl LurkField for grumpkin::Scalar { + const FIELD: LanguageField = LanguageField::Grumpkin; +} + // The impl LurkField for grumpkin::Scalar is technically possible, but voluntarily omitted to avoid confusion. // For working around the orphan trait impl rule diff --git a/src/lem/circuit.rs b/src/lem/circuit.rs index 7de35510b0..023a7b66a8 100644 --- a/src/lem/circuit.rs +++ b/src/lem/circuit.rs @@ -1594,10 +1594,10 @@ impl Func { }; // fixed cost for each slot - let slot_constraints = 289 * self.slots_count.hash4 - + 337 * self.slots_count.hash6 - + 388 * self.slots_count.hash8 - + 265 * self.slots_count.commitment + let slot_constraints = store.hash4_cost() * self.slots_count.hash4 + + store.hash6_cost() * self.slots_count.hash6 + + store.hash8_cost() * self.slots_count.hash8 + + store.hash3_cost() * self.slots_count.commitment + bit_decomp_cost * self.slots_count.bit_decomp; let num_constraints = recurse(&self.body, globals, store, false); slot_constraints + num_constraints + globals.len() diff --git a/src/lem/multiframe.rs b/src/lem/multiframe.rs index 20b7b14482..54713ca87a 100644 --- a/src/lem/multiframe.rs +++ b/src/lem/multiframe.rs @@ -14,7 +14,7 @@ use crate::{ coprocessor::Coprocessor, error::{ProofError, ReductionError}, eval::{lang::Lang, Meta}, - field::LurkField, + field::{LanguageField, LurkField}, proof::{ nova::{CurveCycleEquipped, G1, G2}, supernova::{FoldingConfig, C2}, @@ -140,6 +140,30 @@ fn assert_eq_ptrs_aptrs( Ok(()) } +// Hardcoded slot witness sizes, empirically collected +const BIT_DECOMP_PALLAS_WITNESS_SIZE: usize = 298; +const BIT_DECOMP_VESTA_WITNESS_SIZE: usize = 301; +const BIT_DECOMP_BN256_WITNESS_SIZE: usize = 354; +const BIT_DECOMP_GRUMPKIN_WITNESS_SIZE: usize = 364; + +/// Computes the witness size for a `SlotType`. Note that the witness size for +/// bit decomposition depends on the field we're in. +#[inline] +fn compute_witness_size(slot_type: &SlotType, store: &Store) -> usize { + match slot_type { + SlotType::Hash4 => store.hash4_cost() + 4, // 4 preimg elts + SlotType::Hash6 => store.hash6_cost() + 6, // 6 preimg elts + SlotType::Hash8 => store.hash8_cost() + 8, // 8 preimg elts + SlotType::Commitment => store.hash3_cost() + 3, // 3 preimg elts + SlotType::BitDecomp => match F::FIELD { + LanguageField::Pallas => BIT_DECOMP_PALLAS_WITNESS_SIZE, + LanguageField::Vesta => BIT_DECOMP_VESTA_WITNESS_SIZE, + LanguageField::BN256 => BIT_DECOMP_BN256_WITNESS_SIZE, + LanguageField::Grumpkin => BIT_DECOMP_GRUMPKIN_WITNESS_SIZE, + }, + } +} + /// Generates the witnesses for all slots in `frames`. Since many slots are fed /// with dummy data, we cache their (dummy) witnesses for extra speed fn generate_slots_witnesses( @@ -160,11 +184,25 @@ fn generate_slots_witnesses( .into_iter() .for_each(|(sd_vec, st)| sd_vec.iter().for_each(|sd| slots_data.push((sd, st)))); }); + // precompute these values + let hash4_witness_size = compute_witness_size(&SlotType::Hash4, store); + let hash6_witness_size = compute_witness_size(&SlotType::Hash6, store); + let hash8_witness_size = compute_witness_size(&SlotType::Hash8, store); + let commitment_witness_size = compute_witness_size(&SlotType::Commitment, store); + let bit_decomp_witness_size = compute_witness_size(&SlotType::BitDecomp, store); + // fast getter for the precomputed values + let get_witness_size = |slot_type| match slot_type { + SlotType::Hash4 => hash4_witness_size, + SlotType::Hash6 => hash6_witness_size, + SlotType::Hash8 => hash8_witness_size, + SlotType::Commitment => commitment_witness_size, + SlotType::BitDecomp => bit_decomp_witness_size, + }; // cache dummy slots witnesses with `Arc` for speedy clones let dummy_witnesses_cache: FrozenMap<_, Box>>> = FrozenMap::default(); let gen_slot_witness = |(slot_idx, (slot_data, slot_type))| { let mk_witness = || { - let mut witness = WitnessCS::new(); + let mut witness = WitnessCS::with_capacity(1, get_witness_size(slot_type)); let allocations = allocate_slot(&mut witness, slot_data, slot_idx, slot_type, store) .expect("slot allocations failed"); Arc::new(SlotWitness { @@ -896,7 +934,8 @@ where #[cfg(test)] mod tests { use bellpepper_core::test_cs::TestConstraintSystem; - use pasta_curves::Fq; + use nova::provider::bn256_grumpkin::{bn256::Scalar as Bn, grumpkin::Scalar as Gr}; + use pasta_curves::{Fp, Fq}; use crate::{ eval::lang::Coproc, @@ -908,6 +947,36 @@ mod tests { use super::*; + /// Asserts that the computed witness sizes are correct across all slot types + /// and fields used in Lurk + #[test] + fn test_get_witness_size() { + fn assert_sizes() { + [ + SlotType::Hash4, + SlotType::Hash6, + SlotType::Hash8, + SlotType::Commitment, + SlotType::BitDecomp, + ] + .into_par_iter() + .for_each(|slot_type| { + let store = Store::::default(); + let mut w = WitnessCS::::new(); + let computed_size = compute_witness_size::(&slot_type, &store); + allocate_slot(&mut w, &None, 0, slot_type, &store).unwrap(); + assert_eq!(w.aux_assignment().len(), computed_size); + }); + } + (0..3).into_par_iter().for_each(|i| match i { + 0 => assert_sizes::(), + 1 => assert_sizes::(), + 2 => assert_sizes::(), + 3 => assert_sizes::(), + _ => unreachable!(), + }); + } + #[test] fn test_sequential_and_parallel_witnesses_equivalences() { let lurk_step = eval_step(); diff --git a/src/lem/store.rs b/src/lem/store.rs index 0d21f9baf4..ef9795ce0f 100644 --- a/src/lem/store.rs +++ b/src/lem/store.rs @@ -1,10 +1,12 @@ use anyhow::{bail, Result}; use arc_swap::ArcSwap; +use bellpepper::util_cs::witness_cs::SizedWitness; use elsa::{ sync::{FrozenMap, FrozenVec}, sync_index_set::FrozenIndexSet, }; use indexmap::IndexSet; +use neptune::Poseidon; use nom::{sequence::preceded, Parser}; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; use std::{cell::RefCell, rc::Rc, sync::Arc}; @@ -102,6 +104,30 @@ impl Default for Store { } impl Store { + /// Cost of poseidon hash with arity 3, including the input + #[inline] + pub fn hash3_cost(&self) -> usize { + Poseidon::new(self.poseidon_cache.constants.c3()).num_aux() + 1 + } + + /// Cost of poseidon hash with arity 4, including the input + #[inline] + pub fn hash4_cost(&self) -> usize { + Poseidon::new(self.poseidon_cache.constants.c4()).num_aux() + 1 + } + + /// Cost of poseidon hash with arity 6, including the input + #[inline] + pub fn hash6_cost(&self) -> usize { + Poseidon::new(self.poseidon_cache.constants.c6()).num_aux() + 1 + } + + /// Cost of poseidon hash with arity 8, including the input + #[inline] + pub fn hash8_cost(&self) -> usize { + Poseidon::new(self.poseidon_cache.constants.c8()).num_aux() + 1 + } + /// Creates a `Ptr` that's a parent of two children pub fn intern_2_ptrs(&self, tag: Tag, a: Ptr, b: Ptr) -> Ptr { let (idx, inserted) = self.tuple2.insert_probe(Box::new((a, b)));