Skip to content

Commit

Permalink
WIP(smt): add simple benchmark for single subtree computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Qyriad committed Oct 16, 2024
1 parent a279ade commit 56087c7
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 8 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ harness = false
name = "store"
harness = false

[[bench]]
name = "subtree"
harness = false

[features]
default = ["std", "async"]
executable = ["dep:clap", "dep:rand-utils", "std"]
Expand Down
66 changes: 66 additions & 0 deletions benches/subtree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::{collections::BTreeMap, sync::Arc, time::Duration};

use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{NodeIndex, NodeSubtreeComputer, Smt, SparseMerkleTree},
Felt, Word, ONE,
};

const SUBTREE_INTERVAL: u8 = 8;

fn setup_subtree8(tree_size: u64) -> (Smt, NodeIndex, Arc<BTreeMap<RpoDigest, Word>>, RpoDigest) {
let entries: BTreeMap<RpoDigest, Word> = (0..tree_size)
.into_iter()
.map(|i| {
let leaf_index = u64::MAX / (i + 1);
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
let value = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect();
let control = Smt::with_entries(entries.clone()).unwrap();
let subtree = entries
.keys()
.map(|key| {
let index_for_key = NodeIndex::from(Smt::key_to_leaf_index(key));
index_for_key.parent_n(SUBTREE_INTERVAL)
})
.next()
.unwrap();
let control_hash = control.get_inner_node(subtree).hash();
(Smt::new(), subtree, Arc::new(entries), control_hash)
}

fn bench_subtree8(
(smt, subtree, entries, control_hash): (
Smt,
NodeIndex,
Arc<BTreeMap<RpoDigest, Word>>,
RpoDigest,
),
) {
let mut state = NodeSubtreeComputer::with_smt(&smt, Default::default(), entries);
let hash = state.get_or_make_hash(subtree);
assert_eq!(control_hash, hash);
}

fn smt_subtree8(c: &mut Criterion) {
let mut group = c.benchmark_group("subtree8");

group.measurement_time(Duration::from_secs(360));
group.sample_size(30);

for &tree_size in [32, 128, 512, 1024, 8192].iter() {
let bench_id = BenchmarkId::from_parameter(tree_size);
//group.throughput(Throughput::Elements(tree_size));
group.bench_with_input(bench_id, &tree_size, |bench, &tree_size| {
bench.iter_batched(|| setup_subtree8(tree_size), bench_subtree8, BatchSize::SmallInput);
});
}

group.finish();
}

criterion_group!(subtree_group, smt_subtree8);
criterion_main!(subtree_group);
4 changes: 2 additions & 2 deletions src/merkle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ pub use path::{MerklePath, RootPath, ValuePath};

mod smt;
pub use smt::{
LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError,
SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
InnerNode, LeafIndex, MutationSet, NodeSubtreeComputer, SimpleSmt, Smt, SmtLeaf, SmtLeafError,
SmtProof, SmtProofError, SparseMerkleTree, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};

mod mmr;
Expand Down
161 changes: 160 additions & 1 deletion src/merkle/smt/full/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(feature = "async")]
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use alloc::{
collections::{BTreeMap, BTreeSet},
Expand All @@ -12,6 +12,9 @@ use super::{
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};

#[cfg(feature = "async")]
use super::NodeMutation;

mod error;
pub use error::{SmtLeafError, SmtProofError};

Expand Down Expand Up @@ -297,6 +300,27 @@ impl Smt {
None
}
}

fn construct_prospective_leaf(
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));

match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}

existing_leaf
},
}
}
}

impl SparseMerkleTree<SMT_DEPTH> for Smt {
Expand Down Expand Up @@ -399,6 +423,141 @@ impl Default for Smt {
}
}

/// Just a [`NodeMutation`] with its hash already computed and stored.
#[cfg(feature = "async")]
pub struct ComputedNodeMutation {
pub mutation: NodeMutation,
pub hash: RpoDigest,
}

#[cfg(feature = "async")]
pub struct NodeSubtreeComputer {
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
existing_mutations: Arc<HashMap<NodeIndex, ComputedNodeMutation>>,
new_mutations: HashMap<NodeIndex, ComputedNodeMutation>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
/// Cache indices we know to be dirty.
dirtied_indices: HashMap<NodeIndex, bool>,
cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>,
}

