From 6353bc8151991484c670cb7db812dfe940d2c6ed Mon Sep 17 00:00:00 2001 From: xkx Date: Mon, 27 Nov 2023 17:11:25 +0800 Subject: [PATCH] parallelize mpt updates (#99) * parallelize mpt updates * fix * remove unnecessary log * add ci task for testing parallel assignment * fix all_paddings test failure * collect runtime statistics * defer invert to be batched * fix * assign_par for key_bit gadget * assign_par for canonical_repr gadget * comment * fix lock file merge errors --------- Co-authored-by: z2trillion Co-authored-by: Mason Liang --- .github/workflows/ci.yml | 14 +- Cargo.lock | 29 ++ Cargo.toml | 6 +- Makefile | 3 + src/constraint_builder/column.rs | 12 + src/gadgets/canonical_representation.rs | 93 +++++ src/gadgets/is_zero.rs | 6 +- src/gadgets/key_bit.rs | 53 ++- src/gadgets/mpt_update.rs | 499 +++++++++++++----------- src/mpt.rs | 151 +++++-- src/tests.rs | 1 + 11 files changed, 606 insertions(+), 261 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dc3ee518..7a414e9e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,19 @@ jobs: toolchain: nightly-2022-12-10 override: true - run: make test - bench: + + par-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly-2022-12-10 + override: true + - run: make test_par + + bench: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 diff --git a/Cargo.lock b/Cargo.lock index 941fa014..9521e45e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -878,6 +878,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "env_logger" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +dependencies = [ + "atty", + "humantime", + "log", + "regex", + "termcolor", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -1248,6 +1261,7 @@ version = "0.1.0" dependencies = [ "bencher", "criterion", + "env_logger", "ethers-core", "halo2_proofs", "hex", @@ -1415,6 +1429,12 @@ dependencies = [ "hmac 0.8.1", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iana-time-zone" version = "0.1.58" @@ -2620,6 +2640,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "termcolor" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" +dependencies = [ + "winapi-util", +] + [[package]] name = "textwrap" version = "0.16.0" diff --git a/Cargo.toml b/Cargo.toml index e45d45bb..6e64903c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ num-bigint = "0.4" hex = "0.4" thiserror = "1.0" log = "0.4" +env_logger = "0.9.0" mpt-zktrie = { git = "https://github.com/scroll-tech/zkevm-circuits.git", rev = "7d9bc181953cfc6e7baf82ff0ce651281fd70a8a" } rand_chacha = "0.3.0" criterion = { version = "0.4", optional = true} @@ -33,7 +34,8 @@ ethers-core = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = " [features] # printout the layout of circuits for demo and some unittests print_layout = ["halo2_proofs/dev-graph"] -bench = [ "dep:criterion" ] +default = ["halo2_proofs/mock-batch-inv", "halo2_proofs/parallel_syn"] +bench = ["dep:criterion"] [dev-dependencies] # mpt-zktrie = { path = "../scroll-circuits/zktrie" } @@ -52,4 +54,4 @@ debug-assertions = true [[bench]] name = "parallel_assignment" harness = false -required-features = [ "bench" ] +required-features = ["bench"] diff --git a/Makefile b/Makefile index 1adfc3eb..e02e8cbc 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ test: @cargo test +test_par: + PARALLEL_SYN=true cargo test -- --nocapture + fmt: @cargo fmt diff --git a/src/constraint_builder/column.rs b/src/constraint_builder/column.rs index 45351e16..75d472f3 100644 --- a/src/constraint_builder/column.rs +++ b/src/constraint_builder/column.rs @@ -1,4 +1,5 @@ use super::{BinaryQuery, Query}; +use halo2_proofs::plonk::Assigned; use halo2_proofs::{ arithmetic::FieldExt, circuit::{Region, Value}, @@ -101,6 +102,17 @@ impl AdviceColumn { ) .expect("failed assign_advice"); } + + pub fn assign_rational( + &self, + region: &mut Region<'_, F>, + offset: usize, + value: Assigned, + ) { + region + .assign_advice(|| "advice", self.0, offset, || Value::known(value)) + .expect("failed assign_advice"); + } } #[derive(Clone, Copy)] diff --git a/src/gadgets/canonical_representation.rs b/src/gadgets/canonical_representation.rs index e98f0132..e96b6c48 100644 --- a/src/gadgets/canonical_representation.rs +++ b/src/gadgets/canonical_representation.rs @@ -4,6 +4,8 @@ use super::super::constraint_builder::{ }; use super::{byte_bit::RangeCheck256Lookup, is_zero::IsZeroGadget, rlc_randomness::RlcRandomness}; use ethers_core::types::U256; +use halo2_proofs::circuit::Layouter; +use halo2_proofs::plonk::Error; use halo2_proofs::{ arithmetic::{Field, FieldExt}, circuit::{Region, Value}, @@ -212,6 +214,97 @@ impl CanonicalRepresentationConfig { } } + pub fn assign_par( + &self, + layouter: &mut impl Layouter, + randomness: Value, + values: &[Fr], + n_rows: usize, + ) { + let modulus = U256::from_str_radix(Fr::MODULUS, 16).unwrap(); + let mut modulus_bytes = [0u8; 32]; + modulus.to_big_endian(&mut modulus_bytes); + + let num_threads = std::thread::available_parallelism().unwrap().get(); + let num_values = n_rows / 32; + let zero = Fr::zero(); + log::debug!("num_real_values: {}", values.len()); + let values = values + .iter() + .chain(std::iter::repeat(&zero)) + .take(num_values) + .collect_vec(); + let chunk_size = (num_values + num_threads - 1) / num_threads; + let mut is_first_passes = vec![true; num_threads]; + let assignments = values + .chunks(chunk_size) + .zip(is_first_passes.iter_mut()) + .enumerate() + .map(|(i, (values, is_first_pass))| { + move |mut region: Region<'_, Fr>| -> Result<(), Error> { + let region = &mut region; + if *is_first_pass { + *is_first_pass = false; + let last_off = if i == 0 { + values.len() * 32 + } else { + values.len() * 32 - 1 + }; + self.value.assign(region, last_off, Fr::zero()); + return Ok(()); + } + let mut offset = if i == 0 { 1 } else { 0 }; + for value in values.iter() { + let mut bytes = value.to_bytes(); + bytes.reverse(); + let mut differences_are_zero_so_far = true; + let mut rlc = Value::known(Fr::zero()); + for (index, (byte, modulus_byte)) in + bytes.iter().zip_eq(&modulus_bytes).enumerate() + { + self.byte.assign(region, offset, u64::from(*byte)); + self.modulus_byte + .assign(region, offset, u64::from(*modulus_byte)); + + self.index + .assign(region, offset, u64::try_from(index).unwrap()); + if index.is_zero() { + self.index_is_zero.enable(region, offset); + } else if index == 31 { + self.index_is_31.enable(region, offset); + } + + let difference = + Fr::from(u64::from(*modulus_byte)) - Fr::from(u64::from(*byte)); + self.difference.assign(region, offset, difference); + self.difference_is_zero.assign(region, offset, difference); + + self.differences_are_zero_so_far.assign( + region, + offset, + differences_are_zero_so_far, + ); + differences_are_zero_so_far &= difference.is_zero_vartime(); + + self.value.assign(region, offset, **value); + + rlc = rlc * randomness + Value::known(Fr::from(u64::from(*byte))); + self.rlc.assign(region, offset, rlc); + + offset += 1 + } + } + + Ok(()) + } + }) + .collect_vec(); + + layouter + .assign_regions(|| "canonical_repr", assignments) + .unwrap(); + } + pub fn n_rows_required(values: &[Fr]) -> usize { // +1 because assigment starts on offset = 1 instead of offset = 0. values.len() * 32 + 1 diff --git a/src/gadgets/is_zero.rs b/src/gadgets/is_zero.rs index 9fe7163d..9ae835ee 100644 --- a/src/gadgets/is_zero.rs +++ b/src/gadgets/is_zero.rs @@ -1,4 +1,5 @@ use crate::constraint_builder::{AdviceColumn, BinaryQuery, ConstraintBuilder, Query}; +use halo2_proofs::plonk::Assigned; use halo2_proofs::{arithmetic::FieldExt, circuit::Region, plonk::ConstraintSystem}; use std::fmt::Debug; @@ -25,10 +26,11 @@ impl IsZeroGadget { ) where >::Error: Debug, { - self.inverse_or_zero.assign( + self.inverse_or_zero.assign_rational( region, offset, - value.try_into().unwrap().invert().unwrap_or(F::zero()), + // invert is deferred and then batched by the real/mock prover + Assigned::::from(value.try_into().unwrap()).invert(), ); } diff --git a/src/gadgets/key_bit.rs b/src/gadgets/key_bit.rs index abeece53..aef0fa75 100644 --- a/src/gadgets/key_bit.rs +++ b/src/gadgets/key_bit.rs @@ -3,9 +3,11 @@ use super::{ canonical_representation::CanonicalRepresentationLookup, }; use crate::constraint_builder::{AdviceColumn, ConstraintBuilder, Query}; +use halo2_proofs::circuit::Layouter; use halo2_proofs::{ arithmetic::FieldExt, circuit::Region, halo2curves::bn256::Fr, plonk::ConstraintSystem, }; +use itertools::Itertools; pub trait KeyBitLookup { fn lookup(&self) -> [Query; 3]; @@ -83,10 +85,22 @@ impl KeyBitConfig { } pub fn assign(&self, region: &mut Region<'_, Fr>, lookups: &[(Fr, usize, bool)]) { + self.assign_internal(region, lookups, false) + } + pub fn assign_internal( + &self, + region: &mut Region<'_, Fr>, + lookups: &[(Fr, usize, bool)], + use_par: bool, + ) { // TODO; dedup lookups for (offset, (value, index, bit)) in lookups.iter().enumerate() { // TODO: either move the disabled row to the end of the assigment or get rid of it entirely. - let offset = offset + 1; // Start assigning at offet = 1 because the first row is disabled. + let offset = if !use_par { + offset + 1 // Start assigning at offet = 1 because the first row is disabled. + } else { + offset + }; let bytes = value.to_bytes(); let index_div_8 = index / 8; // index = (31 - index/8) * 8 @@ -107,6 +121,43 @@ impl KeyBitConfig { } } + pub fn assign_par(&self, layouter: &mut impl Layouter, lookups: &[(Fr, usize, bool)]) { + let num_threads = std::thread::available_parallelism() + .expect("get num threads") + .get(); + let chunk_size = (lookups.len() + num_threads - 1) / num_threads; + let mut is_first_pass = vec![true; num_threads]; + let assignments = lookups + .chunks(chunk_size) + .zip(is_first_pass.iter_mut()) + .enumerate() + .map(|(i, (lookups, is_first_pass))| { + move |mut region: Region<'_, Fr>| { + if *is_first_pass { + *is_first_pass = false; + + if !lookups.is_empty() { + // only meant to get region's shape. + let last_off = if i == 0 { + // 1st row is disabled. + lookups.len() + } else { + lookups.len() - 1 + }; + self.byte.assign(&mut region, last_off, 0_u64); + } + return Ok(()); + } + self.assign_internal(&mut region, lookups, true); + + Ok(()) + } + }) + .collect_vec(); + + layouter.assign_regions(|| "key_bit", assignments).unwrap(); + } + pub fn n_rows_required(lookups: &[(Fr, usize, bool)]) -> usize { // +1 because assigment starts on offset = 1 instead of offset = 0. 1 + lookups.len() diff --git a/src/gadgets/mpt_update.rs b/src/gadgets/mpt_update.rs index 3689104d..75da4ec5 100644 --- a/src/gadgets/mpt_update.rs +++ b/src/gadgets/mpt_update.rs @@ -28,13 +28,14 @@ use crate::{ MPTProofType, }; use ethers_core::types::Address; +use halo2_proofs::circuit::Layouter; use halo2_proofs::{ arithmetic::{Field, FieldExt}, circuit::{Region, Value}, halo2curves::bn256::Fr, plonk::ConstraintSystem, }; -use itertools::izip; +use itertools::{izip, Itertools}; use lazy_static::lazy_static; use strum::IntoEnumIterator; @@ -359,250 +360,306 @@ impl MptUpdateConfig { proofs: &[Proof], randomness: Value, ) -> usize { - let mut n_rows = 0; + let n_rows = proofs.iter().map(|proof| proof.n_rows()).sum(); let mut offset = 1; // selector on first row is disabled. for proof in proofs { - let proof_type = MPTProofType::from(proof.claim); - let storage_key = - randomness.map(|r| rlc(&u256_to_big_endian(&proof.claim.storage_key()), r)); - let old_value = randomness.map(|r| proof.claim.old_value_assignment(r)); - let new_value = randomness.map(|r| proof.claim.new_value_assignment(r)); - - for i in 0..proof.n_rows() { - self.proof_type.assign(region, offset + i, proof_type); - self.storage_key_rlc.assign(region, offset + i, storage_key); - self.old_value.assign(region, offset + i, old_value); - self.new_value.assign(region, offset + i, new_value); - } + self.assign_single_proof(region, proof, randomness, offset); + offset += proof.n_rows(); + log::debug!("offset: {}", offset); + } - let key = account_key(proof.claim.address); - let (other_key, other_leaf_data_hash) = - // checking if type 1 or type 2 - if proof.old.key != key { - assert!(proof.new.key == key || proof.new.key == proof.old.key); - (proof.old.key, proof.old.leaf_data_hash.unwrap()) - } else if proof.new.key != key { - assert!(proof.old.key == key); - (proof.new.key, proof.new.leaf_data_hash.unwrap()) - } else { - // neither is a type 1 path - // handle type 0 and type 2 paths here: - (proof.old.key, proof.new.leaf_data_hash.unwrap_or_default()) - }; - // Assign start row - self.segment_type.assign(region, offset, SegmentType::Start); - self.path_type.assign(region, offset, PathType::Start); - self.old_hash.assign(region, offset, proof.claim.old_root); - self.new_hash.assign(region, offset, proof.claim.new_root); + let expected_offset = Self::n_rows_required(proofs); + assert!( + offset == expected_offset, + "assign used {offset} rows but {expected_offset} rows expected from `n_rows_required`", + ); - self.key.assign(region, offset, key); - self.other_key.assign(region, offset, other_key); - self.domain.assign(region, offset, HashDomain::Pair); + n_rows + } - self.intermediate_values[0].assign( - region, - offset, - Fr::from_u128(address_high(proof.claim.address)), - ); - self.intermediate_values[1].assign( - region, - offset, - u64::from(address_low(proof.claim.address)), - ); + pub fn assign_single_proof( + &self, + region: &mut Region<'_, Fr>, + proof: &Proof, + randomness: Value, + mut offset: usize, + ) { + let proof_type = MPTProofType::from(proof.claim); + let storage_key = + randomness.map(|r| rlc(&u256_to_big_endian(&proof.claim.storage_key()), r)); + let old_value = randomness.map(|r| proof.claim.old_value_assignment(r)); + let new_value = randomness.map(|r| proof.claim.new_value_assignment(r)); + + for i in 0..proof.n_rows() { + self.proof_type.assign(region, offset + i, proof_type); + self.storage_key_rlc.assign(region, offset + i, storage_key); + self.old_value.assign(region, offset + i, old_value); + self.new_value.assign(region, offset + i, new_value); + } - let rlc_fr = |x: Fr| { - let mut bytes = x.to_bytes(); - bytes.reverse(); - randomness.map(|r| rlc(&bytes, r)) + let key = account_key(proof.claim.address); + let (other_key, other_leaf_data_hash) = + // checking if type 1 or type 2 + if proof.old.key != key { + assert!(proof.new.key == key || proof.new.key == proof.old.key); + (proof.old.key, proof.old.leaf_data_hash.unwrap()) + } else if proof.new.key != key { + assert!(proof.old.key == key); + (proof.new.key, proof.new.leaf_data_hash.unwrap()) + } else { + // neither is a type 1 path + // handle type 0 and type 2 paths here: + (proof.old.key, proof.new.leaf_data_hash.unwrap_or_default()) }; + // Assign start row + self.segment_type.assign(region, offset, SegmentType::Start); + self.path_type.assign(region, offset, PathType::Start); + self.old_hash.assign(region, offset, proof.claim.old_root); + self.new_hash.assign(region, offset, proof.claim.new_root); + + self.key.assign(region, offset, key); + self.other_key.assign(region, offset, other_key); + self.domain.assign(region, offset, HashDomain::Pair); - self.second_phase_intermediate_values[0].assign( - region, - offset, - rlc_fr(proof.claim.old_root), - ); - self.second_phase_intermediate_values[1].assign( - region, - offset, - rlc_fr(proof.claim.new_root), - ); + self.intermediate_values[0].assign( + region, + offset, + Fr::from_u128(address_high(proof.claim.address)), + ); + self.intermediate_values[1].assign( + region, + offset, + u64::from(address_low(proof.claim.address)), + ); - offset += 1; + let rlc_fr = |x: Fr| { + let mut bytes = x.to_bytes(); + bytes.reverse(); + randomness.map(|r| rlc(&bytes, r)) + }; - let n_account_trie_rows = - self.assign_account_trie_rows(region, offset, &proof.account_trie_rows); - for i in 0..n_account_trie_rows { - self.key.assign(region, offset + i, key); - self.other_key.assign(region, offset + i, other_key); - } - offset += n_account_trie_rows; - - let final_path_type = proof - .address_hash_traces - .first() - .map(|(_, _, _, _, _, is_padding_open, is_padding_close)| { - match (*is_padding_open, *is_padding_close) { - (false, false) => PathType::Common, - (false, true) => PathType::ExtensionOld, - (true, false) => PathType::ExtensionNew, - (true, true) => unreachable!(), - } - }) - .unwrap_or(PathType::Common); - let (final_old_hash, final_new_hash) = match proof.address_hash_traces.first() { - None => (proof.old.hash(), proof.new.hash()), - Some((_, _, old_hash, new_hash, _, _, _)) => (*old_hash, *new_hash), - }; + self.second_phase_intermediate_values[0].assign( + region, + offset, + rlc_fr(proof.claim.old_root), + ); + self.second_phase_intermediate_values[1].assign( + region, + offset, + rlc_fr(proof.claim.new_root), + ); - if proof.old_account.is_none() && proof.new_account.is_none() { - offset -= 1; - self.is_zero_gadgets[2].assign_value_and_inverse(region, offset, key - other_key); - self.is_zero_gadgets[3].assign_value_and_inverse(region, offset, final_old_hash); + offset += 1; - self.intermediate_values[3].assign(region, offset, other_leaf_data_hash); + let n_account_trie_rows = + self.assign_account_trie_rows(region, offset, &proof.account_trie_rows); + for i in 0..n_account_trie_rows { + self.key.assign(region, offset + i, key); + self.other_key.assign(region, offset + i, other_key); + } + offset += n_account_trie_rows; + + let final_path_type = proof + .address_hash_traces + .first() + .map(|(_, _, _, _, _, is_padding_open, is_padding_close)| { + match (*is_padding_open, *is_padding_close) { + (false, false) => PathType::Common, + (false, true) => PathType::ExtensionOld, + (true, false) => PathType::ExtensionNew, + (true, true) => unreachable!(), + } + }) + .unwrap_or(PathType::Common); + let (final_old_hash, final_new_hash) = match proof.address_hash_traces.first() { + None => (proof.old.hash(), proof.new.hash()), + Some((_, _, old_hash, new_hash, _, _, _)) => (*old_hash, *new_hash), + }; - n_rows += proof.n_rows(); - offset = 1 + n_rows; - continue; // we don't need to assign any leaf rows for empty accounts - } + if proof.old_account.is_none() && proof.new_account.is_none() { + offset -= 1; + self.is_zero_gadgets[2].assign_value_and_inverse(region, offset, key - other_key); + self.is_zero_gadgets[3].assign_value_and_inverse(region, offset, final_old_hash); - let segment_types = vec![ - SegmentType::AccountLeaf0, - SegmentType::AccountLeaf1, - SegmentType::AccountLeaf2, - SegmentType::AccountLeaf3, - ]; - - let leaf_path_type = match final_path_type { - PathType::Common => { - // need to check if the old or new account is type 2 empty - match ( - final_old_hash.is_zero_vartime(), - final_new_hash.is_zero_vartime(), - ) { - (true, true) => unreachable!("proof type must be AccountDoesNotExist"), - (true, false) => PathType::ExtensionNew, - (false, true) => PathType::ExtensionOld, - (false, false) => PathType::Common, - } - } - _ => final_path_type, - }; + self.intermediate_values[3].assign(region, offset, other_leaf_data_hash); - let directions = match proof_type { - MPTProofType::NonceChanged | MPTProofType::CodeSizeExists => { - vec![true, false, false, false] - } - MPTProofType::BalanceChanged => vec![true, false, false, true], - MPTProofType::PoseidonCodeHashExists => vec![true, true], - MPTProofType::CodeHashExists => vec![true, false, true, true], - MPTProofType::StorageChanged | MPTProofType::StorageDoesNotExist => { - vec![true, false, true, false] + return; // we don't need to assign any leaf rows for empty accounts + } + + let segment_types = vec![ + SegmentType::AccountLeaf0, + SegmentType::AccountLeaf1, + SegmentType::AccountLeaf2, + SegmentType::AccountLeaf3, + ]; + + let leaf_path_type = match final_path_type { + PathType::Common => { + // need to check if the old or new account is type 2 empty + match ( + final_old_hash.is_zero_vartime(), + final_new_hash.is_zero_vartime(), + ) { + (true, true) => unreachable!("proof type must be AccountDoesNotExist"), + (true, false) => PathType::ExtensionNew, + (false, true) => PathType::ExtensionOld, + (false, false) => PathType::Common, } - MPTProofType::AccountDoesNotExist => unreachable!(), - MPTProofType::AccountDestructed => unimplemented!(), - }; - let next_offset = offset + directions.len(); - - let old_hashes = proof - .old_account_leaf_hashes() - .unwrap_or_else(|| vec![final_old_hash; 4]); - let new_hashes = proof - .new_account_leaf_hashes() - .unwrap_or_else(|| vec![final_new_hash; 4]); - let siblings = proof.account_leaf_siblings(); - - for (i, (segment_type, sibling, old_hash, new_hash, direction)) in - izip!(segment_types, siblings, old_hashes, new_hashes, directions).enumerate() - { - if i == 0 { - self.is_zero_gadgets[3].assign_value_and_inverse(region, offset, old_hash); - self.domain.assign(region, offset + i, HashDomain::Leaf); - } else { - self.domain - .assign(region, offset + i, HashDomain::AccountFields); + } + _ => final_path_type, + }; + + let directions = match proof_type { + MPTProofType::NonceChanged | MPTProofType::CodeSizeExists => { + vec![true, false, false, false] + } + MPTProofType::BalanceChanged => vec![true, false, false, true], + MPTProofType::PoseidonCodeHashExists => vec![true, true], + MPTProofType::CodeHashExists => vec![true, false, true, true], + MPTProofType::StorageChanged | MPTProofType::StorageDoesNotExist => { + vec![true, false, true, false] + } + MPTProofType::AccountDoesNotExist => unreachable!(), + MPTProofType::AccountDestructed => unimplemented!(), + }; + let next_offset = offset + directions.len(); + + let old_hashes = proof + .old_account_leaf_hashes() + .unwrap_or_else(|| vec![final_old_hash; 4]); + let new_hashes = proof + .new_account_leaf_hashes() + .unwrap_or_else(|| vec![final_new_hash; 4]); + let siblings = proof.account_leaf_siblings(); + + for (i, (segment_type, sibling, old_hash, new_hash, direction)) in + izip!(segment_types, siblings, old_hashes, new_hashes, directions).enumerate() + { + if i == 0 { + self.is_zero_gadgets[3].assign_value_and_inverse(region, offset, old_hash); + self.domain.assign(region, offset + i, HashDomain::Leaf); + } else { + self.domain + .assign(region, offset + i, HashDomain::AccountFields); + } + self.segment_type.assign(region, offset + i, segment_type); + self.path_type.assign(region, offset + i, leaf_path_type); + self.sibling.assign(region, offset + i, sibling); + self.old_hash.assign(region, offset + i, old_hash); + self.new_hash.assign(region, offset + i, new_hash); + self.direction.assign(region, offset + i, direction); + self.key.assign(region, offset + i, key); + self.other_key.assign(region, offset + i, other_key); + + match segment_type { + SegmentType::AccountLeaf0 => { + let [.., other_key_column, other_leaf_data_hash_column] = + self.intermediate_values; + other_key_column.assign(region, offset, other_key); + other_leaf_data_hash_column.assign(region, offset, other_leaf_data_hash); } - self.segment_type.assign(region, offset + i, segment_type); - self.path_type.assign(region, offset + i, leaf_path_type); - self.sibling.assign(region, offset + i, sibling); - self.old_hash.assign(region, offset + i, old_hash); - self.new_hash.assign(region, offset + i, new_hash); - self.direction.assign(region, offset + i, direction); - self.key.assign(region, offset + i, key); - self.other_key.assign(region, offset + i, other_key); - - match segment_type { - SegmentType::AccountLeaf0 => { - let [.., other_key_column, other_leaf_data_hash_column] = + SegmentType::AccountLeaf3 => { + if let ClaimKind::Storage { key, .. } | ClaimKind::IsEmpty(Some(key)) = + proof.claim.kind + { + self.key.assign(region, offset + 3, proof.storage.key()); + let [storage_key_high, storage_key_low, new_domain, ..] = self.intermediate_values; - other_key_column.assign(region, offset, other_key); - other_leaf_data_hash_column.assign(region, offset, other_leaf_data_hash); - } - SegmentType::AccountLeaf3 => { - if let ClaimKind::Storage { key, .. } | ClaimKind::IsEmpty(Some(key)) = - proof.claim.kind - { - self.key.assign(region, offset + 3, proof.storage.key()); - let [storage_key_high, storage_key_low, new_domain, ..] = - self.intermediate_values; - let [rlc_storage_key_high, rlc_storage_key_low, ..] = - self.second_phase_intermediate_values; - assign_word_rlc( - region, - offset + 3, - key, - [storage_key_high, storage_key_low], - [rlc_storage_key_high, rlc_storage_key_low], - randomness, - ); - self.other_key - .assign(region, offset + 3, proof.storage.other_key()); - new_domain.assign(region, offset + 3, HashDomain::AccountFields); - } + let [rlc_storage_key_high, rlc_storage_key_low, ..] = + self.second_phase_intermediate_values; + assign_word_rlc( + region, + offset + 3, + key, + [storage_key_high, storage_key_low], + [rlc_storage_key_high, rlc_storage_key_low], + randomness, + ); + self.other_key + .assign(region, offset + 3, proof.storage.other_key()); + new_domain.assign(region, offset + 3, HashDomain::AccountFields); } - _ => {} - }; - } - self.key.assign(region, offset, key); - self.other_key.assign(region, offset, other_key); - self.is_zero_gadgets[2].assign_value_and_inverse(region, offset, key - other_key); - if let ClaimKind::CodeHash { old, new } = proof.claim.kind { - let [old_high, old_low, new_high, new_low, ..] = self.intermediate_values; - let [old_rlc_high, old_rlc_low, new_rlc_high, new_rlc_low, ..] = - self.second_phase_intermediate_values; - if let Some(value) = old { - assign_word_rlc( - region, - offset + 3, - value, - [old_high, old_low], - [old_rlc_high, old_rlc_low], - randomness, - ); - } - if let Some(value) = new { - assign_word_rlc( - region, - offset + 3, - value, - [new_high, new_low], - [new_rlc_high, new_rlc_low], - randomness, - ); } + _ => {} }; - self.assign_storage(region, next_offset, &proof.storage, randomness); - n_rows += proof.n_rows(); - offset = 1 + n_rows; } + self.key.assign(region, offset, key); + self.other_key.assign(region, offset, other_key); + self.is_zero_gadgets[2].assign_value_and_inverse(region, offset, key - other_key); + if let ClaimKind::CodeHash { old, new } = proof.claim.kind { + let [old_high, old_low, new_high, new_low, ..] = self.intermediate_values; + let [old_rlc_high, old_rlc_low, new_rlc_high, new_rlc_low, ..] = + self.second_phase_intermediate_values; + if let Some(value) = old { + assign_word_rlc( + region, + offset + 3, + value, + [old_high, old_low], + [old_rlc_high, old_rlc_low], + randomness, + ); + } + if let Some(value) = new { + assign_word_rlc( + region, + offset + 3, + value, + [new_high, new_low], + [new_rlc_high, new_rlc_low], + randomness, + ); + } + }; + self.assign_storage(region, next_offset, &proof.storage, randomness); + } - let expected_offset = Self::n_rows_required(proofs); - debug_assert!( - offset == expected_offset, - "assign used {offset} rows but {expected_offset} rows expected from `n_rows_required`", - ); + pub(crate) fn assign_par( + &self, + layouter: &mut impl Layouter, + proofs: &[Proof], + randomness: Value, + ) -> usize { + let mut is_first_passes = vec![true; proofs.len()]; + let update_assignments = proofs + .iter() + .zip(is_first_passes.iter_mut()) + .enumerate() + .map(|(i, (proof, is_first_pass))| { + move |mut region: Region<'_, Fr>| { + let n_rows = proof.n_rows(); + let (first_off, last_off) = if i == 0 { + // The first region has (1 + proof.n_rows()) rows + (1, n_rows) + } else { + (0, n_rows - 1) + }; + if *is_first_pass { + log::debug!("n_rows for update {}: {}", i, n_rows); + *is_first_pass = false; + // just want the layouter to know this region's shape. + // we use new_hash because this col is assigned by mpt update regions + // and padding region. + self.proof_type.assign( + &mut region, + last_off, + MPTProofType::AccountDoesNotExist, + ); - n_rows + return Ok(()); + } + self.assign_single_proof(&mut region, proof, randomness, first_off); + + Ok(()) + } + }) + .collect_vec(); + + layouter + .assign_regions(|| "mpt updates", update_assignments) + .unwrap(); + + proofs.iter().map(|proof| proof.n_rows()).sum() } pub fn n_rows_required(proofs: &[Proof]) -> usize { diff --git a/src/mpt.rs b/src/mpt.rs index da923762..21fda256 100644 --- a/src/mpt.rs +++ b/src/mpt.rs @@ -22,6 +22,7 @@ use halo2_proofs::{ plonk::{Challenge, ConstraintSystem, Error, Expression, VirtualCells}, }; use itertools::Itertools; +use std::time::Instant; /// Config for MptCircuit #[derive(Clone)] @@ -120,50 +121,132 @@ impl MptCircuitConfig { let randomness = self.rlc_randomness.value(layouter); let (u32s, u64s, u128s, frs) = byte_representations(proofs); + let mpt_updates_assign_dur = Instant::now(); + let use_par = std::env::var("PARALLEL_SYN").map_or(false, |s| s == *"true"); + if use_par { + let n_assigned_rows = self.mpt_update.assign_par(layouter, proofs, randomness); + + layouter.assign_region( + || "mpt update padding rows", + |mut region| { + if n_assigned_rows == 0 { + // first row is all-zeroes row + for offset in 1..n_rows { + self.mpt_update.assign_padding_row(&mut region, offset); + } + } else { + for offset in 0..(n_rows - (1 + n_assigned_rows)) { + self.mpt_update.assign_padding_row(&mut region, offset); + } + } + Ok(()) + }, + )?; + } else { + layouter.assign_region( + || "mpt update", + |mut region| { + let n_assigned_rows = self.mpt_update.assign(&mut region, proofs, randomness); + + assert!( + 2 + n_assigned_rows <= n_rows, + "mpt circuit requires {n_assigned_rows} rows for mpt updates + 1 initial \ + all-zero row + at least 1 final padding row. Only {n_rows} rows available." + ); + + for offset in (1 + n_assigned_rows)..n_rows { + self.mpt_update.assign_padding_row(&mut region, offset); + } + + Ok(()) + }, + )?; + } + log::debug!( + "mpt updates assignment(use_par = {}) took {:?}", + use_par, + mpt_updates_assign_dur.elapsed() + ); + + if use_par { + let key_bit_time = { + let dur = Instant::now(); + self.key_bit.assign_par(layouter, &key_bit_lookups(proofs)); + dur.elapsed() + }; + log::debug!("mpt key_bit assignment took {:?}", key_bit_time); + } + + // pad canonical_representation to fixed count + // notice each input cost 32 rows in canonical_representation, and inside + // assign one extra input is added + let (keys, get_keys_time) = { + let dur = Instant::now(); + let mut keys = mpt_update_keys(proofs); + keys.sort(); + keys.dedup(); + (keys, dur.elapsed()) + }; + let total_rep_size = n_rows / 32 - 1; + assert!( + total_rep_size >= keys.len(), + "no enough space for canonical representation of all keys (need {})", + keys.len() + ); + log::debug!("get keys took {:?}", get_keys_time); + + if use_par { + let canon_repr_time = { + let dur = Instant::now(); + self.canonical_representation + .assign_par(layouter, randomness, &keys, n_rows); + dur.elapsed() + }; + log::debug!("canonical_repr assignment took {:?}", canon_repr_time); + } + layouter.assign_region( - || "mpt circuit", + || "mpt keys", |mut region| { for offset in 1..n_rows { self.selector.enable(&mut region, offset); } - // pad canonical_representation to fixed count - // notice each input cost 32 rows in canonical_representation, and inside - // assign one extra input is added - let mut keys = mpt_update_keys(proofs); - keys.sort(); - keys.dedup(); - let total_rep_size = n_rows / 32 - 1; - assert!( - total_rep_size >= keys.len(), - "no enough space for canonical representation of all keys (need {})", - keys.len() - ); + let keys_assign_dur = Instant::now(); + if !use_par { + self.canonical_representation + .assign(&mut region, randomness, &keys, n_rows); + self.key_bit.assign(&mut region, &key_bit_lookups(proofs)); + } - self.canonical_representation - .assign(&mut region, randomness, &keys, n_rows); - self.key_bit.assign(&mut region, &key_bit_lookups(proofs)); - self.byte_bit.assign(&mut region); - self.byte_representation.assign( - &mut region, - &u32s, - &u64s, - &u128s, - &frs, - randomness, + let byte_bit_time = { + let dur = Instant::now(); + self.byte_bit.assign(&mut region); + dur.elapsed() + }; + let byte_repr_time = { + let dur = Instant::now(); + self.byte_representation.assign( + &mut region, + &u32s, + &u64s, + &u128s, + &frs, + randomness, + ); + dur.elapsed() + }; + let keys_assign_time = keys_assign_dur.elapsed(); + log::debug!("keys assignment took {:?}", keys_assign_time); + log::debug!( + "byte_bit: {}", + byte_bit_time.as_micros() as f64 / keys_assign_time.as_micros() as f64 ); - - let n_assigned_rows = self.mpt_update.assign(&mut region, proofs, randomness); - - assert!( - 2 + n_assigned_rows <= n_rows, - "mpt circuit requires {n_assigned_rows} rows for mpt updates + 1 initial \ - all-zero row + at least 1 final padding row. Only {n_rows} rows available." + log::debug!( + "byte_repr: {}", + byte_repr_time.as_micros() as f64 / keys_assign_time.as_micros() as f64 ); - for offset in 1 + n_assigned_rows..n_rows { - self.mpt_update.assign_padding_row(&mut region, offset); - } self.is_final_row.enable(&mut region, n_rows - 1); Ok(()) diff --git a/src/tests.rs b/src/tests.rs index 8c56186b..38a92797 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -765,6 +765,7 @@ fn empty_storage_type_1_update_c() { #[test] fn multiple_updates() { + env_logger::init(); let witness = vec![ ( MPTProofType::StorageChanged,