diff --git a/Cargo.lock b/Cargo.lock index e10bdb20..ad9cff71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + [[package]] name = "aho-corasick" version = "1.1.3" @@ -84,6 +99,21 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "backtrace" +version = "0.3.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -261,7 +291,7 @@ dependencies = [ "clap", "criterion-plot", "is-terminal", - "itertools", + "itertools 0.10.5", "num-traits", "once_cell", "oorandom", @@ -282,7 +312,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] @@ -364,6 +394,95 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -387,6 +506,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "gimli" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" + [[package]] name = "glob" version = "0.3.1" @@ -447,6 +572,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -524,24 +658,37 @@ dependencies = [ "cc", "clap", "criterion", + "futures", "getrandom", "glob", "hex", + "itertools 0.13.0", "num", "num-complex", "proptest", "rand", "rand_chacha", "rand_core", + "rayon", "seq-macro", "serde", "sha3", + "tokio", "winter-crypto", "winter-math", "winter-rand-utils", "winter-utils", ] +[[package]] +name = "miniz_oxide" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +dependencies = [ + "adler", +] + [[package]] name = "num" version = "0.4.3" @@ -616,6 +763,15 @@ dependencies = [ "libm", ] +[[package]] +name = "object" +version = "0.36.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -628,6 +784,18 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "plotters" version = "0.3.6" @@ -797,6 +965,12 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustix" version = "0.38.34" @@ -885,6 +1059,15 @@ dependencies = [ "keccak", ] +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "strsim" version = "0.11.1" @@ -925,6 +1108,28 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tokio" +version = "1.40.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +dependencies = [ + "backtrace", + "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index 2616341c..07b6fc4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ name = "store" harness = false [features] -default = ["std"] +default = ["std", "async"] executable = ["dep:clap", "dep:rand-utils", "std"] serde = ["dep:serde", "serde?/alloc", "winter-math/serde"] std = [ @@ -44,20 +44,25 @@ std = [ "winter-math/std", "winter-utils/std", ] +async = ["std", "dep:tokio", "dep:rayon", "dep:futures", "serde?/rc"] [dependencies] blake3 = { version = "1.5", default-features = false } clap = { version = "4.5", optional = true, features = ["derive"] } +futures = { version = "0.3.30", optional = true } num = { version = "0.4", default-features = false, features = ["alloc", "libm"] } num-complex = { version = "0.4", default-features = false } rand = { version = "0.8", default-features = false } rand_core = { version = "0.6", default-features = false } rand-utils = { version = "0.9", package = "winter-rand-utils", optional = true } +rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] } sha3 = { version = "0.10", default-features = false } +tokio = { version = "1.40", features = ["rt-multi-thread", "macros", "sync"], optional = true } winter-crypto = { version = "0.9", default-features = false } winter-math = { version = "0.9", default-features = false } winter-utils = { version = "0.9", default-features = false } +itertools = { version = "0.13.0", default-features = false, features = ["use_alloc"] } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/src/main.rs b/src/main.rs index 776ccc21..0f8332e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,10 +16,51 @@ pub struct BenchmarkCmd { size: u64, } +#[cfg(not(feature = "async"))] fn main() { benchmark_smt(); } +#[cfg(feature = "async")] +#[tokio::main(flavor = "multi_thread")] +async fn main() { + // FIXME: very incomplete + + let args = BenchmarkCmd::parse(); + let tree_size = args.size; + + let mut entries = Vec::new(); + for i in 0..tree_size { + //let key = rand_value::(); + let leaf_index = u64::MAX / (i + 1); + let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); + let value = [ONE, ONE, ONE, Felt::new(i)]; + entries.push((key, value)); + } + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let mut tree = Smt::new(); + println!("Running a parallel construction benchmark:"); + let now = Instant::now(); + let mutations = tree.compute_mutations_parallel(entries).await; + assert_eq!(mutations.root(), control.root()); + tree.apply_mutations(mutations.clone()).unwrap(); + let elapsed = now.elapsed(); + assert_eq!(tree.root(), mutations.root(), "mutation did not apply the right root?"); + assert_eq!(control.root(), mutations.root(), "mutation root hash did not match control"); + assert_eq!(tree.root(), control.root(), "applied root hash did not match control"); + std::eprintln!("\nassertion checks complete"); + + println!( + "Constructed an SMT in parallel with {} key-value pairs in {:.3} seconds", + tree_size, + elapsed.as_secs_f32(), + ); + + //benchmark_smt(); +} + /// Run a benchmark for [`Smt`]. pub fn benchmark_smt() { let args = BenchmarkCmd::parse(); diff --git a/src/merkle/index.rs b/src/merkle/index.rs index 104ceb44..03e9df5f 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -1,4 +1,4 @@ -use core::fmt::Display; +use core::{fmt::Display, num::NonZero}; use super::{Felt, MerkleError, RpoDigest}; use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; @@ -72,6 +72,50 @@ impl NodeIndex { Self::new(depth, value) } + /// Converts a scalar representation of a depth/value pair to a [`NodeIndex`]. + /// + /// This is the inverse operation of [`NodeIndex::to_scalar_index()`]. As `1` represents the + /// root node, `index` cannot be zero. + /// + /// # Errors + /// Returns the same errors under the same conditions as [`NodeIndex::new()`]. + /// + /// # Panics + /// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value + /// indicated by `index` does not fit in a [`u64`]. + pub fn from_scalar_index(index: NonZero) -> Result { + let index = index.get() - 1; + + if index == 0 { + return Ok(Self::root()); + } + + // The log of 1 is always 0. + if index == 1 { + return Ok(Self::root().left_child()); + } + + let depth = { + let depth = u128::ilog2(index + 1); + assert!(depth <= u8::MAX as u32); + depth as u8 + }; + + let max_value_for_depth = (1 << depth) - 1; + assert!( + max_value_for_depth <= u64::MAX as u128, + "max_value ({max_value_for_depth}) does not fit in u64", + ); + + let value = { + let value = index - max_value_for_depth; + assert!(value <= u64::MAX as u128); + value as u64 + }; + + Self::new(depth, value) + } + /// Creates a new node index pointing to the root of the tree. pub const fn root() -> Self { Self { depth: 0, value: 0 } @@ -90,6 +134,18 @@ impl NodeIndex { self } + pub const fn left_ancestor_n(mut self, n: u8) -> Self { + self.depth += n; + self.value <<= n; + self + } + + pub const fn right_ancestor_n(mut self, n: u8) -> Self { + self.depth += n; + self.value = (self.value << n) + 1; + self + } + /// Returns right child index of the current node. pub const fn right_child(mut self) -> Self { self.depth += 1; @@ -97,6 +153,64 @@ impl NodeIndex { self } + /// Returns the parent of the current node. + pub const fn parent(mut self) -> Self { + self.depth = self.depth.saturating_sub(1); + self.value >>= 1; + self + } + + /// Returns the `n`th parent of the current node. + pub fn parent_n(mut self, n: u8) -> Self { + debug_assert!(n <= self.depth); + self.depth = self.depth.saturating_sub(n); + self.value >>= n; + + self + } + + /// Returns `true` if and only if `other` is an ancestor of the current node, or the current + /// node itself. + pub fn contains(&self, mut other: Self) -> bool { + if other == *self { + return true; + } + if other.is_root() { + return false; + } + if other.depth < self.depth { + return false; + } + + other = other.parent_n(other.depth() - self.depth()); + + loop { + if other == *self { + return true; + } + + if other.is_root() { + return false; + } + + if other.depth < self.depth { + return false; + } + + other = other.parent(); + } + } + + /// The inverse of [`NodeIndex::is_ancestor_of`], except that it does not include itself. + pub fn is_descendent_of(self, other: Self) -> bool { + self.depth != other.depth && self.value != other.value && other.contains(self) + } + + /// Returns `true` if and only if `other` is an ancestor of the current node. + pub fn is_ancestor_of(self, other: Self) -> bool { + self.depth != other.depth && self.value != other.value && self.contains(other) + } + // PROVIDERS // -------------------------------------------------------------------------------------------- @@ -114,8 +228,8 @@ impl NodeIndex { /// Returns the scalar representation of the depth/value pair. /// /// It is computed as `2^depth + value`. - pub const fn to_scalar_index(&self) -> u64 { - (1 << self.depth as u64) + self.value + pub const fn to_scalar_index(&self) -> u128 { + (1 << self.depth as u64) + (self.value as u128) } /// Returns the depth of the current instance. @@ -180,6 +294,27 @@ impl Deserializable for NodeIndex { } } +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct SubtreeIndex { + pub root: NodeIndex, + pub depth: u8, +} + +#[allow(dead_code)] +impl SubtreeIndex { + pub const fn new(root: NodeIndex, depth: u8) -> Self { + Self { root, depth } + } + + pub const fn left_bound(&self) -> NodeIndex { + self.root.left_ancestor_n(self.depth) + } + + pub const fn right_bound(&self) -> NodeIndex { + self.root.right_ancestor_n(self.depth) + } +} + #[cfg(test)] mod tests { use proptest::prelude::*; @@ -210,6 +345,21 @@ mod tests { assert!(NodeIndex::new(64, u64::MAX).is_ok()); } + #[test] + fn test_scalar_roundtrip() { + // Arbitrary value that's at the bottom and not in a corner. + let start = NodeIndex::make(64, u64::MAX - 8); + + let mut index = start; + while !index.is_root() { + let as_scalar = index.to_scalar_index(); + let round_trip = + NodeIndex::from_scalar_index(NonZero::new(as_scalar).unwrap()).unwrap(); + assert_eq!(index, round_trip, "{index:?} did not round-trip as a scalar index"); + index.move_up(); + } + } + prop_compose! { fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex { // unwrap never panics because the range of depth is 0..u64::BITS diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 9c640021..7ce9e3d8 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -1,13 +1,25 @@ +#[cfg(feature = "async")] +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + use alloc::{ collections::{BTreeMap, BTreeSet}, string::ToString, vec::Vec, }; +#[cfg(feature = "async")] +use tokio::task::JoinSet; +#[cfg(feature = "async")] +use super::NodeMutation; use super::{ EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, }; +#[cfg(feature = "async")] +use crate::merkle::index::SubtreeIndex; mod error; pub use error::{SmtLeafError, SmtProofError}; @@ -43,8 +55,16 @@ pub const SMT_DEPTH: u8 = 64; #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct Smt { root: RpoDigest, + + #[cfg(not(feature = "async"))] leaves: BTreeMap, + #[cfg(feature = "async")] + leaves: Arc>, + + #[cfg(not(feature = "async"))] inner_nodes: BTreeMap, + #[cfg(feature = "async")] + inner_nodes: Arc>, } impl Smt { @@ -64,8 +84,8 @@ impl Smt { Self { root, - leaves: BTreeMap::new(), - inner_nodes: BTreeMap::new(), + leaves: Default::default(), + inner_nodes: Default::default(), } } @@ -101,6 +121,22 @@ impl Smt { Ok(tree) } + #[cfg(feature = "async")] + pub fn get_leaves(&self) -> Arc> { + Arc::clone(&self.leaves) + } + + #[cfg(feature = "async")] + pub async fn compute_mutations_parallel( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet { + >::compute_mutations_parallel( + self, kv_pairs, + ) + .await + } + // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -154,6 +190,40 @@ impl Smt { }) } + /// Gets a mutable reference to this structure's inner node mapping. + /// + /// # Panics + /// This will panic if we have violated our own invariants and try to mutate these nodes while + /// Self::compute_mutations_parallel() is still running. + fn inner_nodes_mut(&mut self) -> &mut BTreeMap { + #[cfg(feature = "async")] + { + Arc::get_mut(&mut self.inner_nodes).unwrap() + } + + #[cfg(not(feature = "async"))] + { + &mut self.inner_nodes + } + } + + /// Gets a mutable reference to this structure's inner leaf mapping. + /// + /// # Panics + /// This will panic if we have violated our own invariants and try to mutate these nodes while + /// Self::compute_mutations_parallel() is still running. + fn leaves_mut(&mut self) -> &mut BTreeMap { + #[cfg(feature = "async")] + { + Arc::get_mut(&mut self.leaves).unwrap() + } + + #[cfg(not(feature = "async"))] + { + &mut self.leaves + } + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -218,10 +288,12 @@ impl Smt { let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); - match self.leaves.get_mut(&leaf_index.value()) { + let leaves = self.leaves_mut(); + + match leaves.get_mut(&leaf_index.value()) { Some(leaf) => leaf.insert(key, value), None => { - self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value))); + leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value))); None }, @@ -232,10 +304,12 @@ impl Smt { fn perform_remove(&mut self, key: RpoDigest) -> Option { let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); - if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) { + let leaves = self.leaves_mut(); + + if let Some(leaf) = leaves.get_mut(&leaf_index.value()) { let (old_value, is_empty) = leaf.remove(key); if is_empty { - self.leaves.remove(&leaf_index.value()); + leaves.remove(&leaf_index.value()); } old_value } else { @@ -243,6 +317,27 @@ impl Smt { None } } + + fn construct_prospective_leaf( + mut existing_leaf: SmtLeaf, + key: &RpoDigest, + value: &Word, + ) -> SmtLeaf { + debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key)); + + match existing_leaf { + SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value), + _ => { + if *value != EMPTY_WORD { + existing_leaf.insert(*key, *value); + } else { + existing_leaf.remove(*key); + } + + existing_leaf + }, + } + } } impl SparseMerkleTree for Smt { @@ -269,11 +364,11 @@ impl SparseMerkleTree for Smt { } fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { - self.inner_nodes.insert(index, inner_node); + self.inner_nodes_mut().insert(index, inner_node); } fn remove_inner_node(&mut self, index: NodeIndex) { - let _ = self.inner_nodes.remove(&index); + let _ = self.inner_nodes_mut().remove(&index); } fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option { @@ -309,24 +404,11 @@ impl SparseMerkleTree for Smt { fn construct_prospective_leaf( &self, - mut existing_leaf: SmtLeaf, + existing_leaf: SmtLeaf, key: &RpoDigest, value: &Word, ) -> SmtLeaf { - debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key)); - - match existing_leaf { - SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value), - _ => { - if *value != EMPTY_WORD { - existing_leaf.insert(*key, *value); - } else { - existing_leaf.remove(*key); - } - - existing_leaf - }, - } + Smt::construct_prospective_leaf(existing_leaf, key, value) } fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex { @@ -339,12 +421,297 @@ impl SparseMerkleTree for Smt { } } +#[cfg(feature = "async")] +impl super::ParallelSparseMerkleTree for Smt { + // Helpers required only for the parallel version of the SMT trait. + fn get_inner_nodes(&self) -> Arc> { + Arc::clone(&self.inner_nodes) + } + + fn get_leaves(&self) -> Arc> { + Arc::clone(&self.leaves) + } + + async fn compute_mutations_parallel( + &self, + kv_pairs: I, + ) -> MutationSet + where + I: IntoIterator, + { + use std::time::Instant; + + const SUBTREE_INTERVAL: u8 = 8; + + // FIXME: check for duplicates and return MerkleError. + let kv_pairs = Arc::new(BTreeMap::from_iter(kv_pairs)); + + // The first subtrees we calculate, which include our new leaves. + let mut subtrees: HashSet = kv_pairs + .keys() + .map(|key| { + let index_for_key = NodeIndex::from(Smt::key_to_leaf_index(key)); + index_for_key.parent_n(SUBTREE_INTERVAL) + }) + .collect(); + + // Node mutations across all tasks will be collected here. + // Every time we collect tasks we store all the new known node mutations and their hashes + // (so we don't have to recompute them every time we need them). + let mut node_mutations: Arc> = + Default::default(); + // Any leaf hashes done by tasks will be collected here, so hopefully we only hash each leaf + // once. + let mut cached_leaf_hashes: Arc, RpoDigest>> = + Default::default(); + + for subtree_depth in (0..SMT_DEPTH).step_by(SUBTREE_INTERVAL.into()).rev() { + let now = Instant::now(); + let mut tasks = JoinSet::new(); + + for subtree in subtrees.iter().copied() { + debug_assert_eq!(subtree.depth(), subtree_depth); + let mut state = NodeSubtreeState::::with_smt( + &self, + Arc::clone(&node_mutations), + Arc::clone(&kv_pairs), + SubtreeIndex::new(subtree, SUBTREE_INTERVAL as u8), + ); + // The "double spawn" here is necessary to allow tokio to run these tasks in + // parallel. + tasks.spawn(tokio::spawn(async move { + let hash = state.get_or_make_hash(subtree); + (subtree, hash, state.into_results()) + })); + } + + let task_results = tasks.join_all().await; + let elapsed = now.elapsed(); + std::eprintln!( + "joined {} tasks for depth {} in {:.3} milliseconds", + task_results.len(), + subtree_depth, + elapsed.as_secs_f64() * 1000.0, + ); + + for result in task_results { + // FIXME: .expect() error message? + let result = result.unwrap(); + let (subtree, hash, state) = result; + let NodeSubtreeResults { + new_mutations, + cached_leaf_hashes: new_leaf_hashes, + } = state; + + Arc::get_mut(&mut node_mutations).unwrap().extend(new_mutations); + Arc::get_mut(&mut cached_leaf_hashes).unwrap().extend(new_leaf_hashes); + // Make sure the final hash we calculated is in the new mutations. + assert_eq!( + node_mutations.get(&subtree).unwrap().0, + hash, + "Stored and returned hashes for subtree '{subtree:?}' differ", + ); + } + + // And advance our subtrees, unless we just did the root depth. + if subtree_depth == 0 { + continue; + } + + let subtree_count_before_advance = subtrees.len(); + subtrees = + subtrees.into_iter().map(|subtree| subtree.parent_n(SUBTREE_INTERVAL)).collect(); + // FIXME: remove. + assert!(subtrees.len() <= subtree_count_before_advance); + } + + let root = NodeIndex::root(); + let new_root = node_mutations.get(&root).unwrap().0; + + MutationSet { + old_root: self.root(), + //node_mutations: Arc::into_inner(node_mutations).unwrap().into_iter().collect(), + node_mutations: Arc::into_inner(node_mutations) + .unwrap() + .into_iter() + .map(|(key, (_hash, node))| (key, node)) + .collect(), + new_pairs: Arc::into_inner(kv_pairs).unwrap(), + new_root, + } + } +} + impl Default for Smt { fn default() -> Self { Self::new() } } +#[cfg(feature = "async")] +pub(crate) struct NodeSubtreeState { + inner_nodes: Arc>, + leaves: Arc>, + // This field has invariants! + dirtied_indices: HashMap, + existing_mutations: Arc>, + new_mutations: HashMap, + new_pairs: Arc>, + cached_leaf_hashes: HashMap, RpoDigest>, + indentation: u8, + subtree: SubtreeIndex, +} + +#[cfg(feature = "async")] +impl NodeSubtreeState { + pub(crate) fn new( + inner_nodes: Arc>, + existing_mutations: Arc>, + leaves: Arc>, + new_pairs: Arc>, + subtree: SubtreeIndex, + ) -> Self { + Self { + inner_nodes, + leaves, + dirtied_indices: Default::default(), + new_mutations: Default::default(), + existing_mutations, + new_pairs, + cached_leaf_hashes: Default::default(), + indentation: 0, + subtree, + } + } + + pub(crate) fn with_smt( + smt: &Smt, + existing_mutations: Arc>, + new_pairs: Arc>, + subtree: SubtreeIndex, + ) -> Self { + Self::new( + Arc::clone(&smt.inner_nodes), + existing_mutations, + Arc::clone(&smt.leaves), + new_pairs, + subtree, + ) + } + + #[inline(never)] // XXX: for profiling. + pub(crate) fn is_index_dirty(&mut self, index_to_check: NodeIndex) -> bool { + if index_to_check == self.subtree.root { + return true; + } + if let Some(cached) = self.dirtied_indices.get(&index_to_check) { + return *cached; + } + let is_dirty = self + .existing_mutations + .iter() + .map(|(index, _)| *index) + .chain(self.new_pairs.iter().map(|(key, _v)| Smt::key_to_leaf_index(key).index)) + .filter(|&dirtied_index| index_to_check.contains(dirtied_index)) + .next() + .is_some(); + self.dirtied_indices.insert(index_to_check, is_dirty); + is_dirty + } + + /// Does NOT check `new_mutations`. + #[inline(never)] // XXX: for profiling. + pub(crate) fn get_clean_hash(&self, index: NodeIndex) -> Option { + self.existing_mutations + .get(&index) + .map(|(hash, _)| *hash) + .or_else(|| self.inner_nodes.get(&index).map(|inner_node| InnerNode::hash(&inner_node))) + } + + #[inline(never)] // XXX: for profiling. + pub(crate) fn get_effective_leaf(&self, index: LeafIndex) -> SmtLeaf { + let pairs_at_index = self + .new_pairs + .iter() + .filter(|&(new_key, _)| Smt::key_to_leaf_index(new_key) == index); + + let existing_leaf = self + .leaves + .get(&index.index.value()) + .cloned() + .unwrap_or_else(|| SmtLeaf::new_empty(index)); + + pairs_at_index.fold(existing_leaf, |acc, (k, v)| { + let existing_leaf = acc.clone(); + Smt::construct_prospective_leaf(existing_leaf, k, v) + }) + } + + /// Retrieve a cached hash, or recursively compute it. + #[inline(never)] // XXX: for profiling. + pub fn get_or_make_hash(&mut self, index: NodeIndex) -> RpoDigest { + use NodeMutation::*; + + // If this is a leaf, then only do leaf stuff. + if index.depth() == SMT_DEPTH { + let index = LeafIndex::new(index.value()).unwrap(); + return match self.cached_leaf_hashes.get(&index) { + Some(cached_hash) => cached_hash.clone(), + None => { + let leaf = self.get_effective_leaf(index); + let hash = Smt::hash_leaf(&leaf); + self.cached_leaf_hashes.insert(index, hash); + hash + }, + }; + } + + // If we already computed this one earlier as a mutation, just return it. + if let Some((hash, _)) = self.new_mutations.get(&index) { + return *hash; + } + + // Otherwise, we need to know if this node is one of the nodes we're in the process of + // recomputing, or if we can safely use the node already in the Merkle tree. + if !self.is_index_dirty(index) { + return self + .get_clean_hash(index) + .unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth())); + } + + // If we got here, then we have to make, rather than get, this hash. + // Make sure we mark this index as now dirty. + self.dirtied_indices.insert(index, true); + + // Recurse for the left and right sides. + let left = self.get_or_make_hash(index.left_child()); + let right = self.get_or_make_hash(index.right_child()); + let node = InnerNode { left, right }; + let hash = node.hash(); + let &equivalent_empty_hash = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()); + let is_removal = hash == equivalent_empty_hash; + let new_entry = if is_removal { Removal } else { Addition(node) }; + + self.new_mutations.insert(index, (hash, new_entry)); + + hash + } + + fn into_results(self) -> NodeSubtreeResults { + NodeSubtreeResults { + new_mutations: self.new_mutations, + cached_leaf_hashes: self.cached_leaf_hashes, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +#[cfg(feature = "async")] +pub(crate) struct NodeSubtreeResults { + pub(crate) new_mutations: HashMap, + pub(crate) cached_leaf_hashes: HashMap, RpoDigest>, +} + // CONVERSIONS // ================================================================================================ diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index 1613c8f2..6c6c4584 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -1,6 +1,19 @@ use alloc::vec::Vec; +#[cfg(feature = "async")] +use std::{ + collections::{BTreeMap, BTreeSet, HashMap}, + sync::Arc, + time::Instant, +}; +#[cfg(feature = "async")] +use tokio::task::JoinSet; use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; +#[cfg(feature = "async")] +use crate::merkle::{ + index::SubtreeIndex, + smt::{full::NodeSubtreeState, NodeMutation}, +}; use crate::{ merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore}, utils::{Deserializable, Serializable}, @@ -568,6 +581,263 @@ fn test_multiple_smt_leaf_serialization_success() { assert_eq!(multiple_leaf, deserialized); } +#[cfg(feature = "async")] +fn setup_subtree_test(kv_count: u64) -> (Vec<(RpoDigest, Word)>, Smt) { + // FIXME: override seed. + let rand_felt = || rand_utils::rand_value::(); + + let kv_pairs: Vec<(RpoDigest, Word)> = (0..kv_count) + .into_iter() + .map(|i| { + let leaf_index = u64::MAX / (i + 1); + let key = + RpoDigest::new([rand_felt(), rand_felt(), rand_felt(), Felt::new(leaf_index)]); + let value: Word = [Felt::new(i), rand_felt(), rand_felt(), rand_felt()]; + (key, value) + }) + .collect(); + + let control_smt = Smt::with_entries(kv_pairs.clone()).unwrap(); + + (kv_pairs, control_smt) +} + +#[test] +#[cfg(feature = "async")] +fn test_single_node_subtree() { + use alloc::collections::BTreeMap; + use std::{collections::HashMap, sync::Arc}; + + use crate::merkle::smt::{full::NodeSubtreeState, NodeMutation}; + + const KV_COUNT: u64 = 2_000; + + let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT); + let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs)); + + let test_smt = Smt::new(); + + let rt = tokio::runtime::Builder::new_current_thread().build().unwrap(); + let _: () = rt.block_on(async move { + // Construct some fake node mutations based on the leaves in the control Smt. + let node_mutations: HashMap = control_smt + .leaves() + .flat_map(|(index, _leaf)| { + let subtree = index.index.parent(); + let mutation = control_smt + .inner_nodes + .get(&subtree) + .cloned() + .map(|node| (node.hash(), NodeMutation::Addition(node))) + .unwrap_or_else(|| { + ( + *EmptySubtreeRoots::entry(SMT_DEPTH, subtree.depth()), + NodeMutation::Removal, + ) + }); + + vec![(subtree, mutation)] + }) + .collect(); + let node_mutations = Arc::new(node_mutations); + + let mut state = NodeSubtreeState::::new( + Arc::clone(&test_smt.inner_nodes), + Arc::clone(&node_mutations), + Arc::clone(&control_smt.leaves), + Arc::clone(&new_pairs), + SubtreeIndex::new(NodeIndex::root(), 8), + ); + + for (i, (&index, mutation)) in node_mutations.iter().enumerate() { + assert!(index.depth() <= SMT_DEPTH, "index {index:?} is invalid"); + + let control_hash = if index.depth() < SMT_DEPTH { + control_smt.get_inner_node(index).hash() + } else { + control_smt + .leaves + .get(&index.value()) + .map(Smt::hash_leaf) + .unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth())) + }; + let mutation_hash = mutation.0; + let test_hash = state.get_or_make_hash(index); + assert_eq!(mutation_hash, control_hash); + assert_eq!( + test_hash, control_hash, + "test_hash != control_hash for mutation {i} at {index:?}", + ); + } + }); +} + +// Test doing a node subtree from a LeafSubtreeMutationSet. +#[test] +#[cfg(feature = "async")] +fn test_node_subtree_with_leaves() { + const KV_COUNT: u64 = 2_000; + + let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT); + let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs)); + + let test_smt = Smt::new(); + + let rt = tokio::runtime::Builder::new_current_thread().build().unwrap(); + let _: () = rt.block_on(async move { + let mut task_mutations: HashMap = Default::default(); + + for leaf_index in control_smt.leaves().map(|(index, _leaf)| index) { + let subtree = SubtreeIndex::new(leaf_index.index.parent_n(8), 8); + let subtree_pairs: Vec<(RpoDigest, Word)> = control_smt + .leaves() + .flat_map(|(leaf_index, leaf)| { + if subtree.root.contains(leaf_index.index) { + leaf.entries() + } else { + vec![] + } + }) + .cloned() + .collect(); + + let mut state = NodeSubtreeState::::with_smt( + &test_smt, + Default::default(), + Arc::new(BTreeMap::from_iter(subtree_pairs)), + //Arc::clone(&new_pairs), + subtree, + ); + let test_subtree_hash = state.get_or_make_hash(subtree.root); + let control_subtree_hash = control_smt.get_inner_node(subtree.root).hash(); + assert_eq!(test_subtree_hash, control_subtree_hash); + + task_mutations.extend(state.new_mutations); + } + + let node_mutations = Arc::new(task_mutations); + let subtrees: BTreeSet = control_smt + .leaves() + .map(|(index, _leaf)| SubtreeIndex::new(index.index.parent_n(8).parent_n(8), 8)) + .collect(); + + for (i, subtree) in subtrees.into_iter().enumerate() { + let mut state = NodeSubtreeState::::with_smt( + &test_smt, + Arc::clone(&node_mutations), + Arc::clone(&new_pairs), + subtree, + ); + + let control_subtree_hash = control_smt.get_inner_node(subtree.root).hash(); + let test_subtree_hash = state.get_or_make_hash(subtree.root); + assert_eq!( + test_subtree_hash, control_subtree_hash, + "test subtree hash does not match control hash for subtree {i} '{subtree:?}'", + ); + } + }); +} + +#[test] +#[cfg(feature = "async")] +fn test_node_subtrees_parallel() { + const KV_COUNT: u64 = 2_000; + + let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT); + let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs)); + + let test_smt = Smt::new(); + + let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap(); + //let rt = tokio::runtime::Builder::new_current_thread().build().unwrap(); + let _: () = rt.block_on(async move { + let mut current_subtree_depth = SMT_DEPTH; + + let subtrees: BTreeSet = new_pairs + .keys() + .map(|key| SubtreeIndex::new(Smt::key_to_leaf_index(key).index.parent_n(8), 8)) + .collect(); + current_subtree_depth -= 8; + + let mut node_mutations: Arc> = + Default::default(); + let mut tasks = JoinSet::new(); + + // FIXME + let mut now = Instant::now(); + for subtree in subtrees.iter().copied() { + let mut state = NodeSubtreeState::::with_smt( + &test_smt, + Arc::clone(&node_mutations), + Arc::clone(&new_pairs), + subtree, + ); + tasks.spawn(tokio::spawn(async move { + let hash = state.get_or_make_hash(subtree.root); + (subtree, hash, state) + })); + } + + let mut cached_leaf_hashes: HashMap, RpoDigest> = Default::default(); + let mut subtrees = subtrees; + let mut tasks = Some(tasks); + while current_subtree_depth > 0 { + std::eprintln!( + "joining {} tasks for depth {current_subtree_depth}", + tasks.as_ref().unwrap().len(), + ); + let mut tasks_mutations: HashMap = + Default::default(); + let results = tasks.take().unwrap().join_all().await; + let elapsed = now.elapsed(); + std::eprintln!(" joined in {:.3} milliseconds", elapsed.as_secs_f64() * 1000.0); + for result in results { + let (subtree, test_hash, state) = result.unwrap(); + let control_hash = control_smt.get_inner_node(subtree.root).hash(); + assert_eq!(test_hash, control_hash); + + tasks_mutations.extend(state.new_mutations); + cached_leaf_hashes.extend(state.cached_leaf_hashes); + } + Arc::get_mut(&mut node_mutations).unwrap().extend(tasks_mutations); + + // Move all our subtrees up. + current_subtree_depth -= 8; + subtrees = subtrees + .into_iter() + .map(|subtree| { + let subtree = SubtreeIndex::new(subtree.root.parent_n(8), 8); + assert_eq!(subtree.root.depth(), current_subtree_depth); + subtree + }) + .collect(); + + // And spawn our new tasks. + //std::eprintln!("spawning tasks for depth {current_subtree_depth}"); + let tasks = tasks.insert(JoinSet::new()); + // FIXME + now = Instant::now(); + for subtree in subtrees.iter().copied() { + let mut state = NodeSubtreeState::::with_smt( + &test_smt, + Arc::clone(&node_mutations), + Arc::clone(&new_pairs), + subtree, + ); + state.cached_leaf_hashes = cached_leaf_hashes.clone(); + tasks.spawn(tokio::spawn(async move { + let hash = state.get_or_make_hash(subtree.root); + (subtree, hash, state) + })); + } + } + + assert!(tasks.is_some()); + assert_eq!(tasks.as_ref().unwrap().len(), 1); + }); +} + // HELPERS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 0b7ceb95..a03bab08 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -12,6 +12,11 @@ pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH}; mod simple; pub use simple::SimpleSmt; +#[cfg(feature = "async")] +mod parallel; +#[cfg(feature = "async")] +pub(crate) use parallel::ParallelSparseMerkleTree; + // CONSTANTS // ================================================================================================ @@ -379,6 +384,48 @@ impl LeafIndex { pub fn value(&self) -> u64 { self.index.value() } + + /// Lowest common ancestor — finds the lowest (highest depth) [`NodeIndex`] that is an ancestor + /// of both `self` and `rhs`. + /// + /// The general case algorithm is `O(n)`, however leaf indexes are always at the same depth, + /// and we only need find the depth of the lowest-common ancestor (since we can trivially get + /// its horizontal position based on either child's position), so we can reduce this to + /// `O(log n)`. + pub fn lca(&self, other: &Self) -> NodeIndex { + let mut self_scalar = self.index.to_scalar_index(); + let mut other_scalar = other.index.to_scalar_index(); + + while self_scalar != other_scalar { + self_scalar >>= 1; + other_scalar >>= 1; + } + + // Once we've shifted them enough to be equal, we've found a scalar index with the depth of + // the lowest common ancestor. Time to convert that scalar index to a depth, and apply that + // depth to either of our `NodeIndex`s to get the full position of that ancestor. + + // In general, we can get the depth of a binary tree's scalar index by taking the binary + // logarithm of that index. However, for the root node, the scalar index is 0, and the + // logarithm is undefined for 0, so we trivally special case the root index. + if self_scalar == 0 { + return NodeIndex::root(); + } + + let depth = { + let depth = u128::ilog2(self_scalar); + // The scalar index should not be able to exceed `u8::MAX + u64::MAX` (as those are the + // maximum values `NodeIndex` can hold), and the binary logarithm of `u8::MAX + + // u64::MAX` is 64, which fits in a u8. In other words, this assert should only be + // possible to fail if `to_scalar_index()` is wildly incorrect. + debug_assert!(depth <= u8::MAX as u32); + depth as u8 + }; + + let mut lca = self.index; + lca.move_up_to(depth); + lca + } } impl LeafIndex { @@ -424,6 +471,26 @@ pub(crate) enum NodeMutation { Addition(InnerNode), } +impl NodeMutation { + #[allow(dead_code)] + pub fn into_inner_node(self, tree_depth: u8, node_depth: u8) -> InnerNode { + use NodeMutation::*; + match self { + Addition(node) => node, + Removal => EmptySubtreeRoots::get_inner_node(tree_depth, node_depth), + } + } + + #[allow(dead_code)] + pub fn as_hash(&self, tree_depth: u8, node_depth: u8) -> RpoDigest { + use NodeMutation::*; + match self { + Addition(node) => node.hash(), + Removal => *EmptySubtreeRoots::entry(tree_depth, node_depth), + } + } +} + /// Represents a group of prospective mutations to a `SparseMerkleTree`, created by /// `SparseMerkleTree::compute_mutations()`, and that can be applied with /// `SparseMerkleTree::apply_mutations()`. @@ -456,3 +523,39 @@ impl MutationSet { self.new_root } } + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use crate::merkle::{LeafIndex, NodeIndex, SMT_DEPTH}; + + prop_compose! { + fn leaf_index()(value in 0..2u64.pow(u64::BITS - 1)) -> LeafIndex { + LeafIndex::new(value).unwrap() + } + } + + proptest! { + /// Tests that the O(log n) algorithm has the same results as the naïve version. + #[test] + fn test_leaf_lca(left in leaf_index(), right in leaf_index()) { + let control: NodeIndex = { + let mut left = left.index; + let mut right = right.index; + + loop { + if left == right { + break left; + } + left.move_up(); + right.move_up(); + } + }; + + let actual: NodeIndex = left.lca(&right); + + assert_eq!(actual, control); + } + } +} diff --git a/src/merkle/smt/parallel.rs b/src/merkle/smt/parallel.rs new file mode 100644 index 00000000..6d1e41ec --- /dev/null +++ b/src/merkle/smt/parallel.rs @@ -0,0 +1,62 @@ +use std::{ + cmp::Ordering, + collections::BTreeMap, + sync::{Arc, LazyLock}, + thread, +}; + +use crate::merkle::smt::{InnerNode, MutationSet, NodeIndex, SparseMerkleTree}; + +static TASK_COUNT: LazyLock = LazyLock::new(|| { + // FIXME: error handling? + thread::available_parallelism().unwrap().get() +}); + +#[allow(dead_code)] +pub(crate) trait ParallelSparseMerkleTree +where + // Note: these type bounds need to be specified this way or we'll have to duplicate them + // everywhere. + // https://github.com/rust-lang/rust/issues/130805. + Self: SparseMerkleTree< + DEPTH, + Key: Send + Sync + 'static, + Value: Send + Sync + 'static, + Leaf: Send + Sync + 'static, + >, +{ + /// Shortcut for [`ParallelSparseMerkleTree::compute_mutations_parallel_n()`] with an + /// automatically determined number of tasks. + /// + /// Currently, the default number of tasks is the return value of + /// [`std::thread::available_parallelism()`], but this may be subject to change in the future. + async fn compute_mutations_parallel( + &self, + kv_pairs: I, + ) -> MutationSet + where + I: IntoIterator, + { + self.compute_mutations_parallel_n(kv_pairs, *TASK_COUNT).await + } + + async fn compute_mutations_parallel_n( + &self, + _kv_pairs: I, + _tasks: usize, + ) -> MutationSet + where + I: IntoIterator, + { + todo!(); + } + + fn get_inner_nodes(&self) -> Arc>; + fn get_leaves(&self) -> Arc>; + fn get_leaf_value(_leaf: &Self::Leaf, _key: &Self::Key) -> Option { + todo!(); + } + fn cmp_keys(_lhs: &Self::Key, _rhs: &Self::Key) -> Ordering { + todo!(); + } +} diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 17444300..49231821 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -1,4 +1,6 @@ use alloc::collections::{BTreeMap, BTreeSet}; +#[cfg(feature = "async")] +use std::sync::Arc; use super::{ super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, @@ -20,7 +22,10 @@ mod tests; pub struct SimpleSmt { root: RpoDigest, leaves: BTreeMap, + #[cfg(not(feature = "async"))] inner_nodes: BTreeMap, + #[cfg(feature = "async")] + inner_nodes: Arc>, } impl SimpleSmt { @@ -52,7 +57,7 @@ impl SimpleSmt { Ok(Self { root, leaves: BTreeMap::new(), - inner_nodes: BTreeMap::new(), + inner_nodes: Default::default(), }) } @@ -175,6 +180,23 @@ impl SimpleSmt { }) } + /// Gets a mutable reference to this structure's inner node mapping. + /// + /// # Panics + /// This will panic if we have violated our own invariants and try to mutate these nodes while + /// Self::compute_mutations_parallel() is still running. + fn inner_nodes_mut(&mut self) -> &mut BTreeMap { + #[cfg(feature = "async")] + { + Arc::get_mut(&mut self.inner_nodes).unwrap() + } + + #[cfg(not(feature = "async"))] + { + &mut self.inner_nodes + } + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -271,7 +293,16 @@ impl SimpleSmt { // add subtree's branch nodes (which includes the root) // -------------- - for (branch_idx, branch_node) in subtree.inner_nodes { + let subtree_nodes; + #[cfg(feature = "async")] + { + subtree_nodes = Arc::into_inner(subtree.inner_nodes).unwrap(); + } + #[cfg(not(feature = "async"))] + { + subtree_nodes = subtree.inner_nodes + } + for (branch_idx, branch_node) in subtree_nodes { let new_branch_idx = { let new_depth = subtree_root_insertion_depth + branch_idx.depth(); let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into()) @@ -280,7 +311,7 @@ impl SimpleSmt { NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid") }; - self.inner_nodes.insert(new_branch_idx, branch_node); + self.inner_nodes_mut().insert(new_branch_idx, branch_node); } // recompute nodes starting from subtree root @@ -315,11 +346,11 @@ impl SparseMerkleTree for SimpleSmt { } fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { - self.inner_nodes.insert(index, inner_node); + self.inner_nodes_mut().insert(index, inner_node); } fn remove_inner_node(&mut self, index: NodeIndex) { - let _ = self.inner_nodes.remove(&index); + let _ = self.inner_nodes_mut().remove(&index); } fn insert_value(&mut self, key: LeafIndex, value: Word) -> Option { @@ -364,3 +395,19 @@ impl SparseMerkleTree for SimpleSmt { (path, leaf).into() } } + +//#[cfg(feature = "async")] +////impl super::ParallelSparseMerkleTree, Word, Word> +//impl super::ParallelSparseMerkleTree for SimpleSmt { +// fn get_inner_nodes(&self) -> Arc> { +// Arc::clone(&self.inner_nodes) +// } +// +// fn get_leaf_value(leaf: &Word, _key: &LeafIndex) -> Option { +// Some(*leaf) +// } +// +// fn cmp_keys(lhs: &LeafIndex, rhs: &LeafIndex) -> Ordering { +// LeafIndex::cmp(lhs, rhs) +// } +//}