#[cfg(feature = "async")]
impl NodeSubtreeComputer {
pub fn with_smt(
smt: &Smt,
existing_mutations: Arc<HashMap<NodeIndex, ComputedNodeMutation>>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
) -> Self {
Self {
inner_nodes: Arc::clone(&smt.inner_nodes),
leaves: Arc::clone(&smt.leaves),
existing_mutations,
new_mutations: Default::default(),
new_pairs,
dirtied_indices: Default::default(),
cached_leaf_hashes: Default::default(),
}
}

pub(crate) fn is_index_dirty(&mut self, index_to_check: NodeIndex) -> bool {
if let Some(cached) = self.dirtied_indices.get(&index_to_check) {
return *cached;
}

// An index is dirty if there is a new pair at it, an known existing mutation at it, or an
// ancestor of one of those.
let is_dirty = self
.existing_mutations
.iter()
.map(|(index, _)| *index)
.chain(self.new_pairs.iter().map(|(key, _v)| Smt::key_to_leaf_index(key).index))
.filter(|&dirtied_index| index_to_check.contains(dirtied_index))
.next()
.is_some();

// This is somewhat expensive to compute, so cache this.
self.dirtied_indices.insert(index_to_check, is_dirty);
is_dirty
}

pub(crate) fn get_effective_leaf(&self, index: LeafIndex<SMT_DEPTH>) -> SmtLeaf {
let pairs_at_index = self
.new_pairs
.iter()
.filter(|&(new_key, _)| Smt::key_to_leaf_index(new_key) == index);

let existing_leaf = self
.leaves
.get(&index.index.value())
.cloned()
.unwrap_or_else(|| SmtLeaf::new_empty(index));

pairs_at_index.fold(existing_leaf, |acc, (k, v)| {
let existing_leaf = acc.clone();
Smt::construct_prospective_leaf(existing_leaf, k, v)
})
}

/// Does NOT check `new_mutations`.
pub(crate) fn get_clean_hash(&self, index: NodeIndex) -> Option<RpoDigest> {
self.existing_mutations
.get(&index)
.map(|ComputedNodeMutation { hash, .. }| *hash)
.or_else(|| self.inner_nodes.get(&index).map(|inner_node| InnerNode::hash(&inner_node)))
}

/// Retrieve a cached hash, or recursively compute it.
pub fn get_or_make_hash(&mut self, index: NodeIndex) -> RpoDigest {
use NodeMutation::*;

// If this is a leaf, then only do leaf stuff.
if index.depth() == SMT_DEPTH {
let index = LeafIndex::new(index.value()).unwrap();
return match self.cached_leaf_hashes.get(&index) {
Some(cached_hash) => cached_hash.clone(),
None => {
let leaf = self.get_effective_leaf(index);
let hash = Smt::hash_leaf(&leaf);
self.cached_leaf_hashes.insert(index, hash);
hash
},
};
}

// If we already computed this one earlier as a mutation, just return it.
if let Some(ComputedNodeMutation { hash, .. }) = self.new_mutations.get(&index) {
return *hash;
}

// Otherwise, we need to know if this node is one of the nodes we're in the process of
// recomputing, or if we can safely use the node already in the Merkle tree.
if !self.is_index_dirty(index) {
return self
.get_clean_hash(index)
.unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()));
}

// If we got here, then we have to make, rather than get, this hash.
// Make sure we mark this index as now dirty.
self.dirtied_indices.insert(index, true);

// Recurse for the left and right sides.
let left = self.get_or_make_hash(index.left_child());
let right = self.get_or_make_hash(index.right_child());
let node = InnerNode { left, right };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth());
let is_removal = hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(node) };

self.new_mutations
.insert(index, ComputedNodeMutation { hash, mutation: new_entry });

hash
}
}

// CONVERSIONS
// ================================================================================================

Expand Down
11 changes: 6 additions & 5 deletions src/merkle/smt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use crate::{
};

mod full;
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
pub use full::{
NodeSubtreeComputer, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
};

mod simple;
pub use simple::SimpleSmt;
Expand Down Expand Up @@ -43,7 +45,7 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// must accomodate all keys that map to the same leaf.
///
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
pub trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key
type Key: Clone + Ord;
/// The type for a value
Expand Down Expand Up @@ -346,7 +348,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {

#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub(crate) struct InnerNode {
pub struct InnerNode {
pub left: RpoDigest,
pub right: RpoDigest,
}
Expand Down Expand Up @@ -459,7 +461,7 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
/// need to occur at which node indices.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum NodeMutation {
pub enum NodeMutation {
/// Corresponds to [`SparseMerkleTree::remove_inner_node()`].
Removal,
/// Corresponds to [`SparseMerkleTree::insert_inner_node()`].
Expand Down Expand Up @@ -499,7 +501,6 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
}
}


#[cfg(test)]
mod tests {
use proptest::prelude::*;
Expand Down

0 comments on commit 56087c7

Please sign in to comment.