Skip to content

Commit

Permalink
fix: various tree fixes and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
distractedm1nd committed Apr 29, 2024
1 parent 83b9abf commit f48f676
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 24 deletions.
117 changes: 117 additions & 0 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,120 @@ impl Node {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;

#[test]
fn test_leaf_node_creation() {
let active = true;
let is_left = false;
let label = "test_label".to_string();
let value = "test_value".to_string();
let next = "test_next".to_string();
let leaf = LeafNode::new(active, is_left, label.clone(), value.clone(), next.clone());

assert_eq!(leaf.active, active);
assert_eq!(leaf.is_left_sibling, is_left);
assert_eq!(leaf.label, label);
assert_eq!(leaf.value, value);
assert_eq!(leaf.next, next);
assert!(!leaf.hash.is_empty());
}

#[test]
fn test_inner_node_creation() {
let left_node = Node::default();
let right_node = Node::default();
let index = 0;
let inner_node = InnerNode::new(left_node.clone(), right_node.clone(), index);

let left_pointer = Arc::try_unwrap(inner_node.left).unwrap();
let right_pointer = Arc::try_unwrap(inner_node.right).unwrap();
assert_eq!(left_pointer.get_hash(), left_node.get_hash());
assert_eq!(right_pointer.get_hash(), right_pointer.get_hash());
assert_eq!(inner_node.is_left_sibling, true); // index 0 makes it left
assert!(!inner_node.hash.is_empty());
}

#[test]
fn test_node_default() {
let node = Node::default();
match node {
Node::Leaf(leaf) => {
assert_eq!(leaf.active, false);
assert_eq!(leaf.is_left_sibling, false);
assert_eq!(leaf.label, Node::EMPTY_HASH);
assert_eq!(leaf.value, Node::EMPTY_HASH);
assert_eq!(leaf.next, Node::TAIL);
}
_ => panic!("Default node is not a LeafNode"),
}
}

#[test]
fn test_node_is_active() {
let leaf_node = Node::new_leaf(
true,
false,
"label".to_string(),
"value".to_string(),
"next".to_string(),
);
assert!(leaf_node.is_active());

let inner_node = Node::new_inner(Node::default(), Node::default(), 1);
assert!(inner_node.is_active());
}

#[test]
fn test_node_set_left_sibling() {
let mut node = Node::default();
assert_eq!(node.is_left_sibling(), false);

node.set_left_sibling_value(true);
assert_eq!(node.is_left_sibling(), true);
}

#[test]
fn test_leaf_node_activation() {
let mut node = Node::default();
if let Node::Leaf(ref leaf) = node {
assert_eq!(leaf.active, false);
}

node.set_node_active();
if let Node::Leaf(leaf) = node {
assert!(leaf.active);
} else {
panic!("Node is not a LeafNode");
}
}

#[test]
fn test_update_next_pointer() {
let mut existing_node = Node::new_leaf(
false,
false,
"label1".to_string(),
"value1".to_string(),
"next1".to_string(),
);
let new_node = Node::new_leaf(
true,
true,
"label2".to_string(),
"value2".to_string(),
"next2".to_string(),
);

Node::update_next_pointer(&mut existing_node, &new_node);
if let Node::Leaf(leaf) = existing_node {
assert_eq!(leaf.next, "label2");
} else {
panic!("Existing node is not a LeafNode");
}
}
}
188 changes: 164 additions & 24 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct NonMembershipProof {
pub merkle_proof: MerkleProof,
// Path from the leaf to the root.
pub missing_node: LeafNode,
pub missing_node_index: usize,
}

