From c3bbe1ca8f1df461ff474a92092d487d049592f4 Mon Sep 17 00:00:00 2001 From: krushimir Date: Wed, 15 Jan 2025 22:49:18 +0100 Subject: [PATCH 01/13] feat: adds concurrent `Smt::compute_mutations` --- src/main.rs | 139 +++++++++++++------- src/merkle/smt/full/mod.rs | 12 +- src/merkle/smt/mod.rs | 245 +++++++++++++++++++++++++++++++++-- src/merkle/smt/simple/mod.rs | 2 +- src/merkle/smt/tests.rs | 128 +++++++++++++++++- 5 files changed, 460 insertions(+), 66 deletions(-) diff --git a/src/main.rs b/src/main.rs index 5ee3834c..e5f5b3c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,8 +13,14 @@ use rand_utils::rand_value; #[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")] pub struct BenchmarkCmd { /// Size of the tree - #[clap(short = 's', long = "size")] + #[clap(short = 's', long = "size", default_value = "10000")] size: usize, + /// Number of insertions + #[clap(short = 'i', long = "insertions", default_value = "10000")] + insertions: usize, + /// Number of updates + #[clap(short = 'u', long = "updates", default_value = "10000")] + updates: usize, } fn main() { @@ -25,7 +31,10 @@ fn main() { pub fn benchmark_smt() { let args = BenchmarkCmd::parse(); let tree_size = args.size; + let insertions = args.insertions; + let updates = args.updates; + assert!(updates <= insertions + tree_size, "Cannot update more than insertions + size"); // prepare the `leaves` vector for tree creation let mut entries = Vec::new(); for i in 0..tree_size { @@ -35,35 +44,41 @@ pub fn benchmark_smt() { } let mut tree = construction(entries.clone(), tree_size).unwrap(); - insertion(&mut tree).unwrap(); - batched_insertion(&mut tree).unwrap(); - batched_update(&mut tree, entries).unwrap(); + insertion(&mut tree, insertions).unwrap(); + batched_insertion(&mut tree, insertions).unwrap(); + batched_update(&mut tree, entries, updates).unwrap(); proof_generation(&mut tree).unwrap(); } /// Runs the construction benchmark for [`Smt`], returning the constructed tree. pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result { + let cloned_entries = entries.clone(); println!("Running a construction benchmark:"); let now = Instant::now(); - let tree = Smt::with_entries(entries)?; + let tree = Smt::with_entries(cloned_entries)?; let elapsed = now.elapsed().as_secs_f32(); println!("Constructed a SMT with {size} key-value pairs in {elapsed:.1} seconds"); - println!("Number of leaf nodes: {}\n", tree.leaves().count()); + let now = Instant::now(); + let tree_sequential = Smt::with_entries_sequential(entries)?; + let compute_elapsed_sequential = now.elapsed().as_secs_f32(); + + assert_eq!(tree.root(), tree_sequential.root()); + println!("Constructed a SMT sequentially with {size} key-value pairs in {elapsed:.1} seconds"); + let factor = compute_elapsed_sequential / elapsed; + println!("Parallel implementation is {factor}x times faster."); Ok(tree) } /// Runs the insertion benchmark for the [`Smt`]. -pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> { - const NUM_INSERTIONS: usize = 1_000; - +pub fn insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> { println!("Running an insertion benchmark:"); let size = tree.num_leaves(); let mut insertion_times = Vec::new(); - for i in 0..NUM_INSERTIONS { + for i in 0..insertions { let test_key = Rpo256::hash(&rand_value::().to_be_bytes()); let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)]; @@ -74,22 +89,20 @@ pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> { } println!( - "An average insertion time measured by {NUM_INSERTIONS} inserts into an SMT with {size} leaves is {:.0} μs\n", + "An average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n", // calculate the average - insertion_times.iter().sum::() as f64 / (NUM_INSERTIONS as f64), + insertion_times.iter().sum::() as f64 / (insertions as f64), ); Ok(()) } -pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> { - const NUM_INSERTIONS: usize = 1_000; - +pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> { println!("Running a batched insertion benchmark:"); let size = tree.num_leaves(); - let new_pairs: Vec<(RpoDigest, Word)> = (0..NUM_INSERTIONS) + let new_pairs: Vec<(RpoDigest, Word)> = (0..insertions) .map(|i| { let key = Rpo256::hash(&rand_value::().to_be_bytes()); let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)]; @@ -97,24 +110,41 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> { }) .collect(); + let cloned_new_pairs = new_pairs.clone(); let now = Instant::now(); - let mutations = tree.compute_mutations(new_pairs); + let mutations = tree.compute_mutations(cloned_new_pairs); let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms let now = Instant::now(); - tree.apply_mutations(mutations)?; - let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms + let mutations_sequential = tree.compute_mutations_sequential(new_pairs); + let compute_elapsed_sequential = now.elapsed().as_secs_f64() * 1000_f64; // time in ms + + assert_eq!(mutations.root(), mutations_sequential.root()); + assert_eq!(mutations.node_mutations(), mutations_sequential.node_mutations()); + assert_eq!(mutations.new_pairs(), mutations_sequential.new_pairs()); println!( - "An average insert-batch computation time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "An average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", compute_elapsed, - compute_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs + compute_elapsed * 1000_f64 / insertions as f64, // time in μs + ); + + println!( + "An average insert-batch sequential computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + compute_elapsed_sequential, + compute_elapsed_sequential * 1000_f64 / insertions as f64, // time in μs ); + let parallel_factor = compute_elapsed_sequential / compute_elapsed; + println!("Parallel implementation is {parallel_factor}x times faster."); + + let now = Instant::now(); + tree.apply_mutations(mutations)?; + let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms println!( - "An average insert-batch application time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "An average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", apply_elapsed, - apply_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs + apply_elapsed * 1000_f64 / insertions as f64, // time in μs ); println!( @@ -127,8 +157,11 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> { Ok(()) } -pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result<(), MerkleError> { - const NUM_UPDATES: usize = 1_000; +pub fn batched_update( + tree: &mut Smt, + entries: Vec<(RpoDigest, Word)>, + updates: usize, +) -> Result<(), MerkleError> { const REMOVAL_PROBABILITY: f64 = 0.2; println!("Running a batched update benchmark:"); @@ -136,41 +169,59 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result let size = tree.num_leaves(); let mut rng = thread_rng(); - let new_pairs = - entries - .into_iter() - .choose_multiple(&mut rng, NUM_UPDATES) - .into_iter() - .map(|(key, _)| { - let value = if rng.gen_bool(REMOVAL_PROBABILITY) { - EMPTY_WORD - } else { - [ONE, ONE, ONE, Felt::new(rng.gen())] - }; + let new_pairs: Vec<(RpoDigest, Word)> = entries + .into_iter() + .choose_multiple(&mut rng, updates) + .into_iter() + .map(|(key, _)| { + let value = if rng.gen_bool(REMOVAL_PROBABILITY) { + EMPTY_WORD + } else { + [ONE, ONE, ONE, Felt::new(rng.gen())] + }; - (key, value) - }); + (key, value) + }) + .collect(); - assert_eq!(new_pairs.len(), NUM_UPDATES); + assert_eq!(new_pairs.len(), updates); + let cloned_new_pairs = new_pairs.clone(); let now = Instant::now(); - let mutations = tree.compute_mutations(new_pairs); + let mutations = tree.compute_mutations(cloned_new_pairs); let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms + let now = Instant::now(); + let mutations_sequential = tree.compute_mutations_sequential(new_pairs); + let compute_elapsed_sequential = now.elapsed().as_secs_f64() * 1000_f64; // time in ms + + assert_eq!(mutations.root(), mutations_sequential.root()); + assert_eq!(mutations.node_mutations(), mutations_sequential.node_mutations()); + assert_eq!(mutations.new_pairs(), mutations_sequential.new_pairs()); + let now = Instant::now(); tree.apply_mutations(mutations)?; let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms println!( - "An average update-batch computation time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "An average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", compute_elapsed, - compute_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs + compute_elapsed * 1000_f64 / updates as f64, // time in μs + ); + + println!( + "An average update-batch sequential computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + compute_elapsed_sequential, + compute_elapsed_sequential * 1000_f64 / updates as f64, // time in μs ); + let factor = compute_elapsed_sequential / compute_elapsed; + println!("Parallel implementaton is {factor}x times faster."); + println!( - "An average update-batch application time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "An average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", apply_elapsed, - apply_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs + apply_elapsed * 1000_f64 / updates as f64, // time in μs ); println!( diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 5cd641e4..a73ac468 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -255,7 +255,17 @@ impl Smt { >::compute_mutations(self, kv_pairs) } - /// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree. + /// Sequential implementation of [`Smt::compute_mutations()`]. + pub fn compute_mutations_sequential( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet { + >::compute_mutations_sequential(self, kv_pairs) + } + + /// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to + /// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the + /// updated tree will revert the changes. /// /// # Errors /// If `mutations` was computed on a tree with a different root than this one, returns diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index ec439571..684dbdbd 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,4 +1,4 @@ -use alloc::{collections::BTreeMap, vec::Vec}; +use alloc::vec::Vec; use core::{hash::Hash, mem}; use num::Integer; @@ -36,6 +36,7 @@ type UnorderedMap = alloc::collections::BTreeMap; type InnerNodes = UnorderedMap; type Leaves = UnorderedMap; type NodeMutations = UnorderedMap; +type MutatedLeavesResult = (Vec>, UnorderedMap); /// An abstract description of a sparse Merkle tree. /// @@ -178,6 +179,25 @@ pub(crate) trait SparseMerkleTree { fn compute_mutations( &self, kv_pairs: impl IntoIterator, + ) -> MutationSet + where + Self: Sized + Sync, + { + #[cfg(feature = "concurrent")] + { + self.compute_mutations_subtree(kv_pairs) + } + #[cfg(not(feature = "concurrent"))] + { + self.compute_mutations_sequential(kv_pairs) + } + } + + /// Sequential version of [`SparseMerkleTree::compute_mutations()`]. + /// This is the default implementation. + fn compute_mutations_sequential( + &self, + kv_pairs: impl IntoIterator, ) -> MutationSet { use NodeMutation::*; @@ -265,6 +285,60 @@ pub(crate) trait SparseMerkleTree { } } + /// Parallel implementation of [`SparseMerkleTree::compute_mutations()`]. + /// + /// This method recursively tracks mutations across 8-depth subtrees from the bottom up, + /// ultimately reconstructing the complete mutation set for the entire tree. + /// + /// The implementation is similar to [`SparseMerkleTree::build_subtrees_from_sorted_entries()`], + /// sharing the same constraint that the depth must be a multiple of 8. + #[cfg(feature = "concurrent")] + fn compute_mutations_subtree( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet + where + Self: Sized + Sync, + { + use rayon::prelude::*; + + // Collect and sort key-value pairs by their corresponding leaf index + let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); + sorted_kv_pairs.sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); + + // Convert sorted pairs into mutated leaves and capture any new pairs + let (mut subtree_leaves, new_pairs) = self.sorted_pairs_to_mutated_leaves(sorted_kv_pairs); + let mut node_mutations = NodeMutations::default(); + + // Process each depth level in reverse, stepping by the subtree depth + for depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // Parallel processing of each subtree to generate mutations and roots + let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted() && !subtree.is_empty()); + self.build_subtree_mutations(subtree, DEPTH, depth) + }) + .unzip(); + + // Prepare leaves for the next depth level + subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + + // Aggregate all node mutations + node_mutations.extend(mutations_per_subtree.into_iter().flatten()); + + debug_assert!(!subtree_leaves.is_empty()); + } + + // Finalize the mutation set with updated roots and mutations + MutationSet { + old_root: self.root(), + new_root: subtree_leaves[0][0].hash, + node_mutations, + new_pairs, + } + } + /// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to /// this tree. /// @@ -446,6 +520,149 @@ pub(crate) trait SparseMerkleTree { value: &Self::Value, ) -> Self::Leaf; + /// Computes leaves from a set of key-value pairs and current leaf values. + /// Deried from `sorted_pairs_to_leaves` + /// + /// TODO: refactor and merge functionality with `sorted_pairs_to_leaves`? + fn sorted_pairs_to_mutated_leaves( + &self, + pairs: Vec<(Self::Key, Self::Value)>, + ) -> MutatedLeavesResult { + debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); + + let mut accumulated_leaves = Vec::with_capacity(pairs.len() / 2); + let mut new_pairs = UnorderedMap::new(); + let mut current_leaf_buffer = Vec::new(); + + let mut iter = pairs.into_iter().peekable(); + while let Some((key, value)) = iter.next() { + let col = Self::key_to_leaf_index(&key).index.value(); + + if let Some((next_key, _)) = iter.peek() { + let next_col = Self::key_to_leaf_index(next_key).index.value(); + debug_assert!(next_col >= col); + } + + current_leaf_buffer.push((key.clone(), value)); + + // If the next pair is the same column, continue accumulating + if iter + .peek() + .is_some_and(|(next_key, _)| Self::key_to_leaf_index(next_key).index.value() == col) + { + continue; + } + + // Process buffered pairs + let leaf_pairs = mem::take(&mut current_leaf_buffer); + let mut leaf = self.get_leaf(&key); + + for (key, value) in leaf_pairs { + match new_pairs.get(&key) { + Some(existing_value) if existing_value == &value => continue, + _ => { + leaf = self.construct_prospective_leaf(leaf, &key, &value); + new_pairs.insert(key, value); + }, + } + } + + let hash = Self::hash_leaf(&leaf); + accumulated_leaves.push(SubtreeLeaf { col, hash }); + + debug_assert!(current_leaf_buffer.is_empty()); + } + + let leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); + (leaves, new_pairs) + } + + // Computes the node mutations and the root of a subtree + fn build_subtree_mutations( + &self, + mut leaves: Vec, + tree_depth: u8, + bottom_depth: u8, + ) -> (NodeMutations, SubtreeLeaf) + where + Self: Sized, + { + debug_assert!(bottom_depth <= tree_depth); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); + debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); + + let subtree_root_depth = bottom_depth - SUBTREE_DEPTH; + let mut node_mutations: NodeMutations = Default::default(); + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + + for current_depth in (subtree_root_depth..bottom_depth).rev() { + debug_assert!(current_depth <= bottom_depth); + + let next_depth = current_depth + 1; + let mut iter = leaves.drain(..).peekable(); + + while let Some(first_leaf) = iter.next() { + let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent(); + let parent_node = self.get_inner_node(parent_index); + let (left, right) = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); + + let combined_node = InnerNode { left: left.hash, right: right.hash }; + let combined_hash = combined_node.hash(); + + let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth); + + // Add the parent node even if it is empty for proper upward updates + next_leaves.push(SubtreeLeaf { + col: parent_index.value(), + hash: combined_hash, + }); + + node_mutations.insert( + parent_index, + if combined_hash != empty_hash { + NodeMutation::Addition(combined_node) + } else { + NodeMutation::Removal + }, + ); + } + drop(iter); + leaves = mem::take(&mut next_leaves); + } + + debug_assert_eq!(leaves.len(), 1); + let root_leaf = leaves.pop().unwrap(); + (node_mutations, root_leaf) + } + + // Returns the sibling pair based on the first leaf and the current depth + // + // This is a helper function that is used to build the subtree mutations + // The first leaf is the leaf that we are currently processing + // The current depth is the depth of the current subtree + fn fetch_sibling_pair( + iter: &mut core::iter::Peekable>, + first_leaf: SubtreeLeaf, + parent_node: InnerNode, + ) -> (SubtreeLeaf, SubtreeLeaf) { + let is_right_node = first_leaf.col.is_odd(); + + if is_right_node { + let left_leaf = SubtreeLeaf { + col: first_leaf.col - 1, + hash: parent_node.left, + }; + (left_leaf, first_leaf) + } else { + let right_col = first_leaf.col + 1; + let right_leaf = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(), + _ => SubtreeLeaf { col: right_col, hash: parent_node.right }, + }; + (first_leaf, right_leaf) + } + } + /// Maps a key to a leaf index fn key_to_leaf_index(key: &Self::Key) -> LeafIndex; @@ -551,16 +768,16 @@ pub(crate) trait SparseMerkleTree { } = Self::sorted_pairs_to_leaves(entries); for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees - .into_par_iter() - .map(|subtree| { - debug_assert!(subtree.is_sorted()); - debug_assert!(!subtree.is_empty()); - - let (nodes, subtree_root) = build_subtree(subtree, DEPTH, current_depth); - (nodes, subtree_root) - }) - .unzip(); + let (nodes, mut subtree_roots): (Vec>, Vec) = + leaf_subtrees + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted()); + debug_assert!(!subtree.is_empty()); + let (nodes, subtree_root) = build_subtree(subtree, DEPTH, current_depth); + (nodes, subtree_root) + }) + .unzip(); leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); accumulated_nodes.extend(nodes.into_iter().flatten()); @@ -919,12 +1136,12 @@ fn build_subtree( mut leaves: Vec, tree_depth: u8, bottom_depth: u8, -) -> (BTreeMap, SubtreeLeaf) { +) -> (UnorderedMap, SubtreeLeaf) { debug_assert!(bottom_depth <= tree_depth); debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); let subtree_root = bottom_depth - SUBTREE_DEPTH; - let mut inner_nodes: BTreeMap = Default::default(); + let mut inner_nodes: UnorderedMap = Default::default(); let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); for next_depth in (subtree_root..bottom_depth).rev() { debug_assert!(next_depth <= bottom_depth); @@ -995,7 +1212,7 @@ pub fn build_subtree_for_bench( leaves: Vec, tree_depth: u8, bottom_depth: u8, -) -> (BTreeMap, SubtreeLeaf) { +) -> (UnorderedMap, SubtreeLeaf) { build_subtree(leaves, tree_depth, bottom_depth) } diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 166cc982..67215adf 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -233,7 +233,7 @@ impl SimpleSmt { &self, kv_pairs: impl IntoIterator, Word)>, ) -> MutationSet, Word> { - >::compute_mutations(self, kv_pairs) + >::compute_mutations_sequential(self, kv_pairs) } /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index 23794c67..711a317c 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -1,13 +1,15 @@ use alloc::{collections::BTreeMap, vec::Vec}; +use rand::{prelude::IteratorRandom, thread_rng, Rng}; + use super::{ - build_subtree, InnerNode, LeafIndex, NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree, - SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, SUBTREE_DEPTH, + build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, SmtLeaf, + SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, SUBTREE_DEPTH, }; use crate::{ hash::rpo::RpoDigest, - merkle::{Smt, SMT_DEPTH}, - Felt, Word, ONE, + merkle::{smt::UnorderedMap, Smt, SMT_DEPTH}, + Felt, Word, EMPTY_WORD, ONE, }; fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { @@ -109,6 +111,28 @@ fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { .collect() } +fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> { + const REMOVAL_PROBABILITY: f64 = 0.2; + let mut rng = thread_rng(); + + let mut sorted_entries: Vec<(RpoDigest, Word)> = entries + .into_iter() + .choose_multiple(&mut rng, updates) + .into_iter() + .map(|(key, _)| { + let value = if rng.gen_bool(REMOVAL_PROBABILITY) { + EMPTY_WORD + } else { + [ONE, ONE, ONE, Felt::new(rng.gen())] + }; + + (key, value) + }) + .collect(); + sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value()); + sorted_entries +} + #[test] fn test_single_subtree() { // A single subtree's worth of leaves. @@ -222,7 +246,7 @@ fn test_singlethreaded_subtrees() { for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { // There's no flat_map_unzip(), so this is the best we can do. - let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees + let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees .into_iter() .enumerate() .map(|(i, subtree)| { @@ -324,7 +348,7 @@ fn test_multithreaded_subtrees() { } = Smt::sorted_pairs_to_leaves(entries); for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees + let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees .into_par_iter() .enumerate() .map(|(i, subtree)| { @@ -415,3 +439,95 @@ fn test_with_entries_parallel() { assert_eq!(smt.root(), control.root()); assert_eq!(smt, control); } + +#[test] +fn test_singlethreaded_subtree_mutations() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + let updates = generate_updates(entries.clone(), 1000); + + let tree = Smt::with_entries_sequential(entries.clone()).unwrap(); + let control = tree.compute_mutations_sequential(updates.clone()); + + let mut node_mutations = NodeMutations::default(); + + let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_leaves(updates); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // There's no flat_map_unzip(), so this is the best we can do. + let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves + .into_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + + // Calculate the mutations for this subtree. + let (mutations_per_subtree, subtree_root) = + tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth); + + // Check that the mutations match the control tree. + for (&index, mutation) in mutations_per_subtree.iter() { + let control_mutation = control.node_mutations().get(&index).unwrap(); + assert_eq!( + control_mutation, mutation, + "depth {} subtree {}: mutation does not match control at index {:?}", + current_depth, i, index, + ); + } + + (mutations_per_subtree, subtree_root) + }) + .unzip(); + + subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + node_mutations.extend(mutations_per_subtree.into_iter().flatten()); + + assert!(!subtree_leaves.is_empty(), "on depth {current_depth}"); + } + + let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap(); + let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap(); + // Check that the new root matches the control. + assert_eq!(control.new_root, root_leaf.hash); + + // Check that the node mutations match the control. + assert_eq!(control.node_mutations().len(), node_mutations.len()); + for (&index, mutation) in control.node_mutations().iter() { + let test_mutation = node_mutations.get(&index).unwrap(); + assert_eq!(test_mutation, mutation); + } + // Check that the new pairs match the control + assert_eq!(control.new_pairs.len(), new_pairs.len()); + for (&key, &value) in control.new_pairs.iter() { + let test_value = new_pairs.get(&key).unwrap(); + assert_eq!(test_value, &value); + } +} + +#[test] +#[cfg(feature = "concurrent")] +fn test_compute_mutations_parallel() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + let tree = Smt::with_entries(entries.clone()).unwrap(); + + let updates = generate_updates(entries, 1000); + + let control = tree.compute_mutations_sequential(updates.clone()); + let mutations = tree.compute_mutations(updates); + + assert_eq!(mutations.root(), control.root()); + assert_eq!(mutations.old_root(), control.old_root()); + assert_eq!(mutations.node_mutations(), control.node_mutations()); + assert_eq!(mutations.new_pairs(), control.new_pairs()); +} From c447c6f7429c56b0ba7d524b98268c6b8d464264 Mon Sep 17 00:00:00 2001 From: krushimir Date: Fri, 17 Jan 2025 20:14:19 +0100 Subject: [PATCH 02/13] chore: cleanup bench --- src/main.rs | 68 +++++++++-------------------------------------------- 1 file changed, 11 insertions(+), 57 deletions(-) diff --git a/src/main.rs b/src/main.rs index e5f5b3c6..07bdf15b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,7 +34,7 @@ pub fn benchmark_smt() { let insertions = args.insertions; let updates = args.updates; - assert!(updates <= insertions + tree_size, "Cannot update more than insertions + size"); + assert!(updates <= tree_size, "Cannot update more than `size`"); // prepare the `leaves` vector for tree creation let mut entries = Vec::new(); for i in 0..tree_size { @@ -44,30 +44,19 @@ pub fn benchmark_smt() { } let mut tree = construction(entries.clone(), tree_size).unwrap(); - insertion(&mut tree, insertions).unwrap(); - batched_insertion(&mut tree, insertions).unwrap(); - batched_update(&mut tree, entries, updates).unwrap(); + insertion(&mut tree.clone(), insertions).unwrap(); + batched_insertion(&mut tree.clone(), insertions).unwrap(); + batched_update(&mut tree.clone(), entries, updates).unwrap(); proof_generation(&mut tree).unwrap(); } /// Runs the construction benchmark for [`Smt`], returning the constructed tree. pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result { - let cloned_entries = entries.clone(); println!("Running a construction benchmark:"); let now = Instant::now(); - let tree = Smt::with_entries(cloned_entries)?; + let tree = Smt::with_entries(entries)?; let elapsed = now.elapsed().as_secs_f32(); - println!("Constructed a SMT with {size} key-value pairs in {elapsed:.1} seconds"); - - let now = Instant::now(); - let tree_sequential = Smt::with_entries_sequential(entries)?; - let compute_elapsed_sequential = now.elapsed().as_secs_f32(); - - assert_eq!(tree.root(), tree_sequential.root()); - println!("Constructed a SMT sequentially with {size} key-value pairs in {elapsed:.1} seconds"); - let factor = compute_elapsed_sequential / elapsed; - println!("Parallel implementation is {factor}x times faster."); Ok(tree) } @@ -110,39 +99,22 @@ pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), Merkle }) .collect(); - let cloned_new_pairs = new_pairs.clone(); let now = Instant::now(); - let mutations = tree.compute_mutations(cloned_new_pairs); + let mutations = tree.compute_mutations(new_pairs); let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms - let now = Instant::now(); - let mutations_sequential = tree.compute_mutations_sequential(new_pairs); - let compute_elapsed_sequential = now.elapsed().as_secs_f64() * 1000_f64; // time in ms - - assert_eq!(mutations.root(), mutations_sequential.root()); - assert_eq!(mutations.node_mutations(), mutations_sequential.node_mutations()); - assert_eq!(mutations.new_pairs(), mutations_sequential.new_pairs()); - println!( - "An average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "An average insert-batch computation time measured by a 1k-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", compute_elapsed, compute_elapsed * 1000_f64 / insertions as f64, // time in μs ); - println!( - "An average insert-batch sequential computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", - compute_elapsed_sequential, - compute_elapsed_sequential * 1000_f64 / insertions as f64, // time in μs - ); - let parallel_factor = compute_elapsed_sequential / compute_elapsed; - println!("Parallel implementation is {parallel_factor}x times faster."); - let now = Instant::now(); tree.apply_mutations(mutations)?; let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms println!( - "An average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "An average insert-batch application time measured by a 1k-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", apply_elapsed, apply_elapsed * 1000_f64 / insertions as f64, // time in μs ); @@ -186,40 +158,22 @@ pub fn batched_update( assert_eq!(new_pairs.len(), updates); - let cloned_new_pairs = new_pairs.clone(); let now = Instant::now(); - let mutations = tree.compute_mutations(cloned_new_pairs); + let mutations = tree.compute_mutations(new_pairs); let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms - let now = Instant::now(); - let mutations_sequential = tree.compute_mutations_sequential(new_pairs); - let compute_elapsed_sequential = now.elapsed().as_secs_f64() * 1000_f64; // time in ms - - assert_eq!(mutations.root(), mutations_sequential.root()); - assert_eq!(mutations.node_mutations(), mutations_sequential.node_mutations()); - assert_eq!(mutations.new_pairs(), mutations_sequential.new_pairs()); - let now = Instant::now(); tree.apply_mutations(mutations)?; let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms println!( - "An average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "An average update-batch computation time measured by a 1k-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", compute_elapsed, compute_elapsed * 1000_f64 / updates as f64, // time in μs ); println!( - "An average update-batch sequential computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", - compute_elapsed_sequential, - compute_elapsed_sequential * 1000_f64 / updates as f64, // time in μs - ); - - let factor = compute_elapsed_sequential / compute_elapsed; - println!("Parallel implementaton is {factor}x times faster."); - - println!( - "An average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "An average update-batch application time measured by a 1k-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", apply_elapsed, apply_elapsed * 1000_f64 / updates as f64, // time in μs ); From f42d597686a4f078324ed1aac67b53c7b001e4a0 Mon Sep 17 00:00:00 2001 From: Krushimir Date: Tue, 21 Jan 2025 21:03:50 +0100 Subject: [PATCH 03/13] chore: adds comment Co-authored-by: Philipp Gackstatter --- src/merkle/smt/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 684dbdbd..46e4ea11 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -602,6 +602,7 @@ pub(crate) trait SparseMerkleTree { let mut iter = leaves.drain(..).peekable(); while let Some(first_leaf) = iter.next() { + /// This constructs a valid index because next_depth will never exceed the depth of the tree. let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent(); let parent_node = self.get_inner_node(parent_index); let (left, right) = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); From a76506f89dccc328369f8a842e91e4e026185743 Mon Sep 17 00:00:00 2001 From: krushimir Date: Wed, 22 Jan 2025 22:42:39 +0100 Subject: [PATCH 04/13] chore: addressing comments --- CHANGELOG.md | 1 + src/merkle/smt/full/mod.rs | 8 -- src/merkle/smt/mod.rs | 157 +++++++++++++++++++++---------------- src/merkle/smt/tests.rs | 18 ++++- 4 files changed, 107 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb77cf68..3d2332e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.13.1 (2024-12-26) - Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355). +- Added parallel implementation of `Smt::compute_mutations` with better performance (#365). ## 0.13.0 (2024-11-24) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index a73ac468..d8f09027 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -255,14 +255,6 @@ impl Smt { >::compute_mutations(self, kv_pairs) } - /// Sequential implementation of [`Smt::compute_mutations()`]. - pub fn compute_mutations_sequential( - &self, - kv_pairs: impl IntoIterator, - ) -> MutationSet { - >::compute_mutations_sequential(self, kv_pairs) - } - /// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to /// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the /// updated tree will revert the changes. diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 46e4ea11..c5645d91 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -36,7 +36,7 @@ type UnorderedMap = alloc::collections::BTreeMap; type InnerNodes = UnorderedMap; type Leaves = UnorderedMap; type NodeMutations = UnorderedMap; -type MutatedLeavesResult = (Vec>, UnorderedMap); +type MutatedSubtreeLeaves = Vec>; /// An abstract description of a sparse Merkle tree. /// @@ -185,7 +185,7 @@ pub(crate) trait SparseMerkleTree { { #[cfg(feature = "concurrent")] { - self.compute_mutations_subtree(kv_pairs) + self.compute_mutations_concurrent(kv_pairs) } #[cfg(not(feature = "concurrent"))] { @@ -287,13 +287,21 @@ pub(crate) trait SparseMerkleTree { /// Parallel implementation of [`SparseMerkleTree::compute_mutations()`]. /// - /// This method recursively tracks mutations across 8-depth subtrees from the bottom up, - /// ultimately reconstructing the complete mutation set for the entire tree. + /// This method computes mutations by recursively processing subtrees in parallel, working from + /// the bottom up. For a tree of depth D with subtrees of depth 8, the process works as + /// follows: /// - /// The implementation is similar to [`SparseMerkleTree::build_subtrees_from_sorted_entries()`], - /// sharing the same constraint that the depth must be a multiple of 8. + /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf + /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. + /// + /// 2. The subtrees containing modifications are then processed in parallel: + /// - For each modified subtree, compute node mutations from depth D up to depth D-8 + /// - Each subtree computation yields a new root at depth D-8 and its associated mutations + /// + /// 3. These subtree roots become the "leaves" for the next iteration, which processes the next + /// 8 levels up. This continues until reaching the tree's root at depth 0. #[cfg(feature = "concurrent")] - fn compute_mutations_subtree( + fn compute_mutations_concurrent( &self, kv_pairs: impl IntoIterator, ) -> MutationSet @@ -307,7 +315,8 @@ pub(crate) trait SparseMerkleTree { sorted_kv_pairs.sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); // Convert sorted pairs into mutated leaves and capture any new pairs - let (mut subtree_leaves, new_pairs) = self.sorted_pairs_to_mutated_leaves(sorted_kv_pairs); + let (mut subtree_leaves, new_pairs) = + self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs); let mut node_mutations = NodeMutations::default(); // Process each depth level in reverse, stepping by the subtree depth @@ -521,63 +530,38 @@ pub(crate) trait SparseMerkleTree { ) -> Self::Leaf; /// Computes leaves from a set of key-value pairs and current leaf values. - /// Deried from `sorted_pairs_to_leaves` - /// - /// TODO: refactor and merge functionality with `sorted_pairs_to_leaves`? - fn sorted_pairs_to_mutated_leaves( + /// Derived from `sorted_pairs_to_leaves` + fn sorted_pairs_to_mutated_subtree_leaves( &self, pairs: Vec<(Self::Key, Self::Value)>, - ) -> MutatedLeavesResult { - debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); - - let mut accumulated_leaves = Vec::with_capacity(pairs.len() / 2); + ) -> (MutatedSubtreeLeaves, UnorderedMap) { + // Map to track new key-value pairs for mutated leaves let mut new_pairs = UnorderedMap::new(); - let mut current_leaf_buffer = Vec::new(); - let mut iter = pairs.into_iter().peekable(); - while let Some((key, value)) = iter.next() { - let col = Self::key_to_leaf_index(&key).index.value(); - - if let Some((next_key, _)) = iter.peek() { - let next_col = Self::key_to_leaf_index(next_key).index.value(); - debug_assert!(next_col >= col); - } - - current_leaf_buffer.push((key.clone(), value)); - - // If the next pair is the same column, continue accumulating - if iter - .peek() - .is_some_and(|(next_key, _)| Self::key_to_leaf_index(next_key).index.value() == col) - { - continue; - } - - // Process buffered pairs - let leaf_pairs = mem::take(&mut current_leaf_buffer); - let mut leaf = self.get_leaf(&key); + let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| { + let mut leaf = self.get_leaf(&leaf_pairs[0].0); for (key, value) in leaf_pairs { - match new_pairs.get(&key) { - Some(existing_value) if existing_value == &value => continue, - _ => { - leaf = self.construct_prospective_leaf(leaf, &key, &value); - new_pairs.insert(key, value); - }, - } - } + // Check if the value has changed + let old_value = + new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); - let hash = Self::hash_leaf(&leaf); - accumulated_leaves.push(SubtreeLeaf { col, hash }); + // Skip if the value hasn't changed + if value == old_value { + continue; + } - debug_assert!(current_leaf_buffer.is_empty()); - } + // Otherwise, update the leaf and track the new key-value pair + leaf = self.construct_prospective_leaf(leaf, &key, &value); + new_pairs.insert(key, value); + } - let leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); - (leaves, new_pairs) + leaf + }); + (accumulator.leaves, new_pairs) } - // Computes the node mutations and the root of a subtree + /// Computes the node mutations and the root of a subtree fn build_subtree_mutations( &self, mut leaves: Vec, @@ -602,12 +586,11 @@ pub(crate) trait SparseMerkleTree { let mut iter = leaves.drain(..).peekable(); while let Some(first_leaf) = iter.next() { - /// This constructs a valid index because next_depth will never exceed the depth of the tree. + // This constructs a valid index because next_depth will never exceed the depth of + // the tree. let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent(); let parent_node = self.get_inner_node(parent_index); - let (left, right) = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); - - let combined_node = InnerNode { left: left.hash, right: right.hash }; + let combined_node = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); let combined_hash = combined_node.hash(); let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth); @@ -636,16 +619,17 @@ pub(crate) trait SparseMerkleTree { (node_mutations, root_leaf) } - // Returns the sibling pair based on the first leaf and the current depth - // - // This is a helper function that is used to build the subtree mutations - // The first leaf is the leaf that we are currently processing - // The current depth is the depth of the current subtree + /// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part: + /// - If `first_leaf` is a right child, the left child is copied from the `parent_node`. + /// - If `first_leaf` is a left child, the right child is taken from `iter` if it also mutated + /// or copied from the `parent_node`. + /// + /// Returns the `InnerNode` containing the hashes of the sibling pair. fn fetch_sibling_pair( iter: &mut core::iter::Peekable>, first_leaf: SubtreeLeaf, parent_node: InnerNode, - ) -> (SubtreeLeaf, SubtreeLeaf) { + ) -> InnerNode { let is_right_node = first_leaf.col.is_odd(); if is_right_node { @@ -653,14 +637,20 @@ pub(crate) trait SparseMerkleTree { col: first_leaf.col - 1, hash: parent_node.left, }; - (left_leaf, first_leaf) + InnerNode { + left: left_leaf.hash, + right: first_leaf.hash, + } } else { let right_col = first_leaf.col + 1; let right_leaf = match iter.peek().copied() { Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(), _ => SubtreeLeaf { col: right_col, hash: parent_node.right }, }; - (first_leaf, right_leaf) + InnerNode { + left: first_leaf.hash, + right: right_leaf.hash, + } } } @@ -689,6 +679,39 @@ pub(crate) trait SparseMerkleTree { fn sorted_pairs_to_leaves( pairs: Vec<(Self::Key, Self::Value)>, ) -> PairComputations { + Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| Self::pairs_to_leaf(leaf_pairs)) + } + + /// Processes sorted key-value pairs to compute leaves for a subtree. + /// + /// This function groups key-value pairs by their corresponding column index and processes each + /// group to construct leaves. The actual construction of the leaf is delegated to the + /// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating + /// new leaves or mutating existing ones). + /// + /// # Parameters + /// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index + /// column (not simply by key). If the input is not sorted correctly, the function will + /// produce incorrect results and may panic in debug mode. + /// - `process_leaf`: A callback function used to process each group of key-value pairs + /// corresponding to the same column index. The callback takes a vector of key-value pairs for + /// a single column and returns the constructed leaf for that column. + /// + /// # Returns + /// A `PairComputations` containing: + /// - `nodes`: A mapping of column indices to the constructed leaves. + /// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each + /// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf. + /// + /// # Panics + /// This function will panic in debug mode if the input `pairs` are not sorted by column index. + fn process_sorted_pairs_to_leaves( + pairs: Vec<(Self::Key, Self::Value)>, + mut process_leaf: F, + ) -> PairComputations + where + F: FnMut(Vec<(Self::Key, Self::Value)>) -> Self::Leaf, + { debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); let mut accumulator: PairComputations = Default::default(); @@ -720,7 +743,7 @@ pub(crate) trait SparseMerkleTree { // Otherwise, the next pair is a different column, or there is no next pair. Either way // it's time to swap out our buffer. let leaf_pairs = mem::take(&mut current_leaf_buffer); - let leaf = Self::pairs_to_leaf(leaf_pairs); + let leaf = process_leaf(leaf_pairs); let hash = Self::hash_leaf(&leaf); accumulator.nodes.insert(col, leaf); diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index 711a317c..3d17d0ec 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -1,4 +1,7 @@ -use alloc::{collections::BTreeMap, vec::Vec}; +use alloc::{ + collections::{BTreeMap, BTreeSet}, + vec::Vec, +}; use rand::{prelude::IteratorRandom, thread_rng, Rng}; @@ -115,6 +118,17 @@ fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(Rpo const REMOVAL_PROBABILITY: f64 = 0.2; let mut rng = thread_rng(); + // Assertion to ensure input keys are unique + assert!( + entries + .iter() + .map(|(key, _)| key) + .collect::>() + .len() + == entries.len(), + "Input entries contain duplicate keys!" + ); + let mut sorted_entries: Vec<(RpoDigest, Word)> = entries .into_iter() .choose_multiple(&mut rng, updates) @@ -452,7 +466,7 @@ fn test_singlethreaded_subtree_mutations() { let mut node_mutations = NodeMutations::default(); - let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_leaves(updates); + let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates); for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { // There's no flat_map_unzip(), so this is the best we can do. From ec35f28dd344feb83ac4db14a64060286b3b250f Mon Sep 17 00:00:00 2001 From: Krushimir Date: Thu, 23 Jan 2025 17:32:55 +0100 Subject: [PATCH 05/13] chore: update docs Co-authored-by: Philipp Gackstatter --- src/merkle/smt/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index c5645d91..10eea88e 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -621,7 +621,7 @@ pub(crate) trait SparseMerkleTree { /// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part: /// - If `first_leaf` is a right child, the left child is copied from the `parent_node`. - /// - If `first_leaf` is a left child, the right child is taken from `iter` if it also mutated + /// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also mutated /// or copied from the `parent_node`. /// /// Returns the `InnerNode` containing the hashes of the sibling pair. From e89daa9f8e41bec12f257ebb6da81c0b488f2b5d Mon Sep 17 00:00:00 2001 From: krushimir Date: Thu, 23 Jan 2025 17:41:05 +0100 Subject: [PATCH 06/13] chore: linting and addressing comments --- src/main.rs | 48 +++++++++++++++++++++-------------------- src/merkle/smt/mod.rs | 4 ++-- src/merkle/smt/tests.rs | 7 +----- 3 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/main.rs b/src/main.rs index 07bdf15b..b4a35f52 100644 --- a/src/main.rs +++ b/src/main.rs @@ -56,7 +56,9 @@ pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result Result<(), MerkleError> { } println!( - "An average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n", + "The average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n", // calculate the average insertion_times.iter().sum::() as f64 / (insertions as f64), ); @@ -104,7 +106,7 @@ pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), Merkle let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms println!( - "An average insert-batch computation time measured by a 1k-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "The average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", compute_elapsed, compute_elapsed * 1000_f64 / insertions as f64, // time in μs ); @@ -114,13 +116,13 @@ pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), Merkle let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms println!( - "An average insert-batch application time measured by a 1k-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "The average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", apply_elapsed, apply_elapsed * 1000_f64 / insertions as f64, // time in μs ); println!( - "An average batch insertion time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms", + "The average batch insertion time measured by a {insertions}-batch into an SMT with {size} leaves totals to {:.1} ms", (compute_elapsed + apply_elapsed), ); @@ -141,20 +143,20 @@ pub fn batched_update( let size = tree.num_leaves(); let mut rng = thread_rng(); - let new_pairs: Vec<(RpoDigest, Word)> = entries - .into_iter() - .choose_multiple(&mut rng, updates) - .into_iter() - .map(|(key, _)| { - let value = if rng.gen_bool(REMOVAL_PROBABILITY) { - EMPTY_WORD - } else { - [ONE, ONE, ONE, Felt::new(rng.gen())] - }; - - (key, value) - }) - .collect(); + let new_pairs = + entries + .into_iter() + .choose_multiple(&mut rng, updates) + .into_iter() + .map(|(key, _)| { + let value = if rng.gen_bool(REMOVAL_PROBABILITY) { + EMPTY_WORD + } else { + [ONE, ONE, ONE, Felt::new(rng.gen())] + }; + + (key, value) + }); assert_eq!(new_pairs.len(), updates); @@ -167,19 +169,19 @@ pub fn batched_update( let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms println!( - "An average update-batch computation time measured by a 1k-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "The average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", compute_elapsed, compute_elapsed * 1000_f64 / updates as f64, // time in μs ); println!( - "An average update-batch application time measured by a 1k-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "The average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", apply_elapsed, apply_elapsed * 1000_f64 / updates as f64, // time in μs ); println!( - "An average batch update time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms", + "The average batch update time measured by a {updates}-batch into an SMT with {size} leaves totals to {:.1} ms", (compute_elapsed + apply_elapsed), ); @@ -208,7 +210,7 @@ pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> { } println!( - "An average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs", + "The average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs", // calculate the average insertion_times.iter().sum::() as f64 / (NUM_PROOFS as f64), ); diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 10eea88e..637f15f1 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -621,8 +621,8 @@ pub(crate) trait SparseMerkleTree { /// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part: /// - If `first_leaf` is a right child, the left child is copied from the `parent_node`. - /// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also mutated - /// or copied from the `parent_node`. + /// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also + /// mutated or copied from the `parent_node`. /// /// Returns the `InnerNode` containing the hashes of the sibling pair. fn fetch_sibling_pair( diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index 3d17d0ec..0e4a893c 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -120,12 +120,7 @@ fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(Rpo // Assertion to ensure input keys are unique assert!( - entries - .iter() - .map(|(key, _)| key) - .collect::>() - .len() - == entries.len(), + entries.iter().map(|(key, _)| key).collect::>().len() == entries.len(), "Input entries contain duplicate keys!" ); From 17b03a87f08875f4e12f4ce6357c3cc07ef6796f Mon Sep 17 00:00:00 2001 From: krushimir Date: Mon, 27 Jan 2025 21:23:02 +0100 Subject: [PATCH 07/13] docs: `SimpleSmt::compute_mutations` note --- src/merkle/smt/simple/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 67215adf..d6f0933e 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -218,6 +218,10 @@ impl SimpleSmt { /// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the /// Merkle tree, or [`drop()`] to discard them. /// + /// **Note:** Parallel computation is only supported for trees whose depth is a multiple of 8. + /// Since `SimpleSmt` can have a depth that isn't a multiple of 8, this method defaults to the + /// sequential implementation. + /// /// # Example /// ``` /// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word}; From 1eb776957142e97388b60dacecebe8923f035fb9 Mon Sep 17 00:00:00 2001 From: krushimir Date: Wed, 29 Jan 2025 21:48:23 +0100 Subject: [PATCH 08/13] chore: addressing comments --- src/merkle/smt/full/mod.rs | 4 +--- src/merkle/smt/mod.rs | 6 +++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index d8f09027..5cd641e4 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -255,9 +255,7 @@ impl Smt { >::compute_mutations(self, kv_pairs) } - /// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to - /// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the - /// updated tree will revert the changes. + /// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree. /// /// # Errors /// If `mutations` was computed on a tree with a different root than this one, returns diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 637f15f1..9e183231 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -182,6 +182,8 @@ pub(crate) trait SparseMerkleTree { ) -> MutationSet where Self: Sized + Sync, + Self::Key: Send + Sync, + Self::Value: Send + Sync, { #[cfg(feature = "concurrent")] { @@ -307,12 +309,14 @@ pub(crate) trait SparseMerkleTree { ) -> MutationSet where Self: Sized + Sync, + Self::Key: Send + Sync, + Self::Value: Send + Sync, { use rayon::prelude::*; // Collect and sort key-value pairs by their corresponding leaf index let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); - sorted_kv_pairs.sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); + sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); // Convert sorted pairs into mutated leaves and capture any new pairs let (mut subtree_leaves, new_pairs) = From 4f6f43186df1904b7c04ee66a1c7e2a23d96fa85 Mon Sep 17 00:00:00 2001 From: krushimir Date: Thu, 30 Jan 2025 13:16:37 +0100 Subject: [PATCH 09/13] chore: change the benchmark params default values --- src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main.rs b/src/main.rs index b4a35f52..83daef67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,13 +13,13 @@ use rand_utils::rand_value; #[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")] pub struct BenchmarkCmd { /// Size of the tree - #[clap(short = 's', long = "size", default_value = "10000")] + #[clap(short = 's', long = "size", default_value = "1000000")] size: usize, /// Number of insertions - #[clap(short = 'i', long = "insertions", default_value = "10000")] + #[clap(short = 'i', long = "insertions", default_value = "1000")] insertions: usize, /// Number of updates - #[clap(short = 'u', long = "updates", default_value = "10000")] + #[clap(short = 'u', long = "updates", default_value = "1000")] updates: usize, } From b119df07c0f3d8b13b6a1e6ef143db50ab8ed6a5 Mon Sep 17 00:00:00 2001 From: krushimir Date: Wed, 5 Feb 2025 21:36:24 +0100 Subject: [PATCH 10/13] chore: refactor concurrent implementations --- CHANGELOG.md | 1 + Cargo.toml | 2 +- src/merkle/mod.rs | 4 +- src/merkle/smt/full/mod.rs | 601 +++++++++++++++++++++++++++++++++-- src/merkle/smt/full/tests.rs | 554 ++++++++++++++++++++++++++++++++ src/merkle/smt/mod.rs | 565 +------------------------------- src/merkle/smt/simple/mod.rs | 2 +- src/merkle/smt/tests.rs | 542 ------------------------------- 8 files changed, 1149 insertions(+), 1122 deletions(-) delete mode 100644 src/merkle/smt/tests.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index c34ef4c7..b7f7a8b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355). - Added parallel implementation of `Smt::compute_mutations` with better performance (#365). +- Implemented parallel leaf hashing in `Smt::process_sorted_pairs_to_leaves` (#365). ## 0.13.0 (2024-11-24) diff --git a/Cargo.toml b/Cargo.toml index be393abf..febb4715 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ name = "store" harness = false [features] -concurrent = ["dep:rayon"] +concurrent = ["dep:rayon", "hashbrown?/rayon"] default = ["std", "concurrent"] executable = ["dep:clap", "dep:rand-utils", "std"] smt_hashmaps = ["dep:hashbrown"] diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 2fdbc6ba..3c295a4c 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -23,9 +23,11 @@ pub use path::{MerklePath, RootPath, ValuePath}; mod smt; #[cfg(feature = "internal")] pub use smt::build_subtree_for_bench; +#[cfg(feature = "internal")] +pub use smt::SubtreeLeaf; pub use smt::{ LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, - SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, + SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; mod mmr; diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 5cd641e4..04ed7124 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -1,8 +1,12 @@ use alloc::{collections::BTreeSet, string::ToString, vec::Vec}; +use core::mem; + +use num::Integer; use super::{ EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError, - MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + MerklePath, MutationSet, NodeIndex, NodeMutation, NodeMutations, Rpo256, RpoDigest, + SparseMerkleTree, UnorderedMap, Word, EMPTY_WORD, }; mod error; @@ -27,6 +31,7 @@ pub const SMT_DEPTH: u8 = 64; // ================================================================================================ type Leaves = super::Leaves; +type MutatedSubtreeLeaves = Vec>; /// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented /// by 4 field elements. @@ -81,23 +86,7 @@ impl Smt { ) -> Result { #[cfg(feature = "concurrent")] { - let mut seen_keys = BTreeSet::new(); - let entries: Vec<_> = entries - .into_iter() - .map(|(key, value)| { - if seen_keys.insert(key) { - Ok((key, value)) - } else { - Err(MerkleError::DuplicateValuesForIndex( - LeafIndex::::from(key).value(), - )) - } - }) - .collect::>()?; - if entries.is_empty() { - return Ok(Self::default()); - } - >::with_entries_par(entries) + Self::with_entries_concurrent(entries) } #[cfg(not(feature = "concurrent"))] { @@ -252,7 +241,14 @@ impl Smt { &self, kv_pairs: impl IntoIterator, ) -> MutationSet { - >::compute_mutations(self, kv_pairs) + #[cfg(feature = "concurrent")] + { + self.compute_mutations_concurrent(kv_pairs) + } + #[cfg(not(feature = "concurrent"))] + { + >::compute_mutations(self, kv_pairs) + } } /// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree. @@ -323,6 +319,382 @@ impl Smt { } } +// Concurrent implementation +#[cfg(feature = "concurrent")] +impl Smt { + /// Parallel implementation of [`Smt::with_entries()`]. + /// + /// This method constructs a new sparse Merkle tree concurrently by processing subtrees in + /// parallel, working from the bottom up. The process works as follows: + /// + /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf + /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. + /// + /// 2. The subtrees are then processed in parallel: + /// - For each subtree, compute the inner nodes from depth D down to depth D-8. + /// - Each subtree computation yields a new subtree root and its associated inner nodes. + /// + /// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration, + /// which processes the next 8 levels up. This continues until the final root of the tree is + /// computed at depth 0. + pub fn with_entries_concurrent( + entries: impl IntoIterator, + ) -> Result { + let mut seen_keys = BTreeSet::new(); + let entries: Vec<_> = entries + .into_iter() + .map(|(key, value)| { + if seen_keys.insert(key) { + Ok((key, value)) + } else { + Err(MerkleError::DuplicateValuesForIndex( + LeafIndex::::from(key).value(), + )) + } + }) + .collect::>()?; + if entries.is_empty() { + return Ok(Self::default()); + } + let (inner_nodes, leaves) = Self::build_subtrees(entries); + let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); + >::from_raw_parts(inner_nodes, leaves, root) + } + + /// Parallel implementation of [`Smt::compute_mutations()`]. + /// + /// This method computes mutations by recursively processing subtrees in parallel, working from + /// the bottom up. The process works as follows: + /// + /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf + /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. + /// + /// 2. The subtrees containing modifications are then processed in parallel: + /// - For each modified subtree, compute node mutations from depth D up to depth D-8 + /// - Each subtree computation yields a new root at depth D-8 and its associated mutations + /// + /// 3. These subtree roots become the "leaves" for the next iteration, which processes the next + /// 8 levels up. This continues until reaching the tree's root at depth 0. + pub fn compute_mutations_concurrent( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet + where + Self: Sized + Sync, + { + use rayon::prelude::*; + + // Collect and sort key-value pairs by their corresponding leaf index + let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); + sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); + + // Convert sorted pairs into mutated leaves and capture any new pairs + let (mut subtree_leaves, new_pairs) = + self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs); + let mut node_mutations = NodeMutations::default(); + + // Process each depth level in reverse, stepping by the subtree depth + for depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // Parallel processing of each subtree to generate mutations and roots + let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted() && !subtree.is_empty()); + self.build_subtree_mutations(subtree, SMT_DEPTH, depth) + }) + .unzip(); + + // Prepare leaves for the next depth level + subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + + // Aggregate all node mutations + node_mutations.extend(mutations_per_subtree.into_iter().flatten()); + + debug_assert!(!subtree_leaves.is_empty()); + } + + // Finalize the mutation set with updated roots and mutations + MutationSet { + old_root: self.root(), + new_root: subtree_leaves[0][0].hash, + node_mutations, + new_pairs, + } + } + + /// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing + /// subtrees. In other words, this function takes the key-value inputs to the tree, and produces + /// the inputs to feed into [`build_subtree()`]. + /// + /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If + /// `pairs` is not correctly sorted, the returned computations will be incorrect. + /// + /// # Panics + /// With debug assertions on, this function panics if it detects that `pairs` is not correctly + /// sorted. Without debug assertions, the returned computations will be incorrect. + fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations { + Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf) + } + + /// Computes leaves from a set of key-value pairs and current leaf values. + /// Derived from `sorted_pairs_to_leaves` + fn sorted_pairs_to_mutated_subtree_leaves( + &self, + pairs: Vec<(RpoDigest, Word)>, + ) -> (MutatedSubtreeLeaves, UnorderedMap) { + // Map to track new key-value pairs for mutated leaves + let mut new_pairs = UnorderedMap::new(); + + let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| { + let mut leaf = self.get_leaf(&leaf_pairs[0].0); + + for (key, value) in leaf_pairs { + // Check if the value has changed + let old_value = + new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); + + // Skip if the value hasn't changed + if value == old_value { + continue; + } + + // Otherwise, update the leaf and track the new key-value pair + leaf = self.construct_prospective_leaf(leaf, &key, &value); + new_pairs.insert(key, value); + } + + leaf + }); + (accumulator.leaves, new_pairs) + } + + /// Computes the node mutations and the root of a subtree + fn build_subtree_mutations( + &self, + mut leaves: Vec, + tree_depth: u8, + bottom_depth: u8, + ) -> (NodeMutations, SubtreeLeaf) + where + Self: Sized, + { + debug_assert!(bottom_depth <= tree_depth); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); + debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); + + let subtree_root_depth = bottom_depth - SUBTREE_DEPTH; + let mut node_mutations: NodeMutations = Default::default(); + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + + for current_depth in (subtree_root_depth..bottom_depth).rev() { + debug_assert!(current_depth <= bottom_depth); + + let next_depth = current_depth + 1; + let mut iter = leaves.drain(..).peekable(); + + while let Some(first_leaf) = iter.next() { + // This constructs a valid index because next_depth will never exceed the depth of + // the tree. + let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent(); + let parent_node = self.get_inner_node(parent_index); + let combined_node = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); + let combined_hash = combined_node.hash(); + + let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth); + + // Add the parent node even if it is empty for proper upward updates + next_leaves.push(SubtreeLeaf { + col: parent_index.value(), + hash: combined_hash, + }); + + node_mutations.insert( + parent_index, + if combined_hash != empty_hash { + NodeMutation::Addition(combined_node) + } else { + NodeMutation::Removal + }, + ); + } + drop(iter); + leaves = mem::take(&mut next_leaves); + } + + debug_assert_eq!(leaves.len(), 1); + let root_leaf = leaves.pop().unwrap(); + (node_mutations, root_leaf) + } + + /// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part: + /// - If `first_leaf` is a right child, the left child is copied from the `parent_node`. + /// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also + /// mutated or copied from the `parent_node`. + /// + /// Returns the `InnerNode` containing the hashes of the sibling pair. + fn fetch_sibling_pair( + iter: &mut core::iter::Peekable>, + first_leaf: SubtreeLeaf, + parent_node: InnerNode, + ) -> InnerNode { + let is_right_node = first_leaf.col.is_odd(); + + if is_right_node { + let left_leaf = SubtreeLeaf { + col: first_leaf.col - 1, + hash: parent_node.left, + }; + InnerNode { + left: left_leaf.hash, + right: first_leaf.hash, + } + } else { + let right_col = first_leaf.col + 1; + let right_leaf = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(), + _ => SubtreeLeaf { col: right_col, hash: parent_node.right }, + }; + InnerNode { + left: first_leaf.hash, + right: right_leaf.hash, + } + } + } + + /// Processes sorted key-value pairs to compute leaves for a subtree. + /// + /// This function groups key-value pairs by their corresponding column index and processes each + /// group to construct leaves. The actual construction of the leaf is delegated to the + /// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating + /// new leaves or mutating existing ones). + /// + /// # Parameters + /// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index + /// column (not simply by key). If the input is not sorted correctly, the function will + /// produce incorrect results and may panic in debug mode. + /// - `process_leaf`: A callback function used to process each group of key-value pairs + /// corresponding to the same column index. The callback takes a vector of key-value pairs for + /// a single column and returns the constructed leaf for that column. + /// + /// # Returns + /// A `PairComputations` containing: + /// - `nodes`: A mapping of column indices to the constructed leaves. + /// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each + /// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf. + /// + /// # Panics + /// This function will panic in debug mode if the input `pairs` are not sorted by column index. + fn process_sorted_pairs_to_leaves( + pairs: Vec<(RpoDigest, Word)>, + mut process_leaf: F, + ) -> PairComputations + where + F: FnMut(Vec<(RpoDigest, Word)>) -> SmtLeaf, + { + use rayon::prelude::*; + debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); + + let mut accumulator: PairComputations = Default::default(); + + // As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a + // single leaf. When we see a pair that's in a different leaf, we'll swap these pairs + // out and store them in our accumulated leaves. + let mut current_leaf_buffer: Vec<(RpoDigest, Word)> = Default::default(); + + let mut iter = pairs.into_iter().peekable(); + while let Some((key, value)) = iter.next() { + let col = Self::key_to_leaf_index(&key).index.value(); + let peeked_col = iter.peek().map(|(key, _v)| { + let index = Self::key_to_leaf_index(key); + let next_col = index.index.value(); + // We panic if `pairs` is not sorted by column. + debug_assert!(next_col >= col); + next_col + }); + current_leaf_buffer.push((key, value)); + + // If the next pair is the same column as this one, then we're done after adding this + // pair to the buffer. + if peeked_col == Some(col) { + continue; + } + + // Otherwise, the next pair is a different column, or there is no next pair. Either way + // it's time to swap out our buffer. + let leaf_pairs = mem::take(&mut current_leaf_buffer); + let leaf = process_leaf(leaf_pairs); + + accumulator.nodes.insert(col, leaf); + + debug_assert!(current_leaf_buffer.is_empty()); + } + + // Compute the leaves from the nodes concurrently + let mut accumulated_leaves: Vec = accumulator + .nodes + .clone() + .into_par_iter() + .map(|(col, leaf)| SubtreeLeaf { col, hash: Self::hash_leaf(&leaf) }) + .collect(); + + // Sort the leaves by column + accumulated_leaves.par_sort_by_key(|leaf| leaf.col); + + // TODO: determine is there is any notable performance difference between computing + // subtree boundaries after the fact as an iterator adapter (like this), versus computing + // subtree boundaries as we go. Either way this function is only used at the beginning of a + // parallel construction, so it should not be a critical path. + accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); + accumulator + } + + /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. + /// + /// `entries` need not be sorted. This function will sort them. + fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { + entries.sort_by_key(|item| { + let index = Self::key_to_leaf_index(&item.0); + index.value() + }); + Self::build_subtrees_from_sorted_entries(entries) + } + + /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. + /// + /// This function is mostly an implementation detail of + /// [`Smt::with_entries_concurrent()`]. + fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { + use rayon::prelude::*; + + let mut accumulated_nodes: InnerNodes = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: initial_leaves, + } = Self::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + let (nodes, mut subtree_roots): (Vec>, Vec) = + leaf_subtrees + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted()); + debug_assert!(!subtree.is_empty()); + let (nodes, subtree_root) = + build_subtree(subtree, SMT_DEPTH, current_depth); + (nodes, subtree_root) + }) + .unzip(); + + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + debug_assert!(!leaf_subtrees.is_empty()); + } + (accumulated_nodes, initial_leaves) + } +} + impl SparseMerkleTree for Smt { type Key = RpoDigest; type Value = Word; @@ -542,3 +914,194 @@ fn test_smt_serialization_deserialization() { assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap()); assert_eq!(bytes.len(), smt.get_size_hint()); } + +// SUBTREES +// ================================================================================================ + +/// A subtree is of depth 8. +const SUBTREE_DEPTH: u8 = 8; + +/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. +const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32); + +/// Helper struct for organizing the data we care about when computing Merkle subtrees. +/// +/// Note that these represet "conceptual" leaves of some subtree, not necessarily +/// the leaf type for the sparse Merkle tree. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct SubtreeLeaf { + /// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known. + pub col: u64, + /// The hash of the node this `SubtreeLeaf` represents. + pub hash: RpoDigest, +} + +/// Helper struct to organize the return value of [`Smt::sorted_pairs_to_leaves()`]. +#[derive(Debug, Clone)] +pub(crate) struct PairComputations { + /// Literal leaves to be added to the sparse Merkle tree's internal mapping. + pub nodes: UnorderedMap, + /// "Conceptual" leaves that will be used for computations. + pub leaves: Vec>, +} + +// Derive requires `L` to impl Default, even though we don't actually need that. +impl Default for PairComputations { + fn default() -> Self { + Self { + nodes: Default::default(), + leaves: Default::default(), + } + } +} + +#[derive(Debug)] +struct SubtreeLeavesIter<'s> { + leaves: core::iter::Peekable>, +} +impl<'s> SubtreeLeavesIter<'s> { + fn from_leaves(leaves: &'s mut Vec) -> Self { + // TODO: determine if there is any notable performance difference between taking a Vec, + // which many need flattening first, vs storing a `Box>`. + // The latter may have self-referential properties that are impossible to express in purely + // safe Rust Rust. + Self { leaves: leaves.drain(..).peekable() } + } +} +impl Iterator for SubtreeLeavesIter<'_> { + type Item = Vec; + + /// Each `next()` collects an entire subtree. + fn next(&mut self) -> Option> { + let mut subtree: Vec = Default::default(); + + let mut last_subtree_col = 0; + + while let Some(leaf) = self.leaves.peek() { + last_subtree_col = u64::max(1, last_subtree_col); + let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE); + let next_subtree_col = if is_exact_multiple { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) + }; + + last_subtree_col = leaf.col; + if leaf.col < next_subtree_col { + subtree.push(self.leaves.next().unwrap()); + } else if subtree.is_empty() { + continue; + } else { + break; + } + } + + if subtree.is_empty() { + debug_assert!(self.leaves.peek().is_none()); + return None; + } + + Some(subtree) + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and +/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and +/// `leaves` must not contain more than one depth-8 subtree's worth of leaves. +/// +/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as +/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into +/// itself. +/// +/// # Panics +/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains +/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to +/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified +/// maximum depth (`DEPTH`), or if `leaves` is not sorted. +#[cfg(feature = "concurrent")] +fn build_subtree( + mut leaves: Vec, + tree_depth: u8, + bottom_depth: u8, +) -> (UnorderedMap, SubtreeLeaf) { + debug_assert!(bottom_depth <= tree_depth); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); + debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); + let subtree_root = bottom_depth - SUBTREE_DEPTH; + let mut inner_nodes: UnorderedMap = Default::default(); + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + for next_depth in (subtree_root..bottom_depth).rev() { + debug_assert!(next_depth <= bottom_depth); + // `next_depth` is the stuff we're making. + // `current_depth` is the stuff we have. + let current_depth = next_depth + 1; + let mut iter = leaves.drain(..).peekable(); + while let Some(first) = iter.next() { + // On non-continuous iterations, including the first iteration, `first_column` may + // be a left or right node. On subsequent continuous iterations, we will always call + // `iter.next()` twice. + // On non-continuous iterations (including the very first iteration), this column + // could be either on the left or the right. If the next iteration is not + // discontinuous with our right node, then the next iteration's + let is_right = first.col.is_odd(); + let (left, right) = if is_right { + // Discontinuous iteration: we have no left node, so it must be empty. + let left = SubtreeLeaf { + col: first.col - 1, + hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), + }; + let right = first; + (left, right) + } else { + let left = first; + let right_col = first.col + 1; + let right = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => { + // Our inputs must be sorted. + debug_assert!(left.col <= col); + // The next leaf in the iterator is our sibling. Use it and consume it! + iter.next().unwrap() + }, + // Otherwise, the leaves don't contain our sibling, so our sibling must be + // empty. + _ => SubtreeLeaf { + col: right_col, + hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), + }, + }; + (left, right) + }; + let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); + let node = InnerNode { left: left.hash, right: right.hash }; + let hash = node.hash(); + let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth); + // If this hash is empty, then it doesn't become a new inner node, nor does it count + // as a leaf for the next depth. + if hash != equivalent_empty_hash { + inner_nodes.insert(index, node); + next_leaves.push(SubtreeLeaf { col: index.value(), hash }); + } + } + // Stop borrowing `leaves`, so we can swap it. + // The iterator is empty at this point anyway. + drop(iter); + // After each depth, consider the stuff we just made the new "leaves", and empty the + // other collection. + mem::swap(&mut leaves, &mut next_leaves); + } + debug_assert_eq!(leaves.len(), 1); + let root = leaves.pop().unwrap(); + (inner_nodes, root) +} + +#[cfg(feature = "internal")] +pub fn build_subtree_for_bench( + leaves: Vec, + tree_depth: u8, + bottom_depth: u8, +) -> (UnorderedMap, SubtreeLeaf) { + build_subtree(leaves, tree_depth, bottom_depth) +} diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index 787c01a8..0fe14076 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -722,3 +722,557 @@ fn apply_mutations( reversion } + +// CONCURRENT TESTS +// -------------------------------------------------------------------------------------------- + +#[cfg(feature = "concurrent")] +mod concurrent_tests { + use alloc::{ + collections::{BTreeMap, BTreeSet}, + vec::Vec, + }; + + use rand::{prelude::IteratorRandom, thread_rng, Rng}; + + use super::*; + use crate::{ + merkle::smt::{ + full::{ + build_subtree, PairComputations, SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, + SUBTREE_DEPTH, + }, + InnerNode, NodeMutations, SparseMerkleTree, UnorderedMap, + }, + Word, ONE, + }; + + fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { + SubtreeLeaf { + col: leaf.index().index.value(), + hash: leaf.hash(), + } + } + + #[test] + fn test_sorted_pairs_to_leaves() { + let entries: Vec<(RpoDigest, Word)> = vec![ + // Subtree 0. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), + // Leaf index collision. + (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), + // Subtree 1. Normal single leaf again. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]), + // Subtree 2. Another normal leaf. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]), + ]; + + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + + let control_leaves: Vec = { + let mut entries_iter = entries.iter().cloned(); + let mut next_entry = || entries_iter.next().unwrap(); + let control_leaves = vec![ + // Subtree 0. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), + // Subtree 1. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + // Subtree 2. + SmtLeaf::Single(next_entry()), + ]; + assert_eq!(entries_iter.next(), None); + control_leaves + }; + + let control_subtree_leaves: Vec> = { + let mut control_leaves_iter = control_leaves.iter(); + let mut next_leaf = || control_leaves_iter.next().unwrap(); + + let control_subtree_leaves: Vec> = [ + // Subtree 0. + vec![next_leaf(), next_leaf(), next_leaf()], + // Subtree 1. + vec![next_leaf(), next_leaf()], + // Subtree 2. + vec![next_leaf()], + ] + .map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect()) + .to_vec(); + assert_eq!(control_leaves_iter.next(), None); + control_subtree_leaves + }; + + let subtrees: PairComputations = Smt::sorted_pairs_to_leaves(entries); + // This will check that the hashes, columns, and subtree assignments all match. + assert_eq!(subtrees.leaves, control_subtree_leaves); + + // Flattening and re-separating out the leaves into subtrees should have the same result. + let mut all_leaves: Vec = + subtrees.leaves.clone().into_iter().flatten().collect(); + let re_grouped: Vec> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + assert_eq!(subtrees.leaves, re_grouped); + + // Then finally we might as well check the computed leaf nodes too. + let control_leaves: BTreeMap = control + .leaves() + .map(|(index, value)| (index.index.value(), value.clone())) + .collect(); + + for (column, test_leaf) in subtrees.nodes { + if test_leaf.is_empty() { + continue; + } + let control_leaf = control_leaves + .get(&column) + .unwrap_or_else(|| panic!("no leaf node found for column {column}")); + assert_eq!(control_leaf, &test_leaf); + } + } + + // Helper for the below tests. + fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { + (0..pair_count) + .map(|i| { + let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64; + let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); + let value = [ONE, ONE, ONE, Felt::new(i)]; + (key, value) + }) + .collect() + } + + fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> { + const REMOVAL_PROBABILITY: f64 = 0.2; + let mut rng = thread_rng(); + + // Assertion to ensure input keys are unique + assert!( + entries.iter().map(|(key, _)| key).collect::>().len() == entries.len(), + "Input entries contain duplicate keys!" + ); + + let mut sorted_entries: Vec<(RpoDigest, Word)> = entries + .into_iter() + .choose_multiple(&mut rng, updates) + .into_iter() + .map(|(key, _)| { + let value = if rng.gen_bool(REMOVAL_PROBABILITY) { + EMPTY_WORD + } else { + [ONE, ONE, ONE, Felt::new(rng.gen())] + }; + + (key, value) + }) + .collect(); + sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value()); + sorted_entries + } + + #[test] + fn test_single_subtree() { + // A single subtree's worth of leaves. + const PAIR_COUNT: u64 = COLS_PER_SUBTREE; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + + // `entries` should already be sorted by nature of how we constructed it. + let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; + let leaves = leaves.into_iter().next().unwrap(); + + let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH); + assert!(!first_subtree.is_empty()); + + // The inner nodes computed from that subtree should match the nodes in our control tree. + for (index, node) in first_subtree.into_iter() { + let control = control.get_inner_node(index); + assert_eq!( + control, node, + "subtree-computed node at index {index:?} does not match control", + ); + } + + // The root returned should also match the equivalent node in the control tree. + let control_root_index = + NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index"); + let control_root_node = control.get_inner_node(control_root_index); + let control_hash = control_root_node.hash(); + assert_eq!( + control_hash, subtree_root.hash, + "Subtree-computed root at index {control_root_index:?} does not match control" + ); + } + + // Test that not just can we compute a subtree correctly, but we can feed the results of one + // subtree into computing another. In other words, test that `build_subtree()` is correctly + // composable. + #[test] + fn test_two_subtrees() { + // Two subtrees' worth of leaves. + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + + let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); + // With two subtrees' worth of leaves, we should have exactly two subtrees. + let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap(); + assert_eq!(first.len() as u64, PAIR_COUNT / 2); + assert_eq!(first.len(), second.len()); + + let mut current_depth = SMT_DEPTH; + let mut next_leaves: Vec = Default::default(); + + let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth); + next_leaves.push(first_root); + + let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth); + next_leaves.push(second_root); + + // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. + let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len(); + assert_eq!(total_computed as u64, PAIR_COUNT); + + // Verify the computed nodes of both subtrees. + let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes); + for (index, test_node) in computed_nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + + current_depth -= SUBTREE_DEPTH; + + let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth); + assert_eq!(nodes.len(), SUBTREE_DEPTH as usize); + assert_eq!(root_leaf.col, 0); + + for (index, test_node) in nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + + let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap(); + let control_root = control.get_inner_node(index).hash(); + assert_eq!(control_root, root_leaf.hash, "Root mismatch"); + } + + #[test] + fn test_singlethreaded_subtrees() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = Smt::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // There's no flat_map_unzip(), so this is the best we can do. + let (nodes, mut subtree_roots): (Vec>, Vec) = + leaf_subtrees + .into_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + + // Do actual things. + let (nodes, subtree_root) = + build_subtree(subtree, SMT_DEPTH, current_depth); + + // Post-assertions. + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + (nodes, subtree_root) + }) + .unzip(); + + // Update state between each depth iteration. + + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + + // Make sure the true leaves match, first checking length and then checking each individual + // leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let index = LeafIndex::new_max_depth(col); + let &control_leaf = control_leaves.get(&index).unwrap(); + assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control"); + } + + // Make sure the inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + + // After the last iteration of the above for loop, we should have the new root node actually + // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + + let control_root = control.get_inner_node(NodeIndex::root()); + + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), root_leaf.hash); + + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), root_leaf.hash); + } + + /// The parallel version of `test_singlethreaded_subtree()`. + #[test] + fn test_multithreaded_subtrees() { + use rayon::prelude::*; + + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = Smt::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + let (nodes, mut subtree_roots): (Vec>, Vec) = + leaf_subtrees + .into_par_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + + let (nodes, subtree_root) = + build_subtree(subtree, SMT_DEPTH, current_depth); + + // Post-assertions. + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + (nodes, subtree_root) + }) + .unzip(); + + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + + // Make sure the true leaves match, checking length first and then each individual leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let index = LeafIndex::new_max_depth(col); + let &control_leaf = control_leaves.get(&index).unwrap(); + assert_eq!(test_leaf, control_leaf); + } + + // Make sure the inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + + // After the last iteration of the above for loop, we should have the new root node actually + // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + + let control_root = control.get_inner_node(NodeIndex::root()); + + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), root_leaf.hash); + + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), root_leaf.hash); + } + + #[test] + fn test_with_entries_concurrent() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + + let smt = Smt::with_entries(entries.clone()).unwrap(); + assert_eq!(smt.root(), control.root()); + assert_eq!(smt, control); + } + + /// Concurrent mutations + #[test] + fn test_singlethreaded_subtree_mutations() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + let updates = generate_updates(entries.clone(), 1000); + + let tree = Smt::with_entries_sequential(entries.clone()).unwrap(); + let control = tree.compute_mutations_sequential(updates.clone()); + + let mut node_mutations = NodeMutations::default(); + + let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // There's no flat_map_unzip(), so this is the best we can do. + let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves + .into_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + + // Calculate the mutations for this subtree. + let (mutations_per_subtree, subtree_root) = + tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth); + + // Check that the mutations match the control tree. + for (&index, mutation) in mutations_per_subtree.iter() { + let control_mutation = control.node_mutations().get(&index).unwrap(); + assert_eq!( + control_mutation, mutation, + "depth {} subtree {}: mutation does not match control at index {:?}", + current_depth, i, index, + ); + } + + (mutations_per_subtree, subtree_root) + }) + .unzip(); + + subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + node_mutations.extend(mutations_per_subtree.into_iter().flatten()); + + assert!(!subtree_leaves.is_empty(), "on depth {current_depth}"); + } + + let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap(); + let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap(); + // Check that the new root matches the control. + assert_eq!(control.new_root, root_leaf.hash); + + // Check that the node mutations match the control. + assert_eq!(control.node_mutations().len(), node_mutations.len()); + for (&index, mutation) in control.node_mutations().iter() { + let test_mutation = node_mutations.get(&index).unwrap(); + assert_eq!(test_mutation, mutation); + } + // Check that the new pairs match the control + assert_eq!(control.new_pairs.len(), new_pairs.len()); + for (&key, &value) in control.new_pairs.iter() { + let test_value = new_pairs.get(&key).unwrap(); + assert_eq!(test_value, &value); + } + } + + #[test] + fn test_compute_mutations_parallel() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + let tree = Smt::with_entries(entries.clone()).unwrap(); + + let updates = generate_updates(entries, 1000); + + let control = tree.compute_mutations_sequential(updates.clone()); + let mutations = tree.compute_mutations(updates); + + assert_eq!(mutations.root(), control.root()); + assert_eq!(mutations.old_root(), control.old_root()); + assert_eq!(mutations.node_mutations(), control.node_mutations()); + assert_eq!(mutations.new_pairs(), control.new_pairs()); + } +} diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 9e183231..1443dbbf 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,7 +1,6 @@ use alloc::vec::Vec; -use core::{hash::Hash, mem}; +use core::hash::Hash; -use num::Integer; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; @@ -11,6 +10,10 @@ use crate::{ }; mod full; +#[cfg(feature = "internal")] +pub use full::build_subtree_for_bench; +#[cfg(feature = "internal")] +pub use full::SubtreeLeaf; pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH}; mod simple; @@ -36,7 +39,6 @@ type UnorderedMap = alloc::collections::BTreeMap; type InnerNodes = UnorderedMap; type Leaves = UnorderedMap; type NodeMutations = UnorderedMap; -type MutatedSubtreeLeaves = Vec>; /// An abstract description of a sparse Merkle tree. /// @@ -76,17 +78,6 @@ pub(crate) trait SparseMerkleTree { // PROVIDED METHODS // --------------------------------------------------------------------------------------------- - /// Creates a new sparse Merkle tree from an existing set of key-value pairs, in parallel. - #[cfg(feature = "concurrent")] - fn with_entries_par(entries: Vec<(Self::Key, Self::Value)>) -> Result - where - Self: Sized, - { - let (inner_nodes, leaves) = Self::build_subtrees(entries); - let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); - Self::from_raw_parts(inner_nodes, leaves, root) - } - /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// path to the leaf, as well as the leaf itself. fn open(&self, key: &Self::Key) -> Self::Opening { @@ -179,20 +170,8 @@ pub(crate) trait SparseMerkleTree { fn compute_mutations( &self, kv_pairs: impl IntoIterator, - ) -> MutationSet - where - Self: Sized + Sync, - Self::Key: Send + Sync, - Self::Value: Send + Sync, - { - #[cfg(feature = "concurrent")] - { - self.compute_mutations_concurrent(kv_pairs) - } - #[cfg(not(feature = "concurrent"))] - { - self.compute_mutations_sequential(kv_pairs) - } + ) -> MutationSet { + self.compute_mutations_sequential(kv_pairs) } /// Sequential version of [`SparseMerkleTree::compute_mutations()`]. @@ -287,71 +266,6 @@ pub(crate) trait SparseMerkleTree { } } - /// Parallel implementation of [`SparseMerkleTree::compute_mutations()`]. - /// - /// This method computes mutations by recursively processing subtrees in parallel, working from - /// the bottom up. For a tree of depth D with subtrees of depth 8, the process works as - /// follows: - /// - /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf - /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. - /// - /// 2. The subtrees containing modifications are then processed in parallel: - /// - For each modified subtree, compute node mutations from depth D up to depth D-8 - /// - Each subtree computation yields a new root at depth D-8 and its associated mutations - /// - /// 3. These subtree roots become the "leaves" for the next iteration, which processes the next - /// 8 levels up. This continues until reaching the tree's root at depth 0. - #[cfg(feature = "concurrent")] - fn compute_mutations_concurrent( - &self, - kv_pairs: impl IntoIterator, - ) -> MutationSet - where - Self: Sized + Sync, - Self::Key: Send + Sync, - Self::Value: Send + Sync, - { - use rayon::prelude::*; - - // Collect and sort key-value pairs by their corresponding leaf index - let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); - sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); - - // Convert sorted pairs into mutated leaves and capture any new pairs - let (mut subtree_leaves, new_pairs) = - self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs); - let mut node_mutations = NodeMutations::default(); - - // Process each depth level in reverse, stepping by the subtree depth - for depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - // Parallel processing of each subtree to generate mutations and roots - let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves - .into_par_iter() - .map(|subtree| { - debug_assert!(subtree.is_sorted() && !subtree.is_empty()); - self.build_subtree_mutations(subtree, DEPTH, depth) - }) - .unzip(); - - // Prepare leaves for the next depth level - subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - - // Aggregate all node mutations - node_mutations.extend(mutations_per_subtree.into_iter().flatten()); - - debug_assert!(!subtree_leaves.is_empty()); - } - - // Finalize the mutation set with updated roots and mutations - MutationSet { - old_root: self.root(), - new_root: subtree_leaves[0][0].hash, - node_mutations, - new_pairs, - } - } - /// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to /// this tree. /// @@ -533,131 +447,6 @@ pub(crate) trait SparseMerkleTree { value: &Self::Value, ) -> Self::Leaf; - /// Computes leaves from a set of key-value pairs and current leaf values. - /// Derived from `sorted_pairs_to_leaves` - fn sorted_pairs_to_mutated_subtree_leaves( - &self, - pairs: Vec<(Self::Key, Self::Value)>, - ) -> (MutatedSubtreeLeaves, UnorderedMap) { - // Map to track new key-value pairs for mutated leaves - let mut new_pairs = UnorderedMap::new(); - - let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| { - let mut leaf = self.get_leaf(&leaf_pairs[0].0); - - for (key, value) in leaf_pairs { - // Check if the value has changed - let old_value = - new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); - - // Skip if the value hasn't changed - if value == old_value { - continue; - } - - // Otherwise, update the leaf and track the new key-value pair - leaf = self.construct_prospective_leaf(leaf, &key, &value); - new_pairs.insert(key, value); - } - - leaf - }); - (accumulator.leaves, new_pairs) - } - - /// Computes the node mutations and the root of a subtree - fn build_subtree_mutations( - &self, - mut leaves: Vec, - tree_depth: u8, - bottom_depth: u8, - ) -> (NodeMutations, SubtreeLeaf) - where - Self: Sized, - { - debug_assert!(bottom_depth <= tree_depth); - debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); - debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); - - let subtree_root_depth = bottom_depth - SUBTREE_DEPTH; - let mut node_mutations: NodeMutations = Default::default(); - let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); - - for current_depth in (subtree_root_depth..bottom_depth).rev() { - debug_assert!(current_depth <= bottom_depth); - - let next_depth = current_depth + 1; - let mut iter = leaves.drain(..).peekable(); - - while let Some(first_leaf) = iter.next() { - // This constructs a valid index because next_depth will never exceed the depth of - // the tree. - let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent(); - let parent_node = self.get_inner_node(parent_index); - let combined_node = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); - let combined_hash = combined_node.hash(); - - let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth); - - // Add the parent node even if it is empty for proper upward updates - next_leaves.push(SubtreeLeaf { - col: parent_index.value(), - hash: combined_hash, - }); - - node_mutations.insert( - parent_index, - if combined_hash != empty_hash { - NodeMutation::Addition(combined_node) - } else { - NodeMutation::Removal - }, - ); - } - drop(iter); - leaves = mem::take(&mut next_leaves); - } - - debug_assert_eq!(leaves.len(), 1); - let root_leaf = leaves.pop().unwrap(); - (node_mutations, root_leaf) - } - - /// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part: - /// - If `first_leaf` is a right child, the left child is copied from the `parent_node`. - /// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also - /// mutated or copied from the `parent_node`. - /// - /// Returns the `InnerNode` containing the hashes of the sibling pair. - fn fetch_sibling_pair( - iter: &mut core::iter::Peekable>, - first_leaf: SubtreeLeaf, - parent_node: InnerNode, - ) -> InnerNode { - let is_right_node = first_leaf.col.is_odd(); - - if is_right_node { - let left_leaf = SubtreeLeaf { - col: first_leaf.col - 1, - hash: parent_node.left, - }; - InnerNode { - left: left_leaf.hash, - right: first_leaf.hash, - } - } else { - let right_col = first_leaf.col + 1; - let right_leaf = match iter.peek().copied() { - Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(), - _ => SubtreeLeaf { col: right_col, hash: parent_node.right }, - }; - InnerNode { - left: first_leaf.hash, - right: right_leaf.hash, - } - } - } - /// Maps a key to a leaf index fn key_to_leaf_index(key: &Self::Key) -> LeafIndex; @@ -669,151 +458,6 @@ pub(crate) trait SparseMerkleTree { /// /// The length `path` is guaranteed to be equal to `DEPTH` fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening; - - /// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing - /// subtrees. In other words, this function takes the key-value inputs to the tree, and produces - /// the inputs to feed into [`build_subtree()`]. - /// - /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If - /// `pairs` is not correctly sorted, the returned computations will be incorrect. - /// - /// # Panics - /// With debug assertions on, this function panics if it detects that `pairs` is not correctly - /// sorted. Without debug assertions, the returned computations will be incorrect. - fn sorted_pairs_to_leaves( - pairs: Vec<(Self::Key, Self::Value)>, - ) -> PairComputations { - Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| Self::pairs_to_leaf(leaf_pairs)) - } - - /// Processes sorted key-value pairs to compute leaves for a subtree. - /// - /// This function groups key-value pairs by their corresponding column index and processes each - /// group to construct leaves. The actual construction of the leaf is delegated to the - /// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating - /// new leaves or mutating existing ones). - /// - /// # Parameters - /// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index - /// column (not simply by key). If the input is not sorted correctly, the function will - /// produce incorrect results and may panic in debug mode. - /// - `process_leaf`: A callback function used to process each group of key-value pairs - /// corresponding to the same column index. The callback takes a vector of key-value pairs for - /// a single column and returns the constructed leaf for that column. - /// - /// # Returns - /// A `PairComputations` containing: - /// - `nodes`: A mapping of column indices to the constructed leaves. - /// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each - /// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf. - /// - /// # Panics - /// This function will panic in debug mode if the input `pairs` are not sorted by column index. - fn process_sorted_pairs_to_leaves( - pairs: Vec<(Self::Key, Self::Value)>, - mut process_leaf: F, - ) -> PairComputations - where - F: FnMut(Vec<(Self::Key, Self::Value)>) -> Self::Leaf, - { - debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); - - let mut accumulator: PairComputations = Default::default(); - let mut accumulated_leaves: Vec = Vec::with_capacity(pairs.len() / 2); - - // As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a - // single leaf. When we see a pair that's in a different leaf, we'll swap these pairs - // out and store them in our accumulated leaves. - let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default(); - - let mut iter = pairs.into_iter().peekable(); - while let Some((key, value)) = iter.next() { - let col = Self::key_to_leaf_index(&key).index.value(); - let peeked_col = iter.peek().map(|(key, _v)| { - let index = Self::key_to_leaf_index(key); - let next_col = index.index.value(); - // We panic if `pairs` is not sorted by column. - debug_assert!(next_col >= col); - next_col - }); - current_leaf_buffer.push((key, value)); - - // If the next pair is the same column as this one, then we're done after adding this - // pair to the buffer. - if peeked_col == Some(col) { - continue; - } - - // Otherwise, the next pair is a different column, or there is no next pair. Either way - // it's time to swap out our buffer. - let leaf_pairs = mem::take(&mut current_leaf_buffer); - let leaf = process_leaf(leaf_pairs); - let hash = Self::hash_leaf(&leaf); - - accumulator.nodes.insert(col, leaf); - accumulated_leaves.push(SubtreeLeaf { col, hash }); - - debug_assert!(current_leaf_buffer.is_empty()); - } - - // TODO: determine is there is any notable performance difference between computing - // subtree boundaries after the fact as an iterator adapter (like this), versus computing - // subtree boundaries as we go. Either way this function is only used at the beginning of a - // parallel construction, so it should not be a critical path. - accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); - accumulator - } - - /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. - /// - /// `entries` need not be sorted. This function will sort them. - #[cfg(feature = "concurrent")] - fn build_subtrees( - mut entries: Vec<(Self::Key, Self::Value)>, - ) -> (InnerNodes, Leaves) { - entries.sort_by_key(|item| { - let index = Self::key_to_leaf_index(&item.0); - index.value() - }); - Self::build_subtrees_from_sorted_entries(entries) - } - - /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. - /// - /// This function is mostly an implementation detail of - /// [`SparseMerkleTree::with_entries_par()`]. - #[cfg(feature = "concurrent")] - fn build_subtrees_from_sorted_entries( - entries: Vec<(Self::Key, Self::Value)>, - ) -> (InnerNodes, Leaves) { - use rayon::prelude::*; - - let mut accumulated_nodes: InnerNodes = Default::default(); - - let PairComputations { - leaves: mut leaf_subtrees, - nodes: initial_leaves, - } = Self::sorted_pairs_to_leaves(entries); - - for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - let (nodes, mut subtree_roots): (Vec>, Vec) = - leaf_subtrees - .into_par_iter() - .map(|subtree| { - debug_assert!(subtree.is_sorted()); - debug_assert!(!subtree.is_empty()); - let (nodes, subtree_root) = build_subtree(subtree, DEPTH, current_depth); - (nodes, subtree_root) - }) - .unzip(); - - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - accumulated_nodes.extend(nodes.into_iter().flatten()); - - debug_assert!(!leaf_subtrees.is_empty()); - } - (accumulated_nodes, initial_leaves) - } } // INNER NODE @@ -1053,198 +697,3 @@ impl De }) } } - -// SUBTREES -// ================================================================================================ - -/// A subtree is of depth 8. -const SUBTREE_DEPTH: u8 = 8; - -/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. -const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32); - -/// Helper struct for organizing the data we care about when computing Merkle subtrees. -/// -/// Note that these represet "conceptual" leaves of some subtree, not necessarily -/// the leaf type for the sparse Merkle tree. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] -pub struct SubtreeLeaf { - /// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known. - pub col: u64, - /// The hash of the node this `SubtreeLeaf` represents. - pub hash: RpoDigest, -} - -/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. -#[derive(Debug, Clone)] -pub(crate) struct PairComputations { - /// Literal leaves to be added to the sparse Merkle tree's internal mapping. - pub nodes: UnorderedMap, - /// "Conceptual" leaves that will be used for computations. - pub leaves: Vec>, -} - -// Derive requires `L` to impl Default, even though we don't actually need that. -impl Default for PairComputations { - fn default() -> Self { - Self { - nodes: Default::default(), - leaves: Default::default(), - } - } -} - -#[derive(Debug)] -struct SubtreeLeavesIter<'s> { - leaves: core::iter::Peekable>, -} -impl<'s> SubtreeLeavesIter<'s> { - fn from_leaves(leaves: &'s mut Vec) -> Self { - // TODO: determine if there is any notable performance difference between taking a Vec, - // which many need flattening first, vs storing a `Box>`. - // The latter may have self-referential properties that are impossible to express in purely - // safe Rust Rust. - Self { leaves: leaves.drain(..).peekable() } - } -} -impl Iterator for SubtreeLeavesIter<'_> { - type Item = Vec; - - /// Each `next()` collects an entire subtree. - fn next(&mut self) -> Option> { - let mut subtree: Vec = Default::default(); - - let mut last_subtree_col = 0; - - while let Some(leaf) = self.leaves.peek() { - last_subtree_col = u64::max(1, last_subtree_col); - let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE); - let next_subtree_col = if is_exact_multiple { - u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) - } else { - last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) - }; - - last_subtree_col = leaf.col; - if leaf.col < next_subtree_col { - subtree.push(self.leaves.next().unwrap()); - } else if subtree.is_empty() { - continue; - } else { - break; - } - } - - if subtree.is_empty() { - debug_assert!(self.leaves.peek().is_none()); - return None; - } - - Some(subtree) - } -} - -// HELPER FUNCTIONS -// ================================================================================================ - -/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and -/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and -/// `leaves` must not contain more than one depth-8 subtree's worth of leaves. -/// -/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as -/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into -/// itself. -/// -/// # Panics -/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains -/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to -/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified -/// maximum depth (`DEPTH`), or if `leaves` is not sorted. -fn build_subtree( - mut leaves: Vec, - tree_depth: u8, - bottom_depth: u8, -) -> (UnorderedMap, SubtreeLeaf) { - debug_assert!(bottom_depth <= tree_depth); - debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); - debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); - let subtree_root = bottom_depth - SUBTREE_DEPTH; - let mut inner_nodes: UnorderedMap = Default::default(); - let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); - for next_depth in (subtree_root..bottom_depth).rev() { - debug_assert!(next_depth <= bottom_depth); - // `next_depth` is the stuff we're making. - // `current_depth` is the stuff we have. - let current_depth = next_depth + 1; - let mut iter = leaves.drain(..).peekable(); - while let Some(first) = iter.next() { - // On non-continuous iterations, including the first iteration, `first_column` may - // be a left or right node. On subsequent continuous iterations, we will always call - // `iter.next()` twice. - // On non-continuous iterations (including the very first iteration), this column - // could be either on the left or the right. If the next iteration is not - // discontinuous with our right node, then the next iteration's - let is_right = first.col.is_odd(); - let (left, right) = if is_right { - // Discontinuous iteration: we have no left node, so it must be empty. - let left = SubtreeLeaf { - col: first.col - 1, - hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), - }; - let right = first; - (left, right) - } else { - let left = first; - let right_col = first.col + 1; - let right = match iter.peek().copied() { - Some(SubtreeLeaf { col, .. }) if col == right_col => { - // Our inputs must be sorted. - debug_assert!(left.col <= col); - // The next leaf in the iterator is our sibling. Use it and consume it! - iter.next().unwrap() - }, - // Otherwise, the leaves don't contain our sibling, so our sibling must be - // empty. - _ => SubtreeLeaf { - col: right_col, - hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), - }, - }; - (left, right) - }; - let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); - let node = InnerNode { left: left.hash, right: right.hash }; - let hash = node.hash(); - let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth); - // If this hash is empty, then it doesn't become a new inner node, nor does it count - // as a leaf for the next depth. - if hash != equivalent_empty_hash { - inner_nodes.insert(index, node); - next_leaves.push(SubtreeLeaf { col: index.value(), hash }); - } - } - // Stop borrowing `leaves`, so we can swap it. - // The iterator is empty at this point anyway. - drop(iter); - // After each depth, consider the stuff we just made the new "leaves", and empty the - // other collection. - mem::swap(&mut leaves, &mut next_leaves); - } - debug_assert_eq!(leaves.len(), 1); - let root = leaves.pop().unwrap(); - (inner_nodes, root) -} - -#[cfg(feature = "internal")] -pub fn build_subtree_for_bench( - leaves: Vec, - tree_depth: u8, - bottom_depth: u8, -) -> (UnorderedMap, SubtreeLeaf) { - build_subtree(leaves, tree_depth, bottom_depth) -} - -// TESTS -// ================================================================================================ -#[cfg(test)] -mod tests; diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index d6f0933e..b4402002 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -237,7 +237,7 @@ impl SimpleSmt { &self, kv_pairs: impl IntoIterator, Word)>, ) -> MutationSet, Word> { - >::compute_mutations_sequential(self, kv_pairs) + >::compute_mutations(self, kv_pairs) } /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs deleted file mode 100644 index 0e4a893c..00000000 --- a/src/merkle/smt/tests.rs +++ /dev/null @@ -1,542 +0,0 @@ -use alloc::{ - collections::{BTreeMap, BTreeSet}, - vec::Vec, -}; - -use rand::{prelude::IteratorRandom, thread_rng, Rng}; - -use super::{ - build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, SmtLeaf, - SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, SUBTREE_DEPTH, -}; -use crate::{ - hash::rpo::RpoDigest, - merkle::{smt::UnorderedMap, Smt, SMT_DEPTH}, - Felt, Word, EMPTY_WORD, ONE, -}; - -fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { - SubtreeLeaf { - col: leaf.index().index.value(), - hash: leaf.hash(), - } -} - -#[test] -fn test_sorted_pairs_to_leaves() { - let entries: Vec<(RpoDigest, Word)> = vec![ - // Subtree 0. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), - (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), - // Leaf index collision. - (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), - (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), - // Subtree 1. Normal single leaf again. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]), - // Subtree 2. Another normal leaf. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]), - ]; - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let control_leaves: Vec = { - let mut entries_iter = entries.iter().cloned(); - let mut next_entry = || entries_iter.next().unwrap(); - let control_leaves = vec![ - // Subtree 0. - SmtLeaf::Single(next_entry()), - SmtLeaf::Single(next_entry()), - SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), - // Subtree 1. - SmtLeaf::Single(next_entry()), - SmtLeaf::Single(next_entry()), - // Subtree 2. - SmtLeaf::Single(next_entry()), - ]; - assert_eq!(entries_iter.next(), None); - control_leaves - }; - - let control_subtree_leaves: Vec> = { - let mut control_leaves_iter = control_leaves.iter(); - let mut next_leaf = || control_leaves_iter.next().unwrap(); - - let control_subtree_leaves: Vec> = [ - // Subtree 0. - vec![next_leaf(), next_leaf(), next_leaf()], - // Subtree 1. - vec![next_leaf(), next_leaf()], - // Subtree 2. - vec![next_leaf()], - ] - .map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect()) - .to_vec(); - assert_eq!(control_leaves_iter.next(), None); - control_subtree_leaves - }; - - let subtrees: PairComputations = Smt::sorted_pairs_to_leaves(entries); - // This will check that the hashes, columns, and subtree assignments all match. - assert_eq!(subtrees.leaves, control_subtree_leaves); - - // Flattening and re-separating out the leaves into subtrees should have the same result. - let mut all_leaves: Vec = subtrees.leaves.clone().into_iter().flatten().collect(); - let re_grouped: Vec> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); - assert_eq!(subtrees.leaves, re_grouped); - - // Then finally we might as well check the computed leaf nodes too. - let control_leaves: BTreeMap = control - .leaves() - .map(|(index, value)| (index.index.value(), value.clone())) - .collect(); - - for (column, test_leaf) in subtrees.nodes { - if test_leaf.is_empty() { - continue; - } - let control_leaf = control_leaves - .get(&column) - .unwrap_or_else(|| panic!("no leaf node found for column {column}")); - assert_eq!(control_leaf, &test_leaf); - } -} - -// Helper for the below tests. -fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { - (0..pair_count) - .map(|i| { - let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64; - let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); - let value = [ONE, ONE, ONE, Felt::new(i)]; - (key, value) - }) - .collect() -} - -fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> { - const REMOVAL_PROBABILITY: f64 = 0.2; - let mut rng = thread_rng(); - - // Assertion to ensure input keys are unique - assert!( - entries.iter().map(|(key, _)| key).collect::>().len() == entries.len(), - "Input entries contain duplicate keys!" - ); - - let mut sorted_entries: Vec<(RpoDigest, Word)> = entries - .into_iter() - .choose_multiple(&mut rng, updates) - .into_iter() - .map(|(key, _)| { - let value = if rng.gen_bool(REMOVAL_PROBABILITY) { - EMPTY_WORD - } else { - [ONE, ONE, ONE, Felt::new(rng.gen())] - }; - - (key, value) - }) - .collect(); - sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value()); - sorted_entries -} - -#[test] -fn test_single_subtree() { - // A single subtree's worth of leaves. - const PAIR_COUNT: u64 = COLS_PER_SUBTREE; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - // `entries` should already be sorted by nature of how we constructed it. - let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; - let leaves = leaves.into_iter().next().unwrap(); - - let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH); - assert!(!first_subtree.is_empty()); - - // The inner nodes computed from that subtree should match the nodes in our control tree. - for (index, node) in first_subtree.into_iter() { - let control = control.get_inner_node(index); - assert_eq!( - control, node, - "subtree-computed node at index {index:?} does not match control", - ); - } - - // The root returned should also match the equivalent node in the control tree. - let control_root_index = - NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index"); - let control_root_node = control.get_inner_node(control_root_index); - let control_hash = control_root_node.hash(); - assert_eq!( - control_hash, subtree_root.hash, - "Subtree-computed root at index {control_root_index:?} does not match control" - ); -} - -// Test that not just can we compute a subtree correctly, but we can feed the results of one -// subtree into computing another. In other words, test that `build_subtree()` is correctly -// composable. -#[test] -fn test_two_subtrees() { - // Two subtrees' worth of leaves. - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); - // With two subtrees' worth of leaves, we should have exactly two subtrees. - let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap(); - assert_eq!(first.len() as u64, PAIR_COUNT / 2); - assert_eq!(first.len(), second.len()); - - let mut current_depth = SMT_DEPTH; - let mut next_leaves: Vec = Default::default(); - - let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth); - next_leaves.push(first_root); - - let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth); - next_leaves.push(second_root); - - // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. - let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len(); - assert_eq!(total_computed as u64, PAIR_COUNT); - - // Verify the computed nodes of both subtrees. - let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes); - for (index, test_node) in computed_nodes { - let control_node = control.get_inner_node(index); - assert_eq!( - control_node, test_node, - "subtree-computed node at index {index:?} does not match control", - ); - } - - current_depth -= SUBTREE_DEPTH; - - let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth); - assert_eq!(nodes.len(), SUBTREE_DEPTH as usize); - assert_eq!(root_leaf.col, 0); - - for (index, test_node) in nodes { - let control_node = control.get_inner_node(index); - assert_eq!( - control_node, test_node, - "subtree-computed node at index {index:?} does not match control", - ); - } - - let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap(); - let control_root = control.get_inner_node(index).hash(); - assert_eq!(control_root, root_leaf.hash, "Root mismatch"); -} - -#[test] -fn test_singlethreaded_subtrees() { - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let mut accumulated_nodes: BTreeMap = Default::default(); - - let PairComputations { - leaves: mut leaf_subtrees, - nodes: test_leaves, - } = Smt::sorted_pairs_to_leaves(entries); - - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - // There's no flat_map_unzip(), so this is the best we can do. - let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees - .into_iter() - .enumerate() - .map(|(i, subtree)| { - // Pre-assertions. - assert!( - subtree.is_sorted(), - "subtree {i} at bottom-depth {current_depth} is not sorted", - ); - assert!( - !subtree.is_empty(), - "subtree {i} at bottom-depth {current_depth} is empty!", - ); - - // Do actual things. - let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth); - - // Post-assertions. - for (&index, test_node) in nodes.iter() { - let control_node = control.get_inner_node(index); - assert_eq!( - test_node, &control_node, - "depth {} subtree {}: test node does not match control at index {:?}", - current_depth, i, index, - ); - } - - (nodes, subtree_root) - }) - .unzip(); - - // Update state between each depth iteration. - - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - accumulated_nodes.extend(nodes.into_iter().flatten()); - - assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); - } - - // Make sure the true leaves match, first checking length and then checking each individual - // leaf. - let control_leaves: BTreeMap<_, _> = control.leaves().collect(); - let control_leaves_len = control_leaves.len(); - let test_leaves_len = test_leaves.len(); - assert_eq!(test_leaves_len, control_leaves_len); - for (col, ref test_leaf) in test_leaves { - let index = LeafIndex::new_max_depth(col); - let &control_leaf = control_leaves.get(&index).unwrap(); - assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control"); - } - - // Make sure the inner nodes match, checking length first and then each individual leaf. - let control_nodes_len = control.inner_nodes().count(); - let test_nodes_len = accumulated_nodes.len(); - assert_eq!(test_nodes_len, control_nodes_len); - for (index, test_node) in accumulated_nodes.clone() { - let control_node = control.get_inner_node(index); - assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); - } - - // After the last iteration of the above for loop, we should have the new root node actually - // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from - // `build_subtree()`. So let's check both! - - let control_root = control.get_inner_node(NodeIndex::root()); - - // That for loop should have left us with only one leaf subtree... - let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap(); - // which itself contains only one 'leaf'... - let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap(); - // which matches the expected root. - assert_eq!(control.root(), root_leaf.hash); - - // Likewise `accumulated_nodes` should contain a node at the root index... - assert!(accumulated_nodes.contains_key(&NodeIndex::root())); - // and it should match our actual root. - let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); - assert_eq!(control_root, *test_root); - // And of course the root we got from each place should match. - assert_eq!(control.root(), root_leaf.hash); -} - -/// The parallel version of `test_singlethreaded_subtree()`. -#[test] -#[cfg(feature = "concurrent")] -fn test_multithreaded_subtrees() { - use rayon::prelude::*; - - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let mut accumulated_nodes: BTreeMap = Default::default(); - - let PairComputations { - leaves: mut leaf_subtrees, - nodes: test_leaves, - } = Smt::sorted_pairs_to_leaves(entries); - - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees - .into_par_iter() - .enumerate() - .map(|(i, subtree)| { - // Pre-assertions. - assert!( - subtree.is_sorted(), - "subtree {i} at bottom-depth {current_depth} is not sorted", - ); - assert!( - !subtree.is_empty(), - "subtree {i} at bottom-depth {current_depth} is empty!", - ); - - let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth); - - // Post-assertions. - for (&index, test_node) in nodes.iter() { - let control_node = control.get_inner_node(index); - assert_eq!( - test_node, &control_node, - "depth {} subtree {}: test node does not match control at index {:?}", - current_depth, i, index, - ); - } - - (nodes, subtree_root) - }) - .unzip(); - - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - accumulated_nodes.extend(nodes.into_iter().flatten()); - - assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); - } - - // Make sure the true leaves match, checking length first and then each individual leaf. - let control_leaves: BTreeMap<_, _> = control.leaves().collect(); - let control_leaves_len = control_leaves.len(); - let test_leaves_len = test_leaves.len(); - assert_eq!(test_leaves_len, control_leaves_len); - for (col, ref test_leaf) in test_leaves { - let index = LeafIndex::new_max_depth(col); - let &control_leaf = control_leaves.get(&index).unwrap(); - assert_eq!(test_leaf, control_leaf); - } - - // Make sure the inner nodes match, checking length first and then each individual leaf. - let control_nodes_len = control.inner_nodes().count(); - let test_nodes_len = accumulated_nodes.len(); - assert_eq!(test_nodes_len, control_nodes_len); - for (index, test_node) in accumulated_nodes.clone() { - let control_node = control.get_inner_node(index); - assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); - } - - // After the last iteration of the above for loop, we should have the new root node actually - // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from - // `build_subtree()`. So let's check both! - - let control_root = control.get_inner_node(NodeIndex::root()); - - // That for loop should have left us with only one leaf subtree... - let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); - // which itself contains only one 'leaf'... - let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap(); - // which matches the expected root. - assert_eq!(control.root(), root_leaf.hash); - - // Likewise `accumulated_nodes` should contain a node at the root index... - assert!(accumulated_nodes.contains_key(&NodeIndex::root())); - // and it should match our actual root. - let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); - assert_eq!(control_root, *test_root); - // And of course the root we got from each place should match. - assert_eq!(control.root(), root_leaf.hash); -} - -#[test] -#[cfg(feature = "concurrent")] -fn test_with_entries_parallel() { - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let smt = Smt::with_entries(entries.clone()).unwrap(); - assert_eq!(smt.root(), control.root()); - assert_eq!(smt, control); -} - -#[test] -fn test_singlethreaded_subtree_mutations() { - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - let updates = generate_updates(entries.clone(), 1000); - - let tree = Smt::with_entries_sequential(entries.clone()).unwrap(); - let control = tree.compute_mutations_sequential(updates.clone()); - - let mut node_mutations = NodeMutations::default(); - - let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates); - - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - // There's no flat_map_unzip(), so this is the best we can do. - let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves - .into_iter() - .enumerate() - .map(|(i, subtree)| { - // Pre-assertions. - assert!( - subtree.is_sorted(), - "subtree {i} at bottom-depth {current_depth} is not sorted", - ); - assert!( - !subtree.is_empty(), - "subtree {i} at bottom-depth {current_depth} is empty!", - ); - - // Calculate the mutations for this subtree. - let (mutations_per_subtree, subtree_root) = - tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth); - - // Check that the mutations match the control tree. - for (&index, mutation) in mutations_per_subtree.iter() { - let control_mutation = control.node_mutations().get(&index).unwrap(); - assert_eq!( - control_mutation, mutation, - "depth {} subtree {}: mutation does not match control at index {:?}", - current_depth, i, index, - ); - } - - (mutations_per_subtree, subtree_root) - }) - .unzip(); - - subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - node_mutations.extend(mutations_per_subtree.into_iter().flatten()); - - assert!(!subtree_leaves.is_empty(), "on depth {current_depth}"); - } - - let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap(); - let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap(); - // Check that the new root matches the control. - assert_eq!(control.new_root, root_leaf.hash); - - // Check that the node mutations match the control. - assert_eq!(control.node_mutations().len(), node_mutations.len()); - for (&index, mutation) in control.node_mutations().iter() { - let test_mutation = node_mutations.get(&index).unwrap(); - assert_eq!(test_mutation, mutation); - } - // Check that the new pairs match the control - assert_eq!(control.new_pairs.len(), new_pairs.len()); - for (&key, &value) in control.new_pairs.iter() { - let test_value = new_pairs.get(&key).unwrap(); - assert_eq!(test_value, &value); - } -} - -#[test] -#[cfg(feature = "concurrent")] -fn test_compute_mutations_parallel() { - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - let tree = Smt::with_entries(entries.clone()).unwrap(); - - let updates = generate_updates(entries, 1000); - - let control = tree.compute_mutations_sequential(updates.clone()); - let mutations = tree.compute_mutations(updates); - - assert_eq!(mutations.root(), control.root()); - assert_eq!(mutations.old_root(), control.old_root()); - assert_eq!(mutations.node_mutations(), control.node_mutations()); - assert_eq!(mutations.new_pairs(), control.new_pairs()); -} From bdcdd6cf5cc290507f9aa00051fd817f67e24743 Mon Sep 17 00:00:00 2001 From: krushimir Date: Wed, 5 Feb 2025 21:48:19 +0100 Subject: [PATCH 11/13] chore: remove unnecessary note --- src/merkle/smt/simple/mod.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index b4402002..166cc982 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -218,10 +218,6 @@ impl SimpleSmt { /// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the /// Merkle tree, or [`drop()`] to discard them. /// - /// **Note:** Parallel computation is only supported for trees whose depth is a multiple of 8. - /// Since `SimpleSmt` can have a depth that isn't a multiple of 8, this method defaults to the - /// sequential implementation. - /// /// # Example /// ``` /// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word}; From 1d7cac875aa75b24eb6ff9704ccb02b01eeaefc8 Mon Sep 17 00:00:00 2001 From: krushimir Date: Thu, 6 Feb 2025 11:40:44 +0100 Subject: [PATCH 12/13] chore: improve refactor effort --- src/merkle/smt/full/concurrent/mod.rs | 580 +++++++++++++++++++++++ src/merkle/smt/full/concurrent/tests.rs | 446 ++++++++++++++++++ src/merkle/smt/full/mod.rs | 583 +----------------------- src/merkle/smt/full/tests.rs | 554 ---------------------- 4 files changed, 1037 insertions(+), 1126 deletions(-) create mode 100644 src/merkle/smt/full/concurrent/mod.rs create mode 100644 src/merkle/smt/full/concurrent/tests.rs diff --git a/src/merkle/smt/full/concurrent/mod.rs b/src/merkle/smt/full/concurrent/mod.rs new file mode 100644 index 00000000..ea890d48 --- /dev/null +++ b/src/merkle/smt/full/concurrent/mod.rs @@ -0,0 +1,580 @@ +use alloc::{collections::BTreeSet, vec::Vec}; +use core::mem; + +use num::Integer; + +use super::{ + EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet, + NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH, +}; +use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap}; + +#[cfg(test)] +mod tests; + +type MutatedSubtreeLeaves = Vec>; + +impl Smt { + /// Parallel implementation of [`Smt::with_entries()`]. + /// + /// This method constructs a new sparse Merkle tree concurrently by processing subtrees in + /// parallel, working from the bottom up. The process works as follows: + /// + /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf + /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. + /// + /// 2. The subtrees are then processed in parallel: + /// - For each subtree, compute the inner nodes from depth D down to depth D-8. + /// - Each subtree computation yields a new subtree root and its associated inner nodes. + /// + /// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration, + /// which processes the next 8 levels up. This continues until the final root of the tree is + /// computed at depth 0. + pub(crate) fn with_entries_concurrent( + entries: impl IntoIterator, + ) -> Result { + let mut seen_keys = BTreeSet::new(); + let entries: Vec<_> = entries + .into_iter() + .map(|(key, value)| { + if seen_keys.insert(key) { + Ok((key, value)) + } else { + Err(MerkleError::DuplicateValuesForIndex( + LeafIndex::::from(key).value(), + )) + } + }) + .collect::>()?; + if entries.is_empty() { + return Ok(Self::default()); + } + let (inner_nodes, leaves) = Self::build_subtrees(entries); + let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); + >::from_raw_parts(inner_nodes, leaves, root) + } + + /// Parallel implementation of [`Smt::compute_mutations()`]. + /// + /// This method computes mutations by recursively processing subtrees in parallel, working from + /// the bottom up. The process works as follows: + /// + /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf + /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. + /// + /// 2. The subtrees containing modifications are then processed in parallel: + /// - For each modified subtree, compute node mutations from depth D up to depth D-8 + /// - Each subtree computation yields a new root at depth D-8 and its associated mutations + /// + /// 3. These subtree roots become the "leaves" for the next iteration, which processes the next + /// 8 levels up. This continues until reaching the tree's root at depth 0. + pub(crate) fn compute_mutations_concurrent( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet + where + Self: Sized + Sync, + { + use rayon::prelude::*; + + // Collect and sort key-value pairs by their corresponding leaf index + let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); + sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); + + // Convert sorted pairs into mutated leaves and capture any new pairs + let (mut subtree_leaves, new_pairs) = + self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs); + let mut node_mutations = NodeMutations::default(); + + // Process each depth level in reverse, stepping by the subtree depth + for depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // Parallel processing of each subtree to generate mutations and roots + let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted() && !subtree.is_empty()); + self.build_subtree_mutations(subtree, SMT_DEPTH, depth) + }) + .unzip(); + + // Prepare leaves for the next depth level + subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + + // Aggregate all node mutations + node_mutations.extend(mutations_per_subtree.into_iter().flatten()); + + debug_assert!(!subtree_leaves.is_empty()); + } + + // Finalize the mutation set with updated roots and mutations + MutationSet { + old_root: self.root(), + new_root: subtree_leaves[0][0].hash, + node_mutations, + new_pairs, + } + } + + /// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing + /// subtrees. In other words, this function takes the key-value inputs to the tree, and produces + /// the inputs to feed into [`build_subtree()`]. + /// + /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If + /// `pairs` is not correctly sorted, the returned computations will be incorrect. + /// + /// # Panics + /// With debug assertions on, this function panics if it detects that `pairs` is not correctly + /// sorted. Without debug assertions, the returned computations will be incorrect. + fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations { + Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf) + } + + /// Computes leaves from a set of key-value pairs and current leaf values. + /// Derived from `sorted_pairs_to_leaves` + fn sorted_pairs_to_mutated_subtree_leaves( + &self, + pairs: Vec<(RpoDigest, Word)>, + ) -> (MutatedSubtreeLeaves, UnorderedMap) { + // Map to track new key-value pairs for mutated leaves + let mut new_pairs = UnorderedMap::new(); + + let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| { + let mut leaf = self.get_leaf(&leaf_pairs[0].0); + + for (key, value) in leaf_pairs { + // Check if the value has changed + let old_value = + new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); + + // Skip if the value hasn't changed + if value == old_value { + continue; + } + + // Otherwise, update the leaf and track the new key-value pair + leaf = self.construct_prospective_leaf(leaf, &key, &value); + new_pairs.insert(key, value); + } + + leaf + }); + (accumulator.leaves, new_pairs) + } + + /// Computes the node mutations and the root of a subtree + fn build_subtree_mutations( + &self, + mut leaves: Vec, + tree_depth: u8, + bottom_depth: u8, + ) -> (NodeMutations, SubtreeLeaf) + where + Self: Sized, + { + debug_assert!(bottom_depth <= tree_depth); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); + debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); + + let subtree_root_depth = bottom_depth - SUBTREE_DEPTH; + let mut node_mutations: NodeMutations = Default::default(); + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + + for current_depth in (subtree_root_depth..bottom_depth).rev() { + debug_assert!(current_depth <= bottom_depth); + + let next_depth = current_depth + 1; + let mut iter = leaves.drain(..).peekable(); + + while let Some(first_leaf) = iter.next() { + // This constructs a valid index because next_depth will never exceed the depth of + // the tree. + let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent(); + let parent_node = self.get_inner_node(parent_index); + let combined_node = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); + let combined_hash = combined_node.hash(); + + let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth); + + // Add the parent node even if it is empty for proper upward updates + next_leaves.push(SubtreeLeaf { + col: parent_index.value(), + hash: combined_hash, + }); + + node_mutations.insert( + parent_index, + if combined_hash != empty_hash { + NodeMutation::Addition(combined_node) + } else { + NodeMutation::Removal + }, + ); + } + drop(iter); + leaves = mem::take(&mut next_leaves); + } + + debug_assert_eq!(leaves.len(), 1); + let root_leaf = leaves.pop().unwrap(); + (node_mutations, root_leaf) + } + + /// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part: + /// - If `first_leaf` is a right child, the left child is copied from the `parent_node`. + /// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also + /// mutated or copied from the `parent_node`. + /// + /// Returns the `InnerNode` containing the hashes of the sibling pair. + fn fetch_sibling_pair( + iter: &mut core::iter::Peekable>, + first_leaf: SubtreeLeaf, + parent_node: InnerNode, + ) -> InnerNode { + let is_right_node = first_leaf.col.is_odd(); + + if is_right_node { + let left_leaf = SubtreeLeaf { + col: first_leaf.col - 1, + hash: parent_node.left, + }; + InnerNode { + left: left_leaf.hash, + right: first_leaf.hash, + } + } else { + let right_col = first_leaf.col + 1; + let right_leaf = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(), + _ => SubtreeLeaf { col: right_col, hash: parent_node.right }, + }; + InnerNode { + left: first_leaf.hash, + right: right_leaf.hash, + } + } + } + + /// Processes sorted key-value pairs to compute leaves for a subtree. + /// + /// This function groups key-value pairs by their corresponding column index and processes each + /// group to construct leaves. The actual construction of the leaf is delegated to the + /// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating + /// new leaves or mutating existing ones). + /// + /// # Parameters + /// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index + /// column (not simply by key). If the input is not sorted correctly, the function will + /// produce incorrect results and may panic in debug mode. + /// - `process_leaf`: A callback function used to process each group of key-value pairs + /// corresponding to the same column index. The callback takes a vector of key-value pairs for + /// a single column and returns the constructed leaf for that column. + /// + /// # Returns + /// A `PairComputations` containing: + /// - `nodes`: A mapping of column indices to the constructed leaves. + /// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each + /// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf. + /// + /// # Panics + /// This function will panic in debug mode if the input `pairs` are not sorted by column index. + fn process_sorted_pairs_to_leaves( + pairs: Vec<(RpoDigest, Word)>, + mut process_leaf: F, + ) -> PairComputations + where + F: FnMut(Vec<(RpoDigest, Word)>) -> SmtLeaf, + { + use rayon::prelude::*; + debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); + + let mut accumulator: PairComputations = Default::default(); + + // As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a + // single leaf. When we see a pair that's in a different leaf, we'll swap these pairs + // out and store them in our accumulated leaves. + let mut current_leaf_buffer: Vec<(RpoDigest, Word)> = Default::default(); + + let mut iter = pairs.into_iter().peekable(); + while let Some((key, value)) = iter.next() { + let col = Self::key_to_leaf_index(&key).index.value(); + let peeked_col = iter.peek().map(|(key, _v)| { + let index = Self::key_to_leaf_index(key); + let next_col = index.index.value(); + // We panic if `pairs` is not sorted by column. + debug_assert!(next_col >= col); + next_col + }); + current_leaf_buffer.push((key, value)); + + // If the next pair is the same column as this one, then we're done after adding this + // pair to the buffer. + if peeked_col == Some(col) { + continue; + } + + // Otherwise, the next pair is a different column, or there is no next pair. Either way + // it's time to swap out our buffer. + let leaf_pairs = mem::take(&mut current_leaf_buffer); + let leaf = process_leaf(leaf_pairs); + + accumulator.nodes.insert(col, leaf); + + debug_assert!(current_leaf_buffer.is_empty()); + } + + // Compute the leaves from the nodes concurrently + let mut accumulated_leaves: Vec = accumulator + .nodes + .clone() + .into_par_iter() + .map(|(col, leaf)| SubtreeLeaf { col, hash: Self::hash_leaf(&leaf) }) + .collect(); + + // Sort the leaves by column + accumulated_leaves.par_sort_by_key(|leaf| leaf.col); + + // TODO: determine is there is any notable performance difference between computing + // subtree boundaries after the fact as an iterator adapter (like this), versus computing + // subtree boundaries as we go. Either way this function is only used at the beginning of a + // parallel construction, so it should not be a critical path. + accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); + accumulator + } + + /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. + /// + /// `entries` need not be sorted. This function will sort them. + fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { + entries.sort_by_key(|item| { + let index = Self::key_to_leaf_index(&item.0); + index.value() + }); + Self::build_subtrees_from_sorted_entries(entries) + } + + /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. + /// + /// This function is mostly an implementation detail of + /// [`Smt::with_entries_concurrent()`]. + fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { + use rayon::prelude::*; + + let mut accumulated_nodes: InnerNodes = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: initial_leaves, + } = Self::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + let (nodes, mut subtree_roots): (Vec>, Vec) = + leaf_subtrees + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted()); + debug_assert!(!subtree.is_empty()); + let (nodes, subtree_root) = + build_subtree(subtree, SMT_DEPTH, current_depth); + (nodes, subtree_root) + }) + .unzip(); + + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + debug_assert!(!leaf_subtrees.is_empty()); + } + (accumulated_nodes, initial_leaves) + } +} + +// SUBTREES +// ================================================================================================ + +/// A subtree is of depth 8. +const SUBTREE_DEPTH: u8 = 8; + +/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. +const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32); + +/// Helper struct for organizing the data we care about when computing Merkle subtrees. +/// +/// Note that these represet "conceptual" leaves of some subtree, not necessarily +/// the leaf type for the sparse Merkle tree. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct SubtreeLeaf { + /// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known. + pub col: u64, + /// The hash of the node this `SubtreeLeaf` represents. + pub hash: RpoDigest, +} + +/// Helper struct to organize the return value of [`Smt::sorted_pairs_to_leaves()`]. +#[derive(Debug, Clone)] +pub(crate) struct PairComputations { + /// Literal leaves to be added to the sparse Merkle tree's internal mapping. + pub nodes: UnorderedMap, + /// "Conceptual" leaves that will be used for computations. + pub leaves: Vec>, +} + +// Derive requires `L` to impl Default, even though we don't actually need that. +impl Default for PairComputations { + fn default() -> Self { + Self { + nodes: Default::default(), + leaves: Default::default(), + } + } +} + +#[derive(Debug)] +pub(crate) struct SubtreeLeavesIter<'s> { + leaves: core::iter::Peekable>, +} +impl<'s> SubtreeLeavesIter<'s> { + fn from_leaves(leaves: &'s mut Vec) -> Self { + // TODO: determine if there is any notable performance difference between taking a Vec, + // which many need flattening first, vs storing a `Box>`. + // The latter may have self-referential properties that are impossible to express in purely + // safe Rust Rust. + Self { leaves: leaves.drain(..).peekable() } + } +} +impl Iterator for SubtreeLeavesIter<'_> { + type Item = Vec; + + /// Each `next()` collects an entire subtree. + fn next(&mut self) -> Option> { + let mut subtree: Vec = Default::default(); + + let mut last_subtree_col = 0; + + while let Some(leaf) = self.leaves.peek() { + last_subtree_col = u64::max(1, last_subtree_col); + let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE); + let next_subtree_col = if is_exact_multiple { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) + }; + + last_subtree_col = leaf.col; + if leaf.col < next_subtree_col { + subtree.push(self.leaves.next().unwrap()); + } else if subtree.is_empty() { + continue; + } else { + break; + } + } + + if subtree.is_empty() { + debug_assert!(self.leaves.peek().is_none()); + return None; + } + + Some(subtree) + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and +/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and +/// `leaves` must not contain more than one depth-8 subtree's worth of leaves. +/// +/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as +/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into +/// itself. +/// +/// # Panics +/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains +/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to +/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified +/// maximum depth (`DEPTH`), or if `leaves` is not sorted. +#[cfg(feature = "concurrent")] +pub(crate) fn build_subtree( + mut leaves: Vec, + tree_depth: u8, + bottom_depth: u8, +) -> (UnorderedMap, SubtreeLeaf) { + debug_assert!(bottom_depth <= tree_depth); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); + debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); + let subtree_root = bottom_depth - SUBTREE_DEPTH; + let mut inner_nodes: UnorderedMap = Default::default(); + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + for next_depth in (subtree_root..bottom_depth).rev() { + debug_assert!(next_depth <= bottom_depth); + // `next_depth` is the stuff we're making. + // `current_depth` is the stuff we have. + let current_depth = next_depth + 1; + let mut iter = leaves.drain(..).peekable(); + while let Some(first) = iter.next() { + // On non-continuous iterations, including the first iteration, `first_column` may + // be a left or right node. On subsequent continuous iterations, we will always call + // `iter.next()` twice. + // On non-continuous iterations (including the very first iteration), this column + // could be either on the left or the right. If the next iteration is not + // discontinuous with our right node, then the next iteration's + let is_right = first.col.is_odd(); + let (left, right) = if is_right { + // Discontinuous iteration: we have no left node, so it must be empty. + let left = SubtreeLeaf { + col: first.col - 1, + hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), + }; + let right = first; + (left, right) + } else { + let left = first; + let right_col = first.col + 1; + let right = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => { + // Our inputs must be sorted. + debug_assert!(left.col <= col); + // The next leaf in the iterator is our sibling. Use it and consume it! + iter.next().unwrap() + }, + // Otherwise, the leaves don't contain our sibling, so our sibling must be + // empty. + _ => SubtreeLeaf { + col: right_col, + hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), + }, + }; + (left, right) + }; + let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); + let node = InnerNode { left: left.hash, right: right.hash }; + let hash = node.hash(); + let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth); + // If this hash is empty, then it doesn't become a new inner node, nor does it count + // as a leaf for the next depth. + if hash != equivalent_empty_hash { + inner_nodes.insert(index, node); + next_leaves.push(SubtreeLeaf { col: index.value(), hash }); + } + } + // Stop borrowing `leaves`, so we can swap it. + // The iterator is empty at this point anyway. + drop(iter); + // After each depth, consider the stuff we just made the new "leaves", and empty the + // other collection. + mem::swap(&mut leaves, &mut next_leaves); + } + debug_assert_eq!(leaves.len(), 1); + let root = leaves.pop().unwrap(); + (inner_nodes, root) +} + +#[cfg(feature = "internal")] +pub fn build_subtree_for_bench( + leaves: Vec, + tree_depth: u8, + bottom_depth: u8, +) -> (UnorderedMap, SubtreeLeaf) { + build_subtree(leaves, tree_depth, bottom_depth) +} diff --git a/src/merkle/smt/full/concurrent/tests.rs b/src/merkle/smt/full/concurrent/tests.rs new file mode 100644 index 00000000..c000a245 --- /dev/null +++ b/src/merkle/smt/full/concurrent/tests.rs @@ -0,0 +1,446 @@ +use alloc::{ + collections::{BTreeMap, BTreeSet}, + vec::Vec, +}; + +use rand::{prelude::IteratorRandom, thread_rng, Rng}; + +use super::{ + build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest, + Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE, + SMT_DEPTH, SUBTREE_DEPTH, +}; +use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE}; + +fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { + SubtreeLeaf { + col: leaf.index().index.value(), + hash: leaf.hash(), + } +} + +#[test] +fn test_sorted_pairs_to_leaves() { + let entries: Vec<(RpoDigest, Word)> = vec![ + // Subtree 0. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), + // Leaf index collision. + (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), + // Subtree 1. Normal single leaf again. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]), + // Subtree 2. Another normal leaf. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]), + ]; + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + let control_leaves: Vec = { + let mut entries_iter = entries.iter().cloned(); + let mut next_entry = || entries_iter.next().unwrap(); + let control_leaves = vec![ + // Subtree 0. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), + // Subtree 1. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + // Subtree 2. + SmtLeaf::Single(next_entry()), + ]; + assert_eq!(entries_iter.next(), None); + control_leaves + }; + let control_subtree_leaves: Vec> = { + let mut control_leaves_iter = control_leaves.iter(); + let mut next_leaf = || control_leaves_iter.next().unwrap(); + let control_subtree_leaves: Vec> = [ + // Subtree 0. + vec![next_leaf(), next_leaf(), next_leaf()], + // Subtree 1. + vec![next_leaf(), next_leaf()], + // Subtree 2. + vec![next_leaf()], + ] + .map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect()) + .to_vec(); + assert_eq!(control_leaves_iter.next(), None); + control_subtree_leaves + }; + let subtrees: PairComputations = Smt::sorted_pairs_to_leaves(entries); + // This will check that the hashes, columns, and subtree assignments all match. + assert_eq!(subtrees.leaves, control_subtree_leaves); + // Flattening and re-separating out the leaves into subtrees should have the same result. + let mut all_leaves: Vec = subtrees.leaves.clone().into_iter().flatten().collect(); + let re_grouped: Vec> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + assert_eq!(subtrees.leaves, re_grouped); + // Then finally we might as well check the computed leaf nodes too. + let control_leaves: BTreeMap = control + .leaves() + .map(|(index, value)| (index.index.value(), value.clone())) + .collect(); + for (column, test_leaf) in subtrees.nodes { + if test_leaf.is_empty() { + continue; + } + let control_leaf = control_leaves + .get(&column) + .unwrap_or_else(|| panic!("no leaf node found for column {column}")); + assert_eq!(control_leaf, &test_leaf); + } +} +// Helper for the below tests. +fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { + (0..pair_count) + .map(|i| { + let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64; + let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); + let value = [ONE, ONE, ONE, Felt::new(i)]; + (key, value) + }) + .collect() +} +fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> { + const REMOVAL_PROBABILITY: f64 = 0.2; + let mut rng = thread_rng(); + // Assertion to ensure input keys are unique + assert!( + entries.iter().map(|(key, _)| key).collect::>().len() == entries.len(), + "Input entries contain duplicate keys!" + ); + let mut sorted_entries: Vec<(RpoDigest, Word)> = entries + .into_iter() + .choose_multiple(&mut rng, updates) + .into_iter() + .map(|(key, _)| { + let value = if rng.gen_bool(REMOVAL_PROBABILITY) { + EMPTY_WORD + } else { + [ONE, ONE, ONE, Felt::new(rng.gen())] + }; + (key, value) + }) + .collect(); + sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value()); + sorted_entries +} +#[test] +fn test_single_subtree() { + // A single subtree's worth of leaves. + const PAIR_COUNT: u64 = COLS_PER_SUBTREE; + let entries = generate_entries(PAIR_COUNT); + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + // `entries` should already be sorted by nature of how we constructed it. + let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; + let leaves = leaves.into_iter().next().unwrap(); + let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH); + assert!(!first_subtree.is_empty()); + // The inner nodes computed from that subtree should match the nodes in our control tree. + for (index, node) in first_subtree.into_iter() { + let control = control.get_inner_node(index); + assert_eq!( + control, node, + "subtree-computed node at index {index:?} does not match control", + ); + } + // The root returned should also match the equivalent node in the control tree. + let control_root_index = + NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index"); + let control_root_node = control.get_inner_node(control_root_index); + let control_hash = control_root_node.hash(); + assert_eq!( + control_hash, subtree_root.hash, + "Subtree-computed root at index {control_root_index:?} does not match control" + ); +} +// Test that not just can we compute a subtree correctly, but we can feed the results of one +// subtree into computing another. In other words, test that `build_subtree()` is correctly +// composable. +#[test] +fn test_two_subtrees() { + // Two subtrees' worth of leaves. + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2; + let entries = generate_entries(PAIR_COUNT); + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); + // With two subtrees' worth of leaves, we should have exactly two subtrees. + let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap(); + assert_eq!(first.len() as u64, PAIR_COUNT / 2); + assert_eq!(first.len(), second.len()); + let mut current_depth = SMT_DEPTH; + let mut next_leaves: Vec = Default::default(); + let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth); + next_leaves.push(first_root); + let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth); + next_leaves.push(second_root); + // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. + let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len(); + assert_eq!(total_computed as u64, PAIR_COUNT); + // Verify the computed nodes of both subtrees. + let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes); + for (index, test_node) in computed_nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + current_depth -= SUBTREE_DEPTH; + let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth); + assert_eq!(nodes.len(), SUBTREE_DEPTH as usize); + assert_eq!(root_leaf.col, 0); + for (index, test_node) in nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap(); + let control_root = control.get_inner_node(index).hash(); + assert_eq!(control_root, root_leaf.hash, "Root mismatch"); +} +#[test] +fn test_singlethreaded_subtrees() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + let entries = generate_entries(PAIR_COUNT); + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + let mut accumulated_nodes: BTreeMap = Default::default(); + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = Smt::sorted_pairs_to_leaves(entries); + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // There's no flat_map_unzip(), so this is the best we can do. + let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees + .into_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + // Do actual things. + let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth); + // Post-assertions. + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + (nodes, subtree_root) + }) + .unzip(); + // Update state between each depth iteration. + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + // Make sure the true leaves match, first checking length and then checking each individual + // leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let index = LeafIndex::new_max_depth(col); + let &control_leaf = control_leaves.get(&index).unwrap(); + assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control"); + } + // Make sure the inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + // After the last iteration of the above for loop, we should have the new root node actually + // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + let control_root = control.get_inner_node(NodeIndex::root()); + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), root_leaf.hash); + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), root_leaf.hash); +} +/// The parallel version of `test_singlethreaded_subtree()`. +#[test] +fn test_multithreaded_subtrees() { + use rayon::prelude::*; + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + let entries = generate_entries(PAIR_COUNT); + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + let mut accumulated_nodes: BTreeMap = Default::default(); + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = Smt::sorted_pairs_to_leaves(entries); + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees + .into_par_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth); + // Post-assertions. + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + (nodes, subtree_root) + }) + .unzip(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + // Make sure the true leaves match, checking length first and then each individual leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let index = LeafIndex::new_max_depth(col); + let &control_leaf = control_leaves.get(&index).unwrap(); + assert_eq!(test_leaf, control_leaf); + } + // Make sure the inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + // After the last iteration of the above for loop, we should have the new root node actually + // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + let control_root = control.get_inner_node(NodeIndex::root()); + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), root_leaf.hash); + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), root_leaf.hash); +} +#[test] +fn test_with_entries_concurrent() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + let entries = generate_entries(PAIR_COUNT); + let control = Smt::with_entries_sequential(entries.clone()).unwrap(); + let smt = Smt::with_entries(entries.clone()).unwrap(); + assert_eq!(smt.root(), control.root()); + assert_eq!(smt, control); +} +/// Concurrent mutations +#[test] +fn test_singlethreaded_subtree_mutations() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + let entries = generate_entries(PAIR_COUNT); + let updates = generate_updates(entries.clone(), 1000); + let tree = Smt::with_entries_sequential(entries.clone()).unwrap(); + let control = tree.compute_mutations_sequential(updates.clone()); + let mut node_mutations = NodeMutations::default(); + let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates); + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // There's no flat_map_unzip(), so this is the best we can do. + let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves + .into_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + // Calculate the mutations for this subtree. + let (mutations_per_subtree, subtree_root) = + tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth); + // Check that the mutations match the control tree. + for (&index, mutation) in mutations_per_subtree.iter() { + let control_mutation = control.node_mutations().get(&index).unwrap(); + assert_eq!( + control_mutation, mutation, + "depth {} subtree {}: mutation does not match control at index {:?}", + current_depth, i, index, + ); + } + (mutations_per_subtree, subtree_root) + }) + .unzip(); + subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + node_mutations.extend(mutations_per_subtree.into_iter().flatten()); + assert!(!subtree_leaves.is_empty(), "on depth {current_depth}"); + } + let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap(); + let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap(); + // Check that the new root matches the control. + assert_eq!(control.new_root, root_leaf.hash); + // Check that the node mutations match the control. + assert_eq!(control.node_mutations().len(), node_mutations.len()); + for (&index, mutation) in control.node_mutations().iter() { + let test_mutation = node_mutations.get(&index).unwrap(); + assert_eq!(test_mutation, mutation); + } + // Check that the new pairs match the control + assert_eq!(control.new_pairs.len(), new_pairs.len()); + for (&key, &value) in control.new_pairs.iter() { + let test_value = new_pairs.get(&key).unwrap(); + assert_eq!(test_value, &value); + } +} +#[test] +fn test_compute_mutations_parallel() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + let entries = generate_entries(PAIR_COUNT); + let tree = Smt::with_entries(entries.clone()).unwrap(); + let updates = generate_updates(entries, 1000); + let control = tree.compute_mutations_sequential(updates.clone()); + let mutations = tree.compute_mutations(updates); + assert_eq!(mutations.root(), control.root()); + assert_eq!(mutations.old_root(), control.old_root()); + assert_eq!(mutations.node_mutations(), control.node_mutations()); + assert_eq!(mutations.new_pairs(), control.new_pairs()); +} diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 04ed7124..35ed15b8 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -1,12 +1,8 @@ -use alloc::{collections::BTreeSet, string::ToString, vec::Vec}; -use core::mem; - -use num::Integer; +use alloc::{string::ToString, vec::Vec}; use super::{ EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError, - MerklePath, MutationSet, NodeIndex, NodeMutation, NodeMutations, Rpo256, RpoDigest, - SparseMerkleTree, UnorderedMap, Word, EMPTY_WORD, + MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, }; mod error; @@ -31,7 +27,6 @@ pub const SMT_DEPTH: u8 = 64; // ================================================================================================ type Leaves = super::Leaves; -type MutatedSubtreeLeaves = Vec>; /// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented /// by 4 field elements. @@ -101,9 +96,12 @@ impl Smt { /// /// # Errors /// Returns an error if the provided entries contain multiple values for the same key. - pub fn with_entries_sequential( + #[cfg(any(not(feature = "concurrent"), test))] + fn with_entries_sequential( entries: impl IntoIterator, ) -> Result { + use alloc::collections::BTreeSet; + // create an empty tree let mut tree = Self::new(); @@ -321,379 +319,11 @@ impl Smt { // Concurrent implementation #[cfg(feature = "concurrent")] -impl Smt { - /// Parallel implementation of [`Smt::with_entries()`]. - /// - /// This method constructs a new sparse Merkle tree concurrently by processing subtrees in - /// parallel, working from the bottom up. The process works as follows: - /// - /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf - /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. - /// - /// 2. The subtrees are then processed in parallel: - /// - For each subtree, compute the inner nodes from depth D down to depth D-8. - /// - Each subtree computation yields a new subtree root and its associated inner nodes. - /// - /// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration, - /// which processes the next 8 levels up. This continues until the final root of the tree is - /// computed at depth 0. - pub fn with_entries_concurrent( - entries: impl IntoIterator, - ) -> Result { - let mut seen_keys = BTreeSet::new(); - let entries: Vec<_> = entries - .into_iter() - .map(|(key, value)| { - if seen_keys.insert(key) { - Ok((key, value)) - } else { - Err(MerkleError::DuplicateValuesForIndex( - LeafIndex::::from(key).value(), - )) - } - }) - .collect::>()?; - if entries.is_empty() { - return Ok(Self::default()); - } - let (inner_nodes, leaves) = Self::build_subtrees(entries); - let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); - >::from_raw_parts(inner_nodes, leaves, root) - } - - /// Parallel implementation of [`Smt::compute_mutations()`]. - /// - /// This method computes mutations by recursively processing subtrees in parallel, working from - /// the bottom up. The process works as follows: - /// - /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf - /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. - /// - /// 2. The subtrees containing modifications are then processed in parallel: - /// - For each modified subtree, compute node mutations from depth D up to depth D-8 - /// - Each subtree computation yields a new root at depth D-8 and its associated mutations - /// - /// 3. These subtree roots become the "leaves" for the next iteration, which processes the next - /// 8 levels up. This continues until reaching the tree's root at depth 0. - pub fn compute_mutations_concurrent( - &self, - kv_pairs: impl IntoIterator, - ) -> MutationSet - where - Self: Sized + Sync, - { - use rayon::prelude::*; - - // Collect and sort key-value pairs by their corresponding leaf index - let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); - sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); - - // Convert sorted pairs into mutated leaves and capture any new pairs - let (mut subtree_leaves, new_pairs) = - self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs); - let mut node_mutations = NodeMutations::default(); - - // Process each depth level in reverse, stepping by the subtree depth - for depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - // Parallel processing of each subtree to generate mutations and roots - let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves - .into_par_iter() - .map(|subtree| { - debug_assert!(subtree.is_sorted() && !subtree.is_empty()); - self.build_subtree_mutations(subtree, SMT_DEPTH, depth) - }) - .unzip(); - - // Prepare leaves for the next depth level - subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - - // Aggregate all node mutations - node_mutations.extend(mutations_per_subtree.into_iter().flatten()); - - debug_assert!(!subtree_leaves.is_empty()); - } - - // Finalize the mutation set with updated roots and mutations - MutationSet { - old_root: self.root(), - new_root: subtree_leaves[0][0].hash, - node_mutations, - new_pairs, - } - } - - /// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing - /// subtrees. In other words, this function takes the key-value inputs to the tree, and produces - /// the inputs to feed into [`build_subtree()`]. - /// - /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If - /// `pairs` is not correctly sorted, the returned computations will be incorrect. - /// - /// # Panics - /// With debug assertions on, this function panics if it detects that `pairs` is not correctly - /// sorted. Without debug assertions, the returned computations will be incorrect. - fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations { - Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf) - } - - /// Computes leaves from a set of key-value pairs and current leaf values. - /// Derived from `sorted_pairs_to_leaves` - fn sorted_pairs_to_mutated_subtree_leaves( - &self, - pairs: Vec<(RpoDigest, Word)>, - ) -> (MutatedSubtreeLeaves, UnorderedMap) { - // Map to track new key-value pairs for mutated leaves - let mut new_pairs = UnorderedMap::new(); - - let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| { - let mut leaf = self.get_leaf(&leaf_pairs[0].0); - - for (key, value) in leaf_pairs { - // Check if the value has changed - let old_value = - new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); - - // Skip if the value hasn't changed - if value == old_value { - continue; - } - - // Otherwise, update the leaf and track the new key-value pair - leaf = self.construct_prospective_leaf(leaf, &key, &value); - new_pairs.insert(key, value); - } - - leaf - }); - (accumulator.leaves, new_pairs) - } - - /// Computes the node mutations and the root of a subtree - fn build_subtree_mutations( - &self, - mut leaves: Vec, - tree_depth: u8, - bottom_depth: u8, - ) -> (NodeMutations, SubtreeLeaf) - where - Self: Sized, - { - debug_assert!(bottom_depth <= tree_depth); - debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); - debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); - - let subtree_root_depth = bottom_depth - SUBTREE_DEPTH; - let mut node_mutations: NodeMutations = Default::default(); - let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); - - for current_depth in (subtree_root_depth..bottom_depth).rev() { - debug_assert!(current_depth <= bottom_depth); - - let next_depth = current_depth + 1; - let mut iter = leaves.drain(..).peekable(); - - while let Some(first_leaf) = iter.next() { - // This constructs a valid index because next_depth will never exceed the depth of - // the tree. - let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent(); - let parent_node = self.get_inner_node(parent_index); - let combined_node = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); - let combined_hash = combined_node.hash(); - - let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth); - - // Add the parent node even if it is empty for proper upward updates - next_leaves.push(SubtreeLeaf { - col: parent_index.value(), - hash: combined_hash, - }); - - node_mutations.insert( - parent_index, - if combined_hash != empty_hash { - NodeMutation::Addition(combined_node) - } else { - NodeMutation::Removal - }, - ); - } - drop(iter); - leaves = mem::take(&mut next_leaves); - } - - debug_assert_eq!(leaves.len(), 1); - let root_leaf = leaves.pop().unwrap(); - (node_mutations, root_leaf) - } - - /// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part: - /// - If `first_leaf` is a right child, the left child is copied from the `parent_node`. - /// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also - /// mutated or copied from the `parent_node`. - /// - /// Returns the `InnerNode` containing the hashes of the sibling pair. - fn fetch_sibling_pair( - iter: &mut core::iter::Peekable>, - first_leaf: SubtreeLeaf, - parent_node: InnerNode, - ) -> InnerNode { - let is_right_node = first_leaf.col.is_odd(); - - if is_right_node { - let left_leaf = SubtreeLeaf { - col: first_leaf.col - 1, - hash: parent_node.left, - }; - InnerNode { - left: left_leaf.hash, - right: first_leaf.hash, - } - } else { - let right_col = first_leaf.col + 1; - let right_leaf = match iter.peek().copied() { - Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(), - _ => SubtreeLeaf { col: right_col, hash: parent_node.right }, - }; - InnerNode { - left: first_leaf.hash, - right: right_leaf.hash, - } - } - } - - /// Processes sorted key-value pairs to compute leaves for a subtree. - /// - /// This function groups key-value pairs by their corresponding column index and processes each - /// group to construct leaves. The actual construction of the leaf is delegated to the - /// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating - /// new leaves or mutating existing ones). - /// - /// # Parameters - /// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index - /// column (not simply by key). If the input is not sorted correctly, the function will - /// produce incorrect results and may panic in debug mode. - /// - `process_leaf`: A callback function used to process each group of key-value pairs - /// corresponding to the same column index. The callback takes a vector of key-value pairs for - /// a single column and returns the constructed leaf for that column. - /// - /// # Returns - /// A `PairComputations` containing: - /// - `nodes`: A mapping of column indices to the constructed leaves. - /// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each - /// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf. - /// - /// # Panics - /// This function will panic in debug mode if the input `pairs` are not sorted by column index. - fn process_sorted_pairs_to_leaves( - pairs: Vec<(RpoDigest, Word)>, - mut process_leaf: F, - ) -> PairComputations - where - F: FnMut(Vec<(RpoDigest, Word)>) -> SmtLeaf, - { - use rayon::prelude::*; - debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); - - let mut accumulator: PairComputations = Default::default(); - - // As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a - // single leaf. When we see a pair that's in a different leaf, we'll swap these pairs - // out and store them in our accumulated leaves. - let mut current_leaf_buffer: Vec<(RpoDigest, Word)> = Default::default(); - - let mut iter = pairs.into_iter().peekable(); - while let Some((key, value)) = iter.next() { - let col = Self::key_to_leaf_index(&key).index.value(); - let peeked_col = iter.peek().map(|(key, _v)| { - let index = Self::key_to_leaf_index(key); - let next_col = index.index.value(); - // We panic if `pairs` is not sorted by column. - debug_assert!(next_col >= col); - next_col - }); - current_leaf_buffer.push((key, value)); - - // If the next pair is the same column as this one, then we're done after adding this - // pair to the buffer. - if peeked_col == Some(col) { - continue; - } - - // Otherwise, the next pair is a different column, or there is no next pair. Either way - // it's time to swap out our buffer. - let leaf_pairs = mem::take(&mut current_leaf_buffer); - let leaf = process_leaf(leaf_pairs); - - accumulator.nodes.insert(col, leaf); - - debug_assert!(current_leaf_buffer.is_empty()); - } - - // Compute the leaves from the nodes concurrently - let mut accumulated_leaves: Vec = accumulator - .nodes - .clone() - .into_par_iter() - .map(|(col, leaf)| SubtreeLeaf { col, hash: Self::hash_leaf(&leaf) }) - .collect(); - - // Sort the leaves by column - accumulated_leaves.par_sort_by_key(|leaf| leaf.col); - - // TODO: determine is there is any notable performance difference between computing - // subtree boundaries after the fact as an iterator adapter (like this), versus computing - // subtree boundaries as we go. Either way this function is only used at the beginning of a - // parallel construction, so it should not be a critical path. - accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); - accumulator - } - - /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. - /// - /// `entries` need not be sorted. This function will sort them. - fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { - entries.sort_by_key(|item| { - let index = Self::key_to_leaf_index(&item.0); - index.value() - }); - Self::build_subtrees_from_sorted_entries(entries) - } - - /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. - /// - /// This function is mostly an implementation detail of - /// [`Smt::with_entries_concurrent()`]. - fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { - use rayon::prelude::*; - - let mut accumulated_nodes: InnerNodes = Default::default(); - - let PairComputations { - leaves: mut leaf_subtrees, - nodes: initial_leaves, - } = Self::sorted_pairs_to_leaves(entries); - - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - let (nodes, mut subtree_roots): (Vec>, Vec) = - leaf_subtrees - .into_par_iter() - .map(|subtree| { - debug_assert!(subtree.is_sorted()); - debug_assert!(!subtree.is_empty()); - let (nodes, subtree_root) = - build_subtree(subtree, SMT_DEPTH, current_depth); - (nodes, subtree_root) - }) - .unzip(); - - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - accumulated_nodes.extend(nodes.into_iter().flatten()); - - debug_assert!(!leaf_subtrees.is_empty()); - } - (accumulated_nodes, initial_leaves) - } -} +mod concurrent; +#[cfg(feature = "internal")] +pub use concurrent::build_subtree_for_bench; +#[cfg(feature = "internal")] +pub use concurrent::SubtreeLeaf; impl SparseMerkleTree for Smt { type Key = RpoDigest; @@ -914,194 +544,3 @@ fn test_smt_serialization_deserialization() { assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap()); assert_eq!(bytes.len(), smt.get_size_hint()); } - -// SUBTREES -// ================================================================================================ - -/// A subtree is of depth 8. -const SUBTREE_DEPTH: u8 = 8; - -/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. -const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32); - -/// Helper struct for organizing the data we care about when computing Merkle subtrees. -/// -/// Note that these represet "conceptual" leaves of some subtree, not necessarily -/// the leaf type for the sparse Merkle tree. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] -pub struct SubtreeLeaf { - /// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known. - pub col: u64, - /// The hash of the node this `SubtreeLeaf` represents. - pub hash: RpoDigest, -} - -/// Helper struct to organize the return value of [`Smt::sorted_pairs_to_leaves()`]. -#[derive(Debug, Clone)] -pub(crate) struct PairComputations { - /// Literal leaves to be added to the sparse Merkle tree's internal mapping. - pub nodes: UnorderedMap, - /// "Conceptual" leaves that will be used for computations. - pub leaves: Vec>, -} - -// Derive requires `L` to impl Default, even though we don't actually need that. -impl Default for PairComputations { - fn default() -> Self { - Self { - nodes: Default::default(), - leaves: Default::default(), - } - } -} - -#[derive(Debug)] -struct SubtreeLeavesIter<'s> { - leaves: core::iter::Peekable>, -} -impl<'s> SubtreeLeavesIter<'s> { - fn from_leaves(leaves: &'s mut Vec) -> Self { - // TODO: determine if there is any notable performance difference between taking a Vec, - // which many need flattening first, vs storing a `Box>`. - // The latter may have self-referential properties that are impossible to express in purely - // safe Rust Rust. - Self { leaves: leaves.drain(..).peekable() } - } -} -impl Iterator for SubtreeLeavesIter<'_> { - type Item = Vec; - - /// Each `next()` collects an entire subtree. - fn next(&mut self) -> Option> { - let mut subtree: Vec = Default::default(); - - let mut last_subtree_col = 0; - - while let Some(leaf) = self.leaves.peek() { - last_subtree_col = u64::max(1, last_subtree_col); - let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE); - let next_subtree_col = if is_exact_multiple { - u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) - } else { - last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) - }; - - last_subtree_col = leaf.col; - if leaf.col < next_subtree_col { - subtree.push(self.leaves.next().unwrap()); - } else if subtree.is_empty() { - continue; - } else { - break; - } - } - - if subtree.is_empty() { - debug_assert!(self.leaves.peek().is_none()); - return None; - } - - Some(subtree) - } -} - -// HELPER FUNCTIONS -// ================================================================================================ - -/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and -/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and -/// `leaves` must not contain more than one depth-8 subtree's worth of leaves. -/// -/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as -/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into -/// itself. -/// -/// # Panics -/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains -/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to -/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified -/// maximum depth (`DEPTH`), or if `leaves` is not sorted. -#[cfg(feature = "concurrent")] -fn build_subtree( - mut leaves: Vec, - tree_depth: u8, - bottom_depth: u8, -) -> (UnorderedMap, SubtreeLeaf) { - debug_assert!(bottom_depth <= tree_depth); - debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); - debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); - let subtree_root = bottom_depth - SUBTREE_DEPTH; - let mut inner_nodes: UnorderedMap = Default::default(); - let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); - for next_depth in (subtree_root..bottom_depth).rev() { - debug_assert!(next_depth <= bottom_depth); - // `next_depth` is the stuff we're making. - // `current_depth` is the stuff we have. - let current_depth = next_depth + 1; - let mut iter = leaves.drain(..).peekable(); - while let Some(first) = iter.next() { - // On non-continuous iterations, including the first iteration, `first_column` may - // be a left or right node. On subsequent continuous iterations, we will always call - // `iter.next()` twice. - // On non-continuous iterations (including the very first iteration), this column - // could be either on the left or the right. If the next iteration is not - // discontinuous with our right node, then the next iteration's - let is_right = first.col.is_odd(); - let (left, right) = if is_right { - // Discontinuous iteration: we have no left node, so it must be empty. - let left = SubtreeLeaf { - col: first.col - 1, - hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), - }; - let right = first; - (left, right) - } else { - let left = first; - let right_col = first.col + 1; - let right = match iter.peek().copied() { - Some(SubtreeLeaf { col, .. }) if col == right_col => { - // Our inputs must be sorted. - debug_assert!(left.col <= col); - // The next leaf in the iterator is our sibling. Use it and consume it! - iter.next().unwrap() - }, - // Otherwise, the leaves don't contain our sibling, so our sibling must be - // empty. - _ => SubtreeLeaf { - col: right_col, - hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), - }, - }; - (left, right) - }; - let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); - let node = InnerNode { left: left.hash, right: right.hash }; - let hash = node.hash(); - let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth); - // If this hash is empty, then it doesn't become a new inner node, nor does it count - // as a leaf for the next depth. - if hash != equivalent_empty_hash { - inner_nodes.insert(index, node); - next_leaves.push(SubtreeLeaf { col: index.value(), hash }); - } - } - // Stop borrowing `leaves`, so we can swap it. - // The iterator is empty at this point anyway. - drop(iter); - // After each depth, consider the stuff we just made the new "leaves", and empty the - // other collection. - mem::swap(&mut leaves, &mut next_leaves); - } - debug_assert_eq!(leaves.len(), 1); - let root = leaves.pop().unwrap(); - (inner_nodes, root) -} - -#[cfg(feature = "internal")] -pub fn build_subtree_for_bench( - leaves: Vec, - tree_depth: u8, - bottom_depth: u8, -) -> (UnorderedMap, SubtreeLeaf) { - build_subtree(leaves, tree_depth, bottom_depth) -} diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index 0fe14076..787c01a8 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -722,557 +722,3 @@ fn apply_mutations( reversion } - -// CONCURRENT TESTS -// -------------------------------------------------------------------------------------------- - -#[cfg(feature = "concurrent")] -mod concurrent_tests { - use alloc::{ - collections::{BTreeMap, BTreeSet}, - vec::Vec, - }; - - use rand::{prelude::IteratorRandom, thread_rng, Rng}; - - use super::*; - use crate::{ - merkle::smt::{ - full::{ - build_subtree, PairComputations, SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, - SUBTREE_DEPTH, - }, - InnerNode, NodeMutations, SparseMerkleTree, UnorderedMap, - }, - Word, ONE, - }; - - fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { - SubtreeLeaf { - col: leaf.index().index.value(), - hash: leaf.hash(), - } - } - - #[test] - fn test_sorted_pairs_to_leaves() { - let entries: Vec<(RpoDigest, Word)> = vec![ - // Subtree 0. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), - (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), - // Leaf index collision. - (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), - (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), - // Subtree 1. Normal single leaf again. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]), - // Subtree 2. Another normal leaf. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]), - ]; - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let control_leaves: Vec = { - let mut entries_iter = entries.iter().cloned(); - let mut next_entry = || entries_iter.next().unwrap(); - let control_leaves = vec![ - // Subtree 0. - SmtLeaf::Single(next_entry()), - SmtLeaf::Single(next_entry()), - SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), - // Subtree 1. - SmtLeaf::Single(next_entry()), - SmtLeaf::Single(next_entry()), - // Subtree 2. - SmtLeaf::Single(next_entry()), - ]; - assert_eq!(entries_iter.next(), None); - control_leaves - }; - - let control_subtree_leaves: Vec> = { - let mut control_leaves_iter = control_leaves.iter(); - let mut next_leaf = || control_leaves_iter.next().unwrap(); - - let control_subtree_leaves: Vec> = [ - // Subtree 0. - vec![next_leaf(), next_leaf(), next_leaf()], - // Subtree 1. - vec![next_leaf(), next_leaf()], - // Subtree 2. - vec![next_leaf()], - ] - .map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect()) - .to_vec(); - assert_eq!(control_leaves_iter.next(), None); - control_subtree_leaves - }; - - let subtrees: PairComputations = Smt::sorted_pairs_to_leaves(entries); - // This will check that the hashes, columns, and subtree assignments all match. - assert_eq!(subtrees.leaves, control_subtree_leaves); - - // Flattening and re-separating out the leaves into subtrees should have the same result. - let mut all_leaves: Vec = - subtrees.leaves.clone().into_iter().flatten().collect(); - let re_grouped: Vec> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); - assert_eq!(subtrees.leaves, re_grouped); - - // Then finally we might as well check the computed leaf nodes too. - let control_leaves: BTreeMap = control - .leaves() - .map(|(index, value)| (index.index.value(), value.clone())) - .collect(); - - for (column, test_leaf) in subtrees.nodes { - if test_leaf.is_empty() { - continue; - } - let control_leaf = control_leaves - .get(&column) - .unwrap_or_else(|| panic!("no leaf node found for column {column}")); - assert_eq!(control_leaf, &test_leaf); - } - } - - // Helper for the below tests. - fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { - (0..pair_count) - .map(|i| { - let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64; - let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); - let value = [ONE, ONE, ONE, Felt::new(i)]; - (key, value) - }) - .collect() - } - - fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> { - const REMOVAL_PROBABILITY: f64 = 0.2; - let mut rng = thread_rng(); - - // Assertion to ensure input keys are unique - assert!( - entries.iter().map(|(key, _)| key).collect::>().len() == entries.len(), - "Input entries contain duplicate keys!" - ); - - let mut sorted_entries: Vec<(RpoDigest, Word)> = entries - .into_iter() - .choose_multiple(&mut rng, updates) - .into_iter() - .map(|(key, _)| { - let value = if rng.gen_bool(REMOVAL_PROBABILITY) { - EMPTY_WORD - } else { - [ONE, ONE, ONE, Felt::new(rng.gen())] - }; - - (key, value) - }) - .collect(); - sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value()); - sorted_entries - } - - #[test] - fn test_single_subtree() { - // A single subtree's worth of leaves. - const PAIR_COUNT: u64 = COLS_PER_SUBTREE; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - // `entries` should already be sorted by nature of how we constructed it. - let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; - let leaves = leaves.into_iter().next().unwrap(); - - let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH); - assert!(!first_subtree.is_empty()); - - // The inner nodes computed from that subtree should match the nodes in our control tree. - for (index, node) in first_subtree.into_iter() { - let control = control.get_inner_node(index); - assert_eq!( - control, node, - "subtree-computed node at index {index:?} does not match control", - ); - } - - // The root returned should also match the equivalent node in the control tree. - let control_root_index = - NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index"); - let control_root_node = control.get_inner_node(control_root_index); - let control_hash = control_root_node.hash(); - assert_eq!( - control_hash, subtree_root.hash, - "Subtree-computed root at index {control_root_index:?} does not match control" - ); - } - - // Test that not just can we compute a subtree correctly, but we can feed the results of one - // subtree into computing another. In other words, test that `build_subtree()` is correctly - // composable. - #[test] - fn test_two_subtrees() { - // Two subtrees' worth of leaves. - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); - // With two subtrees' worth of leaves, we should have exactly two subtrees. - let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap(); - assert_eq!(first.len() as u64, PAIR_COUNT / 2); - assert_eq!(first.len(), second.len()); - - let mut current_depth = SMT_DEPTH; - let mut next_leaves: Vec = Default::default(); - - let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth); - next_leaves.push(first_root); - - let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth); - next_leaves.push(second_root); - - // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. - let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len(); - assert_eq!(total_computed as u64, PAIR_COUNT); - - // Verify the computed nodes of both subtrees. - let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes); - for (index, test_node) in computed_nodes { - let control_node = control.get_inner_node(index); - assert_eq!( - control_node, test_node, - "subtree-computed node at index {index:?} does not match control", - ); - } - - current_depth -= SUBTREE_DEPTH; - - let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth); - assert_eq!(nodes.len(), SUBTREE_DEPTH as usize); - assert_eq!(root_leaf.col, 0); - - for (index, test_node) in nodes { - let control_node = control.get_inner_node(index); - assert_eq!( - control_node, test_node, - "subtree-computed node at index {index:?} does not match control", - ); - } - - let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap(); - let control_root = control.get_inner_node(index).hash(); - assert_eq!(control_root, root_leaf.hash, "Root mismatch"); - } - - #[test] - fn test_singlethreaded_subtrees() { - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let mut accumulated_nodes: BTreeMap = Default::default(); - - let PairComputations { - leaves: mut leaf_subtrees, - nodes: test_leaves, - } = Smt::sorted_pairs_to_leaves(entries); - - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - // There's no flat_map_unzip(), so this is the best we can do. - let (nodes, mut subtree_roots): (Vec>, Vec) = - leaf_subtrees - .into_iter() - .enumerate() - .map(|(i, subtree)| { - // Pre-assertions. - assert!( - subtree.is_sorted(), - "subtree {i} at bottom-depth {current_depth} is not sorted", - ); - assert!( - !subtree.is_empty(), - "subtree {i} at bottom-depth {current_depth} is empty!", - ); - - // Do actual things. - let (nodes, subtree_root) = - build_subtree(subtree, SMT_DEPTH, current_depth); - - // Post-assertions. - for (&index, test_node) in nodes.iter() { - let control_node = control.get_inner_node(index); - assert_eq!( - test_node, &control_node, - "depth {} subtree {}: test node does not match control at index {:?}", - current_depth, i, index, - ); - } - - (nodes, subtree_root) - }) - .unzip(); - - // Update state between each depth iteration. - - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - accumulated_nodes.extend(nodes.into_iter().flatten()); - - assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); - } - - // Make sure the true leaves match, first checking length and then checking each individual - // leaf. - let control_leaves: BTreeMap<_, _> = control.leaves().collect(); - let control_leaves_len = control_leaves.len(); - let test_leaves_len = test_leaves.len(); - assert_eq!(test_leaves_len, control_leaves_len); - for (col, ref test_leaf) in test_leaves { - let index = LeafIndex::new_max_depth(col); - let &control_leaf = control_leaves.get(&index).unwrap(); - assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control"); - } - - // Make sure the inner nodes match, checking length first and then each individual leaf. - let control_nodes_len = control.inner_nodes().count(); - let test_nodes_len = accumulated_nodes.len(); - assert_eq!(test_nodes_len, control_nodes_len); - for (index, test_node) in accumulated_nodes.clone() { - let control_node = control.get_inner_node(index); - assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); - } - - // After the last iteration of the above for loop, we should have the new root node actually - // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from - // `build_subtree()`. So let's check both! - - let control_root = control.get_inner_node(NodeIndex::root()); - - // That for loop should have left us with only one leaf subtree... - let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap(); - // which itself contains only one 'leaf'... - let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap(); - // which matches the expected root. - assert_eq!(control.root(), root_leaf.hash); - - // Likewise `accumulated_nodes` should contain a node at the root index... - assert!(accumulated_nodes.contains_key(&NodeIndex::root())); - // and it should match our actual root. - let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); - assert_eq!(control_root, *test_root); - // And of course the root we got from each place should match. - assert_eq!(control.root(), root_leaf.hash); - } - - /// The parallel version of `test_singlethreaded_subtree()`. - #[test] - fn test_multithreaded_subtrees() { - use rayon::prelude::*; - - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let mut accumulated_nodes: BTreeMap = Default::default(); - - let PairComputations { - leaves: mut leaf_subtrees, - nodes: test_leaves, - } = Smt::sorted_pairs_to_leaves(entries); - - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - let (nodes, mut subtree_roots): (Vec>, Vec) = - leaf_subtrees - .into_par_iter() - .enumerate() - .map(|(i, subtree)| { - // Pre-assertions. - assert!( - subtree.is_sorted(), - "subtree {i} at bottom-depth {current_depth} is not sorted", - ); - assert!( - !subtree.is_empty(), - "subtree {i} at bottom-depth {current_depth} is empty!", - ); - - let (nodes, subtree_root) = - build_subtree(subtree, SMT_DEPTH, current_depth); - - // Post-assertions. - for (&index, test_node) in nodes.iter() { - let control_node = control.get_inner_node(index); - assert_eq!( - test_node, &control_node, - "depth {} subtree {}: test node does not match control at index {:?}", - current_depth, i, index, - ); - } - - (nodes, subtree_root) - }) - .unzip(); - - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - accumulated_nodes.extend(nodes.into_iter().flatten()); - - assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); - } - - // Make sure the true leaves match, checking length first and then each individual leaf. - let control_leaves: BTreeMap<_, _> = control.leaves().collect(); - let control_leaves_len = control_leaves.len(); - let test_leaves_len = test_leaves.len(); - assert_eq!(test_leaves_len, control_leaves_len); - for (col, ref test_leaf) in test_leaves { - let index = LeafIndex::new_max_depth(col); - let &control_leaf = control_leaves.get(&index).unwrap(); - assert_eq!(test_leaf, control_leaf); - } - - // Make sure the inner nodes match, checking length first and then each individual leaf. - let control_nodes_len = control.inner_nodes().count(); - let test_nodes_len = accumulated_nodes.len(); - assert_eq!(test_nodes_len, control_nodes_len); - for (index, test_node) in accumulated_nodes.clone() { - let control_node = control.get_inner_node(index); - assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); - } - - // After the last iteration of the above for loop, we should have the new root node actually - // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from - // `build_subtree()`. So let's check both! - - let control_root = control.get_inner_node(NodeIndex::root()); - - // That for loop should have left us with only one leaf subtree... - let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); - // which itself contains only one 'leaf'... - let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap(); - // which matches the expected root. - assert_eq!(control.root(), root_leaf.hash); - - // Likewise `accumulated_nodes` should contain a node at the root index... - assert!(accumulated_nodes.contains_key(&NodeIndex::root())); - // and it should match our actual root. - let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); - assert_eq!(control_root, *test_root); - // And of course the root we got from each place should match. - assert_eq!(control.root(), root_leaf.hash); - } - - #[test] - fn test_with_entries_concurrent() { - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - - let smt = Smt::with_entries(entries.clone()).unwrap(); - assert_eq!(smt.root(), control.root()); - assert_eq!(smt, control); - } - - /// Concurrent mutations - #[test] - fn test_singlethreaded_subtree_mutations() { - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - let updates = generate_updates(entries.clone(), 1000); - - let tree = Smt::with_entries_sequential(entries.clone()).unwrap(); - let control = tree.compute_mutations_sequential(updates.clone()); - - let mut node_mutations = NodeMutations::default(); - - let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates); - - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - // There's no flat_map_unzip(), so this is the best we can do. - let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves - .into_iter() - .enumerate() - .map(|(i, subtree)| { - // Pre-assertions. - assert!( - subtree.is_sorted(), - "subtree {i} at bottom-depth {current_depth} is not sorted", - ); - assert!( - !subtree.is_empty(), - "subtree {i} at bottom-depth {current_depth} is empty!", - ); - - // Calculate the mutations for this subtree. - let (mutations_per_subtree, subtree_root) = - tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth); - - // Check that the mutations match the control tree. - for (&index, mutation) in mutations_per_subtree.iter() { - let control_mutation = control.node_mutations().get(&index).unwrap(); - assert_eq!( - control_mutation, mutation, - "depth {} subtree {}: mutation does not match control at index {:?}", - current_depth, i, index, - ); - } - - (mutations_per_subtree, subtree_root) - }) - .unzip(); - - subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - node_mutations.extend(mutations_per_subtree.into_iter().flatten()); - - assert!(!subtree_leaves.is_empty(), "on depth {current_depth}"); - } - - let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap(); - let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap(); - // Check that the new root matches the control. - assert_eq!(control.new_root, root_leaf.hash); - - // Check that the node mutations match the control. - assert_eq!(control.node_mutations().len(), node_mutations.len()); - for (&index, mutation) in control.node_mutations().iter() { - let test_mutation = node_mutations.get(&index).unwrap(); - assert_eq!(test_mutation, mutation); - } - // Check that the new pairs match the control - assert_eq!(control.new_pairs.len(), new_pairs.len()); - for (&key, &value) in control.new_pairs.iter() { - let test_value = new_pairs.get(&key).unwrap(); - assert_eq!(test_value, &value); - } - } - - #[test] - fn test_compute_mutations_parallel() { - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - - let entries = generate_entries(PAIR_COUNT); - let tree = Smt::with_entries(entries.clone()).unwrap(); - - let updates = generate_updates(entries, 1000); - - let control = tree.compute_mutations_sequential(updates.clone()); - let mutations = tree.compute_mutations(updates); - - assert_eq!(mutations.root(), control.root()); - assert_eq!(mutations.old_root(), control.old_root()); - assert_eq!(mutations.node_mutations(), control.node_mutations()); - assert_eq!(mutations.new_pairs(), control.new_pairs()); - } -} From 97276525e135dbe283ee6b118068360b57964f06 Mon Sep 17 00:00:00 2001 From: krushimir Date: Thu, 6 Feb 2025 17:05:20 +0100 Subject: [PATCH 13/13] chore: addressing comments --- src/merkle/mod.rs | 4 +--- src/merkle/smt/full/mod.rs | 14 ++++++-------- src/merkle/smt/mod.rs | 4 +--- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 07aeded0..54ced4dc 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -22,9 +22,7 @@ pub use path::{MerklePath, RootPath, ValuePath}; mod smt; #[cfg(feature = "internal")] -pub use smt::build_subtree_for_bench; -#[cfg(feature = "internal")] -pub use smt::SubtreeLeaf; +pub use smt::{build_subtree_for_bench, SubtreeLeaf}; pub use smt::{ InnerNode, LeafIndex, MutationSet, NodeMutation, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 35ed15b8..70f69a58 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -15,6 +15,12 @@ mod proof; pub use proof::SmtProof; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +// Concurrent implementation +#[cfg(feature = "concurrent")] +mod concurrent; +#[cfg(feature = "internal")] +pub use concurrent::{build_subtree_for_bench, SubtreeLeaf}; + #[cfg(test)] mod tests; @@ -317,14 +323,6 @@ impl Smt { } } -// Concurrent implementation -#[cfg(feature = "concurrent")] -mod concurrent; -#[cfg(feature = "internal")] -pub use concurrent::build_subtree_for_bench; -#[cfg(feature = "internal")] -pub use concurrent::SubtreeLeaf; - impl SparseMerkleTree for Smt { type Key = RpoDigest; type Value = Word; diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 580e08a1..aae0ea96 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -11,9 +11,7 @@ use crate::{ mod full; #[cfg(feature = "internal")] -pub use full::build_subtree_for_bench; -#[cfg(feature = "internal")] -pub use full::SubtreeLeaf; +pub use full::{build_subtree_for_bench, SubtreeLeaf}; pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH}; mod simple;