diff --git a/Cargo.toml b/Cargo.toml index 4046e77..c066f2a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,8 @@ alloy-primitives = { version = "0.8.5", default-features = false, features = [ alloy-rlp = { version = "0.3.8", default-features = false, features = [ "derive", ] } + +arrayvec = { version = "0.7", default-features = false } derive_more = { version = "1", default-features = false, features = [ "add", "add_assign", @@ -62,12 +64,18 @@ default = ["std", "alloy-primitives/default"] std = [ "alloy-primitives/std", "alloy-rlp/std", + "arrayvec/std", "derive_more/std", "nybbles/std", "tracing/std", "serde?/std", ] -serde = ["dep:serde", "alloy-primitives/serde", "nybbles/serde"] +serde = [ + "dep:serde", + "alloy-primitives/serde", + "arrayvec/serde", + "nybbles/serde", +] arbitrary = [ "std", "dep:arbitrary", diff --git a/src/hash_builder/mod.rs b/src/hash_builder/mod.rs index 776de6b..b842fb0 100644 --- a/src/hash_builder/mod.rs +++ b/src/hash_builder/mod.rs @@ -1,11 +1,11 @@ //! The implementation of the hash builder. use super::{ - nodes::{word_rlp, BranchNodeRef, ExtensionNodeRef, LeafNodeRef}, + nodes::{BranchNodeRef, ExtensionNodeRef, LeafNodeRef}, proof::ProofRetainer, BranchNodeCompact, Nibbles, TrieMask, EMPTY_ROOT_HASH, }; -use crate::{proof::ProofNodes, HashMap}; +use crate::{nodes::RlpNode, proof::ProofNodes, HashMap}; use alloy_primitives::{hex, keccak256, B256}; use alloy_rlp::EMPTY_STRING_CODE; use core::cmp; @@ -45,7 +45,7 @@ pub use value::HashBuilderValue; #[allow(missing_docs)] pub struct HashBuilder { pub key: Nibbles, - pub stack: Vec>, + pub stack: Vec, pub value: HashBuilderValue, pub groups: Vec, @@ -131,7 +131,7 @@ impl HashBuilder { if !self.key.is_empty() { self.update(&key); } else if key.is_empty() { - self.stack.push(word_rlp(&value)); + self.stack.push(RlpNode::word_rlp(&value)); } self.set_key_value(key, value); self.stored_in_database = stored_in_database; @@ -250,7 +250,7 @@ impl HashBuilder { } HashBuilderValue::Hash(hash) => { trace!(target: "trie::hash_builder", ?hash, "pushing branch node hash"); - self.stack.push(word_rlp(hash)); + self.stack.push(RlpNode::word_rlp(hash)); if self.stored_in_database { self.tree_masks[current.len() - 1] |= @@ -266,17 +266,17 @@ impl HashBuilder { if build_extensions && !short_node_key.is_empty() { self.update_masks(¤t, len_from); - let stack_last = - self.stack.pop().expect("there should be at least one stack item; qed"); + let stack_last = self.stack.pop().expect("there should be at least one stack item"); let extension_node = ExtensionNodeRef::new(&short_node_key, &stack_last); - trace!(target: "trie::hash_builder", ?extension_node, "pushing extension node"); - trace!(target: "trie::hash_builder", rlp = { - self.rlp_buf.clear(); - hex::encode(extension_node.rlp(&mut self.rlp_buf)) - }, "extension node rlp"); self.rlp_buf.clear(); - self.stack.push(extension_node.rlp(&mut self.rlp_buf)); + let rlp = extension_node.rlp(&mut self.rlp_buf); + trace!(target: "trie::hash_builder", + ?extension_node, + ?rlp, + "pushing extension node", + ); + self.stack.push(rlp); self.retain_proof_from_buf(¤t.slice(..len_from)); self.resize_masks(len_from); } @@ -325,7 +325,7 @@ impl HashBuilder { fn push_branch_node(&mut self, current: &Nibbles, len: usize) -> Vec { let state_mask = self.groups[len]; let hash_mask = self.hash_masks[len]; - let branch_node = BranchNodeRef::new(&self.stack, &state_mask); + let branch_node = BranchNodeRef::new(&self.stack, state_mask); // Avoid calculating this value if it's not needed. let children = if self.updated_branch_nodes.is_some() { branch_node.child_hashes(hash_mask).collect() @@ -345,10 +345,9 @@ impl HashBuilder { old_len = self.stack.len(), "resizing stack to prepare branch node" ); - self.stack.resize(first_child_idx, vec![]); + self.stack.resize_with(first_child_idx, Default::default); - trace!(target: "trie::hash_builder", "pushing branch node with {:?} mask from stack", state_mask); - trace!(target: "trie::hash_builder", rlp = hex::encode(&rlp), "branch node rlp"); + trace!(target: "trie::hash_builder", ?rlp, "pushing branch node with {state_mask:?} mask from stack"); self.stack.push(rlp); children } @@ -570,8 +569,8 @@ mod tests { #[test] fn manual_branch_node_ok() { let raw_input = vec![ - (hex!("646f").to_vec(), hex!("76657262").to_vec()), - (hex!("676f6f64").to_vec(), hex!("7075707079").to_vec()), + (hex!("646f").to_vec(), RlpNode::from_raw(&hex!("76657262")).unwrap()), + (hex!("676f6f64").to_vec(), RlpNode::from_raw(&hex!("7075707079")).unwrap()), ]; let expected = triehash_trie_root(raw_input.clone()); diff --git a/src/nodes/branch.rs b/src/nodes/branch.rs index cf4e874..6e8f719 100644 --- a/src/nodes/branch.rs +++ b/src/nodes/branch.rs @@ -1,4 +1,4 @@ -use super::{super::TrieMask, rlp_node, CHILD_INDEX_RANGE}; +use super::{super::TrieMask, RlpNode, CHILD_INDEX_RANGE}; use alloy_primitives::{hex, B256}; use alloy_rlp::{length_of_length, Buf, BufMut, Decodable, Encodable, Header, EMPTY_STRING_CODE}; use core::{fmt, ops::Range, slice::Iter}; @@ -14,7 +14,7 @@ use alloc::vec::Vec; #[derive(PartialEq, Eq, Clone, Default)] pub struct BranchNode { /// The collection of RLP encoded children. - pub stack: Vec>, + pub stack: Vec, /// The bitmask indicating the presence of children at the respective nibble positions pub state_mask: TrieMask, } @@ -61,7 +61,7 @@ impl Decodable for BranchNode { // Decode without advancing let Header { payload_length, .. } = Header::decode(&mut &bytes[..])?; let len = payload_length + length_of_length(payload_length); - stack.push(Vec::from(&bytes[..len])); + stack.push(RlpNode::from_raw_rlp(&bytes[..len])?); bytes.advance(len); state_mask.set_bit(index); } @@ -79,13 +79,13 @@ impl Decodable for BranchNode { impl BranchNode { /// Creates a new branch node with the given stack and state mask. - pub const fn new(stack: Vec>, state_mask: TrieMask) -> Self { + pub const fn new(stack: Vec, state_mask: TrieMask) -> Self { Self { stack, state_mask } } /// Return branch node as [BranchNodeRef]. pub fn as_ref(&self) -> BranchNodeRef<'_> { - BranchNodeRef::new(&self.stack, &self.state_mask) + BranchNodeRef::new(&self.stack, self.state_mask) } } @@ -97,10 +97,10 @@ pub struct BranchNodeRef<'a> { /// NOTE: The referenced stack might have more items than the number of children /// for this node. We should only ever access items starting from /// [BranchNodeRef::first_child_index]. - pub stack: &'a [Vec], + pub stack: &'a [RlpNode], /// Reference to bitmask indicating the presence of children at /// the respective nibble positions. - pub state_mask: &'a TrieMask, + pub state_mask: TrieMask, } impl fmt::Debug for BranchNodeRef<'_> { @@ -122,12 +122,9 @@ impl Encodable for BranchNodeRef<'_> { Header { list: true, payload_length: self.rlp_payload_length() }.encode(out); // Extend the RLP buffer with the present children - let mut stack_ptr = self.first_child_index(); - for index in CHILD_INDEX_RANGE { - if self.state_mask.is_bit_set(index) { - out.put_slice(&self.stack[stack_ptr]); - // Advance the pointer to the next child. - stack_ptr += 1; + for (_, child) in self.children() { + if let Some(child) = child { + out.put_slice(child); } else { out.put_u8(EMPTY_STRING_CODE); } @@ -145,7 +142,8 @@ impl Encodable for BranchNodeRef<'_> { impl<'a> BranchNodeRef<'a> { /// Create a new branch node from the stack of nodes. - pub const fn new(stack: &'a [Vec], state_mask: &'a TrieMask) -> Self { + #[inline] + pub const fn new(stack: &'a [RlpNode], state_mask: TrieMask) -> Self { Self { stack, state_mask } } @@ -155,34 +153,40 @@ impl<'a> BranchNodeRef<'a> { /// /// If the stack length is less than number of children specified in state mask. /// Means that the node is in inconsistent state. + #[inline] pub fn first_child_index(&self) -> usize { self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap() } + #[inline] + fn children(&self) -> impl Iterator)> + '_ { + BranchChildrenIter::new(self) + } + /// Given the hash mask of children, return an iterator over stack items /// that match the mask. + #[inline] pub fn child_hashes(&self, hash_mask: TrieMask) -> impl Iterator + '_ { - BranchChildrenIter::new(self) + self.children() + .filter_map(|(i, c)| c.map(|c| (i, c))) .filter(move |(index, _)| hash_mask.is_bit_set(*index)) .map(|(_, child)| B256::from_slice(&child[1..])) } - /// Returns the RLP encoding of the branch node given the state mask of children present. - pub fn rlp(&self, out: &mut Vec) -> Vec { - self.encode(out); - rlp_node(out) + /// RLP-encodes the node and returns either `rlp(node)` or `rlp(keccak(rlp(node)))`. + #[inline] + pub fn rlp(&self, rlp: &mut Vec) -> RlpNode { + self.encode(rlp); + RlpNode::from_rlp(rlp) } /// Returns the length of RLP encoded fields of branch node. + #[inline] fn rlp_payload_length(&self) -> usize { let mut payload_length = 1; - - let mut stack_ptr = self.first_child_index(); - for digit in CHILD_INDEX_RANGE { - if self.state_mask.is_bit_set(digit) { - payload_length += self.stack[stack_ptr].len(); - // Advance the pointer to the next child. - stack_ptr += 1; + for (_, child) in self.children() { + if let Some(child) = child { + payload_length += child.len(); } else { payload_length += 1; } @@ -195,8 +199,8 @@ impl<'a> BranchNodeRef<'a> { #[derive(Debug)] struct BranchChildrenIter<'a> { range: Range, - state_mask: &'a TrieMask, - stack_iter: Iter<'a, Vec>, + state_mask: TrieMask, + stack_iter: Iter<'a, RlpNode>, } impl<'a> BranchChildrenIter<'a> { @@ -211,15 +215,34 @@ impl<'a> BranchChildrenIter<'a> { } impl<'a> Iterator for BranchChildrenIter<'a> { - type Item = (u8, &'a [u8]); + type Item = (u8, Option<&'a RlpNode>); + #[inline] fn next(&mut self) -> Option { - loop { - let current = self.range.next()?; - if self.state_mask.is_bit_set(current) { - return Some((current, self.stack_iter.next()?)); - } - } + let i = self.range.next()?; + let value = if self.state_mask.is_bit_set(i) { + // SAFETY: `first_child_index` guarantees that `stack` is exactly + // `state_mask.count_ones()` long. + Some(unsafe { self.stack_iter.next().unwrap_unchecked() }) + } else { + None + }; + Some((i, value)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl core::iter::FusedIterator for BranchChildrenIter<'_> {} + +impl ExactSizeIterator for BranchChildrenIter<'_> { + #[inline] + fn len(&self) -> usize { + self.range.len() } } @@ -292,7 +315,7 @@ impl BranchNodeCompact { #[cfg(test)] mod tests { use super::*; - use crate::nodes::{word_rlp, ExtensionNode, LeafNode}; + use crate::nodes::{ExtensionNode, LeafNode}; use nybbles::Nibbles; #[test] @@ -302,13 +325,19 @@ mod tests { assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), empty); let sparse_node = BranchNode::new( - vec![word_rlp(&B256::repeat_byte(1)), word_rlp(&B256::repeat_byte(2))], + vec![ + RlpNode::word_rlp(&B256::repeat_byte(1)), + RlpNode::word_rlp(&B256::repeat_byte(2)), + ], TrieMask::new(0b1000100), ); let encoded = alloy_rlp::encode(&sparse_node); assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), sparse_node); - let leaf_child = LeafNode::new(Nibbles::from_nibbles(hex!("0203")), hex!("1234").to_vec()); + let leaf_child = LeafNode::new( + Nibbles::from_nibbles(hex!("0203")), + RlpNode::from_raw(&hex!("1234")).unwrap(), + ); let mut buf = vec![]; let leaf_rlp = leaf_child.as_ref().rlp(&mut buf); let branch_with_leaf = BranchNode::new(vec![leaf_rlp.clone()], TrieMask::new(0b0010)); @@ -323,7 +352,7 @@ mod tests { assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_ext); let full = BranchNode::new( - core::iter::repeat(word_rlp(&B256::repeat_byte(23))).take(16).collect(), + core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(), TrieMask::new(u16::MAX), ); let encoded = alloy_rlp::encode(&full); diff --git a/src/nodes/extension.rs b/src/nodes/extension.rs index d6aa958..b7177b0 100644 --- a/src/nodes/extension.rs +++ b/src/nodes/extension.rs @@ -1,4 +1,4 @@ -use super::{super::Nibbles, rlp_node, unpack_path_to_nibbles}; +use super::{super::Nibbles, unpack_path_to_nibbles, RlpNode}; use alloy_primitives::{hex, Bytes}; use alloy_rlp::{length_of_length, BufMut, Decodable, Encodable, Header}; use core::fmt; @@ -20,7 +20,7 @@ pub struct ExtensionNode { /// The key for this extension node. pub key: Nibbles, /// A pointer to the child node. - pub child: Vec, + pub child: RlpNode, } impl fmt::Debug for ExtensionNode { @@ -60,7 +60,7 @@ impl Decodable for ExtensionNode { }; let key = unpack_path_to_nibbles(first, &encoded_key[1..]); - let child = Vec::from(bytes); + let child = RlpNode::from_raw_rlp(bytes)?; Ok(Self { key, child }) } } @@ -73,7 +73,7 @@ impl ExtensionNode { pub const ODD_FLAG: u8 = 0x10; /// Creates a new extension node with the given key and a pointer to the child. - pub const fn new(key: Nibbles, child: Vec) -> Self { + pub const fn new(key: Nibbles, child: RlpNode) -> Self { Self { key, child } } @@ -118,17 +118,20 @@ impl Encodable for ExtensionNodeRef<'_> { impl<'a> ExtensionNodeRef<'a> { /// Creates a new extension node with the given key and a pointer to the child. + #[inline] pub const fn new(key: &'a Nibbles, child: &'a [u8]) -> Self { Self { key, child } } - /// RLP encodes the node and returns either RLP(Node) or RLP(keccak(RLP(node))). - pub fn rlp(&self, buf: &mut Vec) -> Vec { - self.encode(buf); - rlp_node(buf) + /// RLP-encodes the node and returns either `rlp(node)` or `rlp(keccak(rlp(node)))`. + #[inline] + pub fn rlp(&self, rlp: &mut Vec) -> RlpNode { + self.encode(rlp); + RlpNode::from_rlp(rlp) } /// Returns the length of RLP encoded fields of extension node. + #[inline] fn rlp_payload_length(&self) -> usize { let mut encoded_key_len = self.key.len() / 2 + 1; // For extension nodes the first byte cannot be greater than 0x80. @@ -149,9 +152,9 @@ mod tests { let val = hex!("76657262"); let mut child = vec![]; val.to_vec().as_slice().encode(&mut child); - let extension = ExtensionNode::new(nibble, child); + let extension = ExtensionNode::new(nibble, RlpNode::from_raw(&child).unwrap()); let rlp = extension.as_ref().rlp(&mut vec![]); - assert_eq!(rlp, hex!("c98300646f8476657262")); + assert_eq!(rlp.as_ref(), hex!("c98300646f8476657262")); assert_eq!(ExtensionNode::decode(&mut &rlp[..]).unwrap(), extension); } } diff --git a/src/nodes/leaf.rs b/src/nodes/leaf.rs index b41da51..860681b 100644 --- a/src/nodes/leaf.rs +++ b/src/nodes/leaf.rs @@ -1,4 +1,4 @@ -use super::{super::Nibbles, rlp_node, unpack_path_to_nibbles}; +use super::{super::Nibbles, unpack_path_to_nibbles, RlpNode}; use alloy_primitives::{hex, Bytes}; use alloy_rlp::{length_of_length, BufMut, Decodable, Encodable, Header}; use core::fmt; @@ -18,7 +18,7 @@ pub struct LeafNode { /// The key for this leaf node. pub key: Nibbles, /// The node value. - pub value: Vec, + pub value: RlpNode, } impl fmt::Debug for LeafNode { @@ -58,7 +58,7 @@ impl Decodable for LeafNode { }; let key = unpack_path_to_nibbles(first, &encoded_key[1..]); - let value = Bytes::decode(&mut bytes)?.to_vec(); + let value = RlpNode::decode(&mut bytes)?; Ok(Self { key, value }) } } @@ -71,7 +71,7 @@ impl LeafNode { pub const ODD_FLAG: u8 = 0x30; /// Creates a new leaf node with the given key and value. - pub const fn new(key: Nibbles, value: Vec) -> Self { + pub const fn new(key: Nibbles, value: RlpNode) -> Self { Self { key, value } } @@ -120,14 +120,15 @@ impl<'a> LeafNodeRef<'a> { Self { key, value } } - /// RLP encodes the node and returns either RLP(Node) or RLP(keccak(RLP(node))) - /// depending on if the serialized node was longer than a keccak). - pub fn rlp(&self, out: &mut Vec) -> Vec { - self.encode(out); - rlp_node(out) + /// RLP-encodes the node and returns either `rlp(node)` or `rlp(keccak(rlp(node)))`. + #[inline] + pub fn rlp(&self, rlp: &mut Vec) -> RlpNode { + self.encode(rlp); + RlpNode::from_rlp(rlp) } /// Returns the length of RLP encoded fields of leaf node. + #[inline] fn rlp_payload_length(&self) -> usize { let mut encoded_key_len = self.key.len() / 2 + 1; // For leaf nodes the first byte cannot be greater than 0x80. @@ -154,9 +155,9 @@ mod tests { fn rlp_leaf_node_roundtrip() { let nibble = Nibbles::from_nibbles_unchecked(hex!("0604060f")); let val = hex!("76657262"); - let leaf = LeafNode::new(nibble, val.to_vec()); + let leaf = LeafNode::new(nibble, RlpNode::from_raw(&val).unwrap()); let rlp = leaf.as_ref().rlp(&mut vec![]); - assert_eq!(rlp, hex!("c98320646f8476657262")); + assert_eq!(rlp.as_ref(), hex!("c98320646f8476657262")); assert_eq!(LeafNode::decode(&mut &rlp[..]).unwrap(), leaf); } } diff --git a/src/nodes/mod.rs b/src/nodes/mod.rs index ae76ca2..6133bf1 100644 --- a/src/nodes/mod.rs +++ b/src/nodes/mod.rs @@ -1,6 +1,6 @@ //! Various branch nodes produced by the hash builder. -use alloy_primitives::{keccak256, Bytes, B256}; +use alloy_primitives::B256; use alloy_rlp::{Decodable, Encodable, Header, EMPTY_STRING_CODE}; use core::ops::Range; use nybbles::Nibbles; @@ -18,6 +18,9 @@ pub use extension::{ExtensionNode, ExtensionNodeRef}; mod leaf; pub use leaf::{LeafNode, LeafNodeRef}; +mod rlp; +pub use rlp::RlpNode; + /// The range of valid child indexes. pub const CHILD_INDEX_RANGE: Range = 0..16; @@ -84,7 +87,7 @@ impl Decodable for TrieNode { )); } } else if item != [EMPTY_STRING_CODE] { - branch.stack.push(item.to_vec()); + branch.stack.push(RlpNode::from_raw_rlp(item)?); branch.state_mask.set_bit(idx as u8); } } @@ -109,10 +112,13 @@ impl Decodable for TrieNode { let key = unpack_path_to_nibbles(first, &encoded_key[1..]); let node = if key_flag == LeafNode::EVEN_FLAG || key_flag == LeafNode::ODD_FLAG { - Self::Leaf(LeafNode::new(key, Bytes::decode(&mut items.remove(0))?.to_vec())) + Self::Leaf(LeafNode::new(key, RlpNode::decode(&mut items.remove(0))?)) } else { // We don't decode value because it is expected to be RLP encoded. - Self::Extension(ExtensionNode::new(key, items.remove(0).to_vec())) + Self::Extension(ExtensionNode::new( + key, + RlpNode::from_raw_rlp(items.remove(0))?, + )) }; Ok(node) } @@ -122,32 +128,26 @@ impl Decodable for TrieNode { } impl TrieNode { - /// RLP encodes the node and returns either RLP(Node) or RLP(keccak(RLP(node))). - pub fn rlp(&self, buf: &mut Vec) -> Vec { - self.encode(buf); - rlp_node(buf) + /// RLP-encodes the node and returns either `rlp(node)` or `rlp(keccak(rlp(node)))`. + #[inline] + pub fn rlp(&self, rlp: &mut Vec) -> RlpNode { + self.encode(rlp); + RlpNode::from_rlp(rlp) } } -/// Given an RLP encoded node, returns either self as RLP(node) or RLP(keccak(RLP(node))) +/// Given an RLP-encoded node, returns it either as `rlp(node)` or `rlp(keccak(rlp(node)))`. #[inline] -pub fn rlp_node(rlp: &[u8]) -> Vec { - if rlp.len() < B256::len_bytes() { - rlp.to_vec() - } else { - word_rlp(&keccak256(rlp)) - } +#[deprecated = "use `RlpNode::from_rlp` instead"] +pub fn rlp_node(rlp: &[u8]) -> RlpNode { + RlpNode::from_rlp(rlp) } -/// Optimization for quick encoding of a 32-byte word as RLP. -// TODO: this could return [u8; 33] but Vec is needed everywhere this function is used +/// Optimization for quick RLP-encoding of a 32-byte word. #[inline] -pub fn word_rlp(word: &B256) -> Vec { - // Gets optimized to alloc + write directly into it: https://godbolt.org/z/rfWGG6ebq - let mut arr = [0; 33]; - arr[0] = EMPTY_STRING_CODE + 32; - arr[1..].copy_from_slice(word.as_slice()); - arr.to_vec() +#[deprecated = "use `RlpNode::word_rlp` instead"] +pub fn word_rlp(word: &B256) -> RlpNode { + RlpNode::word_rlp(word) } /// Unpack node path to nibbles. @@ -262,7 +262,7 @@ mod tests { fn rlp_empty_root_node() { let empty_root = TrieNode::EmptyRoot; let rlp = empty_root.rlp(&mut vec![]); - assert_eq!(rlp, hex!("80")); + assert_eq!(rlp[..], hex!("80")); assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), empty_root); } @@ -270,10 +270,10 @@ mod tests { fn rlp_zero_value_leaf_roundtrip() { let leaf = TrieNode::Leaf(LeafNode::new( Nibbles::from_nibbles_unchecked(hex!("0604060f")), - alloy_rlp::encode(alloy_primitives::U256::ZERO), + RlpNode::from_raw(&alloy_rlp::encode(alloy_primitives::U256::ZERO)).unwrap(), )); let rlp = leaf.rlp(&mut vec![]); - assert_eq!(rlp, hex!("c68320646f8180")); + assert_eq!(rlp[..], hex!("c68320646f8180")); assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf); } @@ -282,10 +282,10 @@ mod tests { // leaf let leaf = TrieNode::Leaf(LeafNode::new( Nibbles::from_nibbles_unchecked(hex!("0604060f")), - hex!("76657262").to_vec(), + RlpNode::from_raw(&hex!("76657262")).unwrap(), )); let rlp = leaf.rlp(&mut vec![]); - assert_eq!(rlp, hex!("c98320646f8476657262")); + assert_eq!(rlp[..], hex!("c98320646f8476657262")); assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), leaf); // extension @@ -293,21 +293,21 @@ mod tests { hex!("76657262").to_vec().as_slice().encode(&mut child); let extension = TrieNode::Extension(ExtensionNode::new( Nibbles::from_nibbles_unchecked(hex!("0604060f")), - child, + RlpNode::from_raw(&child).unwrap(), )); let rlp = extension.rlp(&mut vec![]); - assert_eq!(rlp, hex!("c98300646f8476657262")); + assert_eq!(rlp[..], hex!("c98300646f8476657262")); assert_eq!(TrieNode::decode(&mut &rlp[..]).unwrap(), extension); // branch let branch = TrieNode::Branch(BranchNode::new( - core::iter::repeat(word_rlp(&B256::repeat_byte(23))).take(16).collect(), + core::iter::repeat(RlpNode::word_rlp(&B256::repeat_byte(23))).take(16).collect(), TrieMask::new(u16::MAX), )); let mut rlp = vec![]; let rlp_node = branch.rlp(&mut rlp); assert_eq!( - rlp_node, + rlp_node[..], hex!("a0bed74980bbe29d9c4439c10e9c451e29b306fe74bcf9795ecf0ebbd92a220513") ); assert_eq!(rlp, hex!("f90211a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a01717171717171717171717171717171717171717171717171717171717171717a0171717171717171717171717171717171717171717171717171717171717171780")); diff --git a/src/nodes/rlp.rs b/src/nodes/rlp.rs new file mode 100644 index 0000000..6f54531 --- /dev/null +++ b/src/nodes/rlp.rs @@ -0,0 +1,115 @@ +use alloy_primitives::{hex, keccak256, B256}; +use alloy_rlp::EMPTY_STRING_CODE; +use arrayvec::ArrayVec; +use core::fmt; + +const MAX: usize = 33; + +/// An RLP-encoded node. +#[derive(Clone, Default, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RlpNode(ArrayVec); + +impl alloy_rlp::Decodable for RlpNode { + fn decode(buf: &mut &[u8]) -> alloy_rlp::Result { + let bytes = alloy_rlp::Header::decode_bytes(buf, false)?; + Self::from_raw_rlp(bytes) + } +} + +impl core::ops::Deref for RlpNode { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl core::ops::DerefMut for RlpNode { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl AsRef<[u8]> for RlpNode { + #[inline] + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl fmt::Debug for RlpNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "RlpNode({})", hex::encode_prefixed(&self.0)) + } +} + +impl RlpNode { + /// Creates a new RLP-encoded node from the given data. + /// + /// Returns `None` if the data is too large (greater than 33 bytes). + #[inline] + pub fn from_raw(data: &[u8]) -> Option { + let mut arr = ArrayVec::new(); + arr.try_extend_from_slice(data).ok()?; + Some(Self(arr)) + } + + /// Creates a new RLP-encoded node from the given data. + #[inline] + pub fn from_raw_rlp(data: &[u8]) -> alloy_rlp::Result { + Self::from_raw(data).ok_or(alloy_rlp::Error::Custom("RLP node too large")) + } + + /// Given an RLP-encoded node, returns it either as `rlp(node)` or `rlp(keccak(rlp(node)))`. + #[doc(alias = "rlp_node")] + #[inline] + pub fn from_rlp(rlp: &[u8]) -> Self { + if rlp.len() < 32 { + // SAFETY: `rlp` is less than max capacity (33). + unsafe { Self::from_raw(rlp).unwrap_unchecked() } + } else { + Self::word_rlp(&keccak256(rlp)) + } + } + + /// RLP-encodes the given word and returns it as a new RLP node. + #[inline] + pub fn word_rlp(word: &B256) -> Self { + let mut arr = ArrayVec::new(); + arr.push(EMPTY_STRING_CODE + 32); + arr.try_extend_from_slice(word.as_slice()).unwrap(); + Self(arr) + } + + /// Returns the RLP-encoded node as a slice. + #[inline] + pub fn as_slice(&self) -> &[u8] { + &self.0 + } +} + +#[cfg(feature = "arbitrary")] +impl<'u> arbitrary::Arbitrary<'u> for RlpNode { + fn arbitrary(g: &mut arbitrary::Unstructured<'u>) -> arbitrary::Result { + let len = g.int_in_range(0..=MAX)?; + let mut arr = ArrayVec::new(); + arr.try_extend_from_slice(g.bytes(len)?).unwrap(); + Ok(Self(arr)) + } +} + +#[cfg(feature = "arbitrary")] +impl proptest::arbitrary::Arbitrary for RlpNode { + type Parameters = (); + type Strategy = proptest::strategy::BoxedStrategy; + + fn arbitrary_with((): Self::Parameters) -> Self::Strategy { + use proptest::prelude::*; + proptest::collection::vec(proptest::prelude::any::(), 0..=MAX) + .prop_map(|vec| Self::from_raw(&vec).unwrap()) + .boxed() + } +} diff --git a/src/proof/verify.rs b/src/proof/verify.rs index 0de830d..fcabc49 100644 --- a/src/proof/verify.rs +++ b/src/proof/verify.rs @@ -1,7 +1,7 @@ //! Proof verification logic. use crate::{ - nodes::{rlp_node, word_rlp, BranchNode, TrieNode, CHILD_INDEX_RANGE}, + nodes::{BranchNode, RlpNode, TrieNode, CHILD_INDEX_RANGE}, proof::ProofVerificationError, EMPTY_ROOT_HASH, }; @@ -42,9 +42,9 @@ where } let mut walked_path = Nibbles::default(); - let mut next_value = Some(word_rlp(&root)); + let mut next_value = Some(RlpNode::word_rlp(&root)); for node in proof { - if Some(rlp_node(node)) != next_value { + if Some(RlpNode::from_rlp(node)) != next_value { let got = Some(Bytes::copy_from_slice(node)); let expected = next_value.map(|b| Bytes::copy_from_slice(&b)); return Err(ProofVerificationError::ValueMismatch { path: walked_path, got, expected }); @@ -65,12 +65,12 @@ where } next_value = next_value.filter(|_| walked_path == key); - if next_value == value { + if next_value.as_deref() == value.as_deref() { Ok(()) } else { Err(ProofVerificationError::ValueMismatch { path: key, - got: next_value.map(Bytes::from), + got: next_value.as_deref().map(Vec::from).map(Bytes::from), expected: value.map(Bytes::from), }) } @@ -81,7 +81,7 @@ fn process_branch( mut branch: BranchNode, walked_path: &mut Nibbles, key: &Nibbles, -) -> Result>, ProofVerificationError> { +) -> Result, ProofVerificationError> { if let Some(next) = key.get(walked_path.len()) { let mut stack_ptr = branch.as_ref().first_child_index(); for index in CHILD_INDEX_RANGE { @@ -184,7 +184,7 @@ mod tests { Err(ProofVerificationError::ValueMismatch { path: Nibbles::default(), got: Some(Bytes::from(dummy_proof)), - expected: Some(Bytes::from(word_rlp(&EMPTY_ROOT_HASH))) + expected: Some(Bytes::from(RlpNode::word_rlp(&EMPTY_ROOT_HASH)[..].to_vec())) }) ); } @@ -467,18 +467,19 @@ mod tests { let mut buffer = vec![]; - let child_leaf = TrieNode::Leaf(LeafNode::new(Nibbles::from_nibbles([0xa]), vec![0x64])); + let value = RlpNode::from_raw(&[0x64]).unwrap(); + let child_leaf = TrieNode::Leaf(LeafNode::new(Nibbles::from_nibbles([0xa]), value.clone())); let child_branch = TrieNode::Branch(BranchNode::new( vec![ { buffer.clear(); - TrieNode::Leaf(LeafNode::new(Nibbles::from_nibbles([0xa]), vec![0x64])) + TrieNode::Leaf(LeafNode::new(Nibbles::from_nibbles([0xa]), value.clone())) .rlp(&mut buffer) }, { buffer.clear(); - TrieNode::Leaf(LeafNode::new(Nibbles::from_nibbles([0xb]), vec![0x64])) + TrieNode::Leaf(LeafNode::new(Nibbles::from_nibbles([0xb]), value.clone())) .rlp(&mut buffer) }, ],