// `UpdateProof` contains the old `MerkleProof` and the new `MerkleProof` after the update operation
Expand Down Expand Up @@ -53,7 +54,7 @@ impl NonMembershipProof {
pub fn verify(&self) -> bool {
if let Some(Node::Leaf(leaf)) = self.merkle_proof.path.first() {
if self.merkle_proof.verify()
&& self.missing_node.label < leaf.label
&& self.missing_node.label > leaf.label
&& self.missing_node.label < leaf.next
{
return true;
Expand Down Expand Up @@ -109,9 +110,9 @@ impl MerkleProof {

for node in path.iter().skip(1) {
let hash = if node.is_left_sibling() {
format!("{} || {}", node.get_hash(), current_hash)
format!("{}{}", node.get_hash(), current_hash)
} else {
format!("{} || {}", current_hash, node.get_hash())
format!("{}{}", current_hash, node.get_hash())
};
current_hash = sha256(&hash);
}
Expand Down Expand Up @@ -224,9 +225,10 @@ impl IndexedMerkleTree {
///
/// This is done when first initializing the tree, as well as when nodes are updated.
fn rebuild_tree_from_leaves(&mut self) {
let leafcount = (self.nodes.len() + 1) / 2;
// let leafcount = (self.nodes.len() + 1) / 2;
// Will always be truncated so the default value doesnt matter
self.nodes.resize(leafcount, Node::default());
self.nodes.retain(|node| matches!(node, Node::Leaf(_)));
// self.nodes.resize(leafcount, Node::default());
self.rehash_inner_nodes(&self.nodes.clone());
}

Expand All @@ -245,7 +247,6 @@ impl IndexedMerkleTree {
///
/// * `Result<(), MerkleTreeError>` - A result indicating the success or failure of the operation.
fn calculate_root(&mut self) -> Result<(), MerkleTreeError> {
// self.rebuild_tree_from_leaves();
self.rebuild_tree_from_leaves();

// set root not as left sibling
Expand Down Expand Up @@ -338,12 +339,14 @@ impl IndexedMerkleTree {
/// effectively doubling its size. This is necessary when no inactive node is available for
/// the insertion of a new node. Each new node is marked inactive and initialized with
/// default values.
pub fn double_tree_size(&mut self) {
let current_size = self.nodes.len();
self.nodes.resize(current_size * 2 + 1, Node::default());
pub fn double_tree_size(&mut self) -> Result<(), MerkleTreeError> {
let current_size = (self.nodes.len() + 1) / 2;
self.nodes = self.nodes[0..current_size].to_vec();
self.nodes
.extend(std::iter::repeat_with(|| Node::default()).take(current_size));
// update sibling status
let new_nodes = set_left_sibling_status_for_nodes(self.nodes.clone());
self.nodes = new_nodes;
self.nodes = set_left_sibling_status_for_nodes(self.nodes.clone());
self.calculate_root()
}

/// Generates a membership proof for a node at a given index in the indexed merkle tree.
Expand Down Expand Up @@ -435,9 +438,15 @@ impl IndexedMerkleTree {
}
}

// Verify that the node itself does not exist by searching through the tree's nodes.
if self.find_leaf_by_label(&given_node_as_leaf.label).is_some() {
return Err(MerkleTreeError::MerkleProofError);
}

match found_index {
Some(index) => Ok(NonMembershipProof {
merkle_proof: self.generate_membership_proof(found_index.unwrap())?,
merkle_proof: self.generate_membership_proof(index)?,
missing_node_index: index,
missing_node: given_node_as_leaf.clone(),
}),
None => Err(MerkleTreeError::MerkleProofError),
Expand Down Expand Up @@ -499,17 +508,9 @@ impl IndexedMerkleTree {
return Err(MerkleTreeError::MerkleProofError);
}

let old_index = self
.find_node_index(non_membership_proof.merkle_proof.path.first().unwrap())
.ok_or(MerkleTreeError::MerkleProofError)?;
let old_index = non_membership_proof.missing_node_index;

// generate first update proof, changing only the next pointer from the old node
let mut new_old_node = self.nodes[old_index].clone();
Node::update_next_pointer(&mut new_old_node, new_node);
new_old_node.generate_hash();
let first_proof = self.update_node(old_index, new_old_node.clone())?;

// we checked if the found index in the non-membership is from an incative node, if not we have to search for another inactive node to update and if we cant find one, we have to double the tree
// check for an inactive node to use for the update, otherwise double the tree size
let mut new_index = None;
for (i, node) in self.nodes.iter().enumerate() {
if !node.is_active() {
Expand All @@ -522,7 +523,7 @@ impl IndexedMerkleTree {
Some(index) => index,
None => {
// double the tree
self.double_tree_size();
self.double_tree_size()?;
// take the first inactive node
self.nodes
.iter_mut()
Expand All @@ -533,7 +534,12 @@ impl IndexedMerkleTree {
}
};

// generate second update proof
// generate first update proof, changing only the next pointer from the old node
let mut new_old_node = self.nodes[old_index].clone();
Node::update_next_pointer(&mut new_old_node, new_node);
new_old_node.generate_hash();

let first_proof = self.update_node(old_index, new_old_node.clone())?;
let second_proof = self.update_node(new_index, new_node.clone())?;

Ok(InsertProof {
Expand Down Expand Up @@ -613,10 +619,144 @@ pub fn resort_nodes_by_input_order(
mod tests {
use super::*;

const TREE_SIZE: usize = 4;

fn test_node() -> Node {
Node::new_leaf(
true,
true,
sha256("test_label"),
sha256("test_value"),
Node::TAIL.to_string(),
)
}

fn create_random_nodes(count: usize) -> Vec<Node> {
(0..count)
.map(|i| {
Node::new_leaf(
true,
true,
sha256(&format!("test_label_{}", i)),
sha256(&format!("test_value_{}", i)),
Node::TAIL.to_string(),
)
})
.collect()
}

#[test]
fn test_new_with_size() {
let n = 4;
let tree = IndexedMerkleTree::new_with_size(n).unwrap();
assert_eq!(tree.nodes.len(), 2 * n - 1);
}

#[test]
fn test_membership_proofs() {
let mut tree = IndexedMerkleTree::new_with_size(TREE_SIZE).unwrap();
let node = test_node();

let mut non_membership_proof = tree.generate_non_membership_proof(&node);
assert!(non_membership_proof.is_ok());
assert_eq!(non_membership_proof.unwrap().verify(), true);

tree.insert_node(&node).unwrap();

let membership_proof = tree.generate_membership_proof(1);
assert!(membership_proof.is_ok());
let proof = membership_proof.unwrap();
assert_eq!(proof.clone().path.len(), 3);
assert_eq!(proof.verify(), true);

non_membership_proof = tree.generate_non_membership_proof(&node);
assert!(non_membership_proof.is_err());
}

#[test]
fn test_find_node_index() {
let mut tree = IndexedMerkleTree::new_with_size(TREE_SIZE).unwrap();
let node = test_node();

let mut index = tree.find_node_index(&node);
assert_eq!(index, None);

tree.insert_node(&node).unwrap();

index = tree.find_node_index(&node);
assert_eq!(index, Some(1));
}

#[test]
fn test_insert_node() {
let mut tree = IndexedMerkleTree::new_with_size(TREE_SIZE).unwrap();
let node = test_node();

let result = tree.insert_node(&node);
assert!(result.is_ok());
assert_eq!(result.unwrap().verify(), true);
}

#[test]
fn test_update_node() {
let mut tree = IndexedMerkleTree::new_with_size(TREE_SIZE).unwrap();
let node = test_node();

tree.insert_node(&node).unwrap();

let new_node = Node::new_leaf(
true,
true,
sha256("new_label"),
sha256("new_value"),
Node::TAIL.to_string(),
);

let result = tree.update_node(1, new_node);
assert!(result.is_ok());
assert_eq!(result.unwrap().verify(), true);
}

#[test]
fn test_find_leaf_by_label() {
let mut tree = IndexedMerkleTree::new_with_size(TREE_SIZE).unwrap();
let node = test_node();

let found_node = tree.find_leaf_by_label(&sha256("test_label"));
assert!(found_node.is_none());

tree.insert_node(&node).unwrap();

let found_node = tree.find_leaf_by_label(&sha256("test_label"));
assert!(found_node.is_some())
}

#[test]
fn test_double_tree_size() {
let mut tree = IndexedMerkleTree::new_with_size(TREE_SIZE).unwrap();
let node = test_node();

tree.insert_node(&node).unwrap();

assert_eq!(tree.nodes.len(), TREE_SIZE * 2 - 1);
let res = tree.double_tree_size();
assert!(res.is_ok());
assert_eq!(tree.nodes.len(), (TREE_SIZE * 2 - 1) * 2 + 1);

let found_node = tree.find_leaf_by_label(&sha256("test_label"));
assert!(found_node.is_some());
}

#[test]
fn test_insert_node_doubles_tree_size() {
let mut tree = IndexedMerkleTree::new_with_size(TREE_SIZE).unwrap();
let nodes = create_random_nodes(TREE_SIZE);

nodes.iter().for_each(|node| {
tree.insert_node(node).unwrap();
});

let node_count = TREE_SIZE * 2 - 1;
assert_eq!(tree.nodes.len(), node_count * 2 + 1);
}
}

0 comments on commit f48f676

Please sign in to comment.