diff --git a/Cargo.toml b/Cargo.toml index 6b42c2b..3342942 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "polkadot-ckb-merkle-mountain-range" -version = "0.6.0" +version = "0.8.0" authors = [ - "Nervos Core Dev ", - "Parity Technologies ", - "Robert Hambrock " + "Nervos Core Dev ", + "Parity Technologies ", + "Robert Hambrock " ] edition = "2018" @@ -18,7 +18,7 @@ std = [] [dependencies] cfg-if = "1.0" -itertools = {version = "0.10.5", default-features = false, features = ["use_alloc"]} +itertools = { version = "0.10.5", default-features = false, features = ["use_alloc"] } [dev-dependencies] faster-hex = "0.8.0" diff --git a/src/ancestry_proof.rs b/src/ancestry_proof.rs index d9e510a..b2ad403 100644 --- a/src/ancestry_proof.rs +++ b/src/ancestry_proof.rs @@ -1,10 +1,11 @@ -use crate::collections::VecDeque; use crate::helper::{ - get_peak_map, get_peaks, is_descendant_pos, leaf_index_to_pos, parent_offset, - pos_height_in_tree, sibling_offset, + get_peak_map, get_peaks, leaf_index_to_pos, parent_offset, pos_height_in_tree, sibling_offset, }; -use crate::mmr::{bagging_peaks_hashes, take_while_vec}; +pub use crate::mmr::bagging_peaks_hashes; +use crate::mmr::take_while_vec; +use crate::util::BTreeMapExt; use crate::vec::Vec; +use crate::BTreeMap; use crate::{Error, Merge, Result}; use core::fmt::Debug; use core::marker::PhantomData; @@ -31,15 +32,12 @@ impl> AncestryProof { if current_leaves_count <= self.prev_peaks.len() as u64 { return Err(Error::CorruptedProof); } - // Test if previous root is correct. - let prev_peaks_positions = { - let prev_peaks_positions = get_peaks(self.prev_mmr_size); - if prev_peaks_positions.len() != self.prev_peaks.len() { - return Err(Error::CorruptedProof); - } - prev_peaks_positions - }; + // Test if previous root is correct. + let prev_peaks_positions = get_peaks(self.prev_mmr_size); + if prev_peaks_positions.len() != self.prev_peaks.len() { + return Err(Error::CorruptedProof); + } let calculated_prev_root = bagging_peaks_hashes::(self.prev_peaks.clone())?; if calculated_prev_root != prev_root { return Ok(false); @@ -74,8 +72,8 @@ impl> NodeMerkleProof { &self.proof } - pub fn calculate_root(&self, leaves: Vec<(u64, T)>) -> Result { - calculate_root::<_, M, _>(leaves, self.mmr_size, self.proof.iter()) + pub fn calculate_root(&self, nodes: Vec<(u64, T)>) -> Result { + calculate_root::<_, M, _>(nodes, self.mmr_size, self.proof.iter()) } /// from merkle proof of leaf n to calculate merkle root of n + 1 leaves. @@ -140,45 +138,27 @@ fn calculate_peak_root< 'a, T: 'a + PartialEq, M: Merge, - // I: Iterator + // I: Iterator, >( nodes: Vec<(u64, T)>, peak_pos: u64, // proof_iter: &mut I, ) -> Result { - debug_assert!(!nodes.is_empty(), "can't be empty"); - // (position, hash, height) - - let mut queue: VecDeque<_> = nodes - .into_iter() - .map(|(pos, item)| (pos, item, pos_height_in_tree(pos))) - .collect(); - - let mut sibs_processed_from_back = Vec::new(); + let mut queue = BTreeMap::new(); + for (pos, item) in nodes.into_iter() { + if !queue.checked_insert((pos_height_in_tree(pos), pos), item) { + return Err(Error::CorruptedProof); + } + } - // calculate tree root from each items - while let Some((pos, item, height)) = queue.pop_front() { + while let Some(((height, pos), item)) = queue.pop_first() { if pos == peak_pos { if queue.is_empty() { // return root once queue is consumed return Ok(item); - } - if queue - .iter() - .any(|entry| entry.0 == peak_pos && entry.1 != item) - { + } else { return Err(Error::CorruptedProof); } - if queue - .iter() - .all(|entry| entry.0 == peak_pos && &entry.1 == &item && entry.2 == height) - { - // return root if remaining queue consists only of duplicate root entries - return Ok(item); - } - // if queue not empty, push peak back to the end - queue.push_back((pos, item, height)); - continue; } // calculate sibling let next_height = pos_height_in_tree(pos + 1); @@ -186,59 +166,43 @@ fn calculate_peak_root< let sibling_offset = sibling_offset(height); if next_height > height { // implies pos is right sibling - let (sib_pos, parent_pos) = (pos - sibling_offset, pos + 1); - let parent_item = if Some(&sib_pos) == queue.front().map(|(pos, _, _)| pos) { - let sibling_item = queue.pop_front().map(|(_, item, _)| item).unwrap(); - M::merge(&sibling_item, &item)? - } else if Some(&sib_pos) == queue.back().map(|(pos, _, _)| pos) { - let sibling_item = queue.pop_back().map(|(_, item, _)| item).unwrap(); - M::merge(&sibling_item, &item)? - } - // handle special if next queue item is descendant of sibling - else if let Some(&(front_pos, ..)) = queue.front() { - if height > 0 && is_descendant_pos(sib_pos, front_pos) { - queue.push_back((pos, item, height)); - continue; + let sib_pos = pos - sibling_offset; + let parent_pos = pos + 1; + let parent_item = + if Some(&sib_pos) == queue.first_key_value().map(|((_, pos), _)| pos) { + let sibling_item = queue.pop_first().map(|((_, _), item)| item).unwrap(); + M::merge(&sibling_item, &item)? } else { return Err(Error::CorruptedProof); - } - } else { - return Err(Error::CorruptedProof); - }; + // Old `mmr.rs` code. It's not needed anymore since now we merge the `proof_iter` + // items with the nodes. + // let sibling_item = &proof_iter.next().ok_or(Error::CorruptedProof)?.1; + // M::merge(sibling_item, &item)? + }; (parent_pos, parent_item) } else { // pos is left sibling - let (sib_pos, parent_pos) = (pos + sibling_offset, pos + parent_offset(height)); - let parent_item = if Some(&sib_pos) == queue.front().map(|(pos, _, _)| pos) { - let sibling_item = queue.pop_front().map(|(_, item, _)| item).unwrap(); - M::merge(&item, &sibling_item)? - } else if Some(&sib_pos) == queue.back().map(|(pos, _, _)| pos) { - let sibling_item = queue.pop_back().map(|(_, item, _)| item).unwrap(); - let parent = M::merge(&item, &sibling_item)?; - sibs_processed_from_back.push((sib_pos, sibling_item, height)); - parent - } else if let Some(&(front_pos, ..)) = queue.front() { - if height > 0 && is_descendant_pos(sib_pos, front_pos) { - queue.push_back((pos, item, height)); - continue; + let sib_pos = pos + sibling_offset; + let parent_pos = pos + parent_offset(height); + let parent_item = + if Some(&sib_pos) == queue.first_key_value().map(|((_, pos), _)| pos) { + let sibling_item = queue.pop_first().map(|((_, _), item)| item).unwrap(); + M::merge(&item, &sibling_item)? } else { return Err(Error::CorruptedProof); - } - } else { - return Err(Error::CorruptedProof); - }; + // Old `mmr.rs` code. It's not needed anymore since now we merge the `proof_iter` + // items with the nodes. + // let sibling_item = &proof_iter.next().ok_or(Error::CorruptedProof)?.1; + // M::merge(&item, sibling_item)? + }; (parent_pos, parent_item) } }; if parent_pos <= peak_pos { - let parent = (parent_pos, parent_item, height + 1); - if peak_pos == parent_pos - || queue.front() != Some(&parent) - && !sibs_processed_from_back.iter().any(|item| item == &parent) - { - queue.push_front(parent) - }; + if !queue.checked_insert((height + 1, parent_pos), parent_item) { + return Err(Error::CorruptedProof); + } } else { return Err(Error::CorruptedProof); } @@ -266,7 +230,6 @@ fn calculate_peaks_hashes< .into_iter() .chain(proof_iter.cloned()) .sorted_by_key(|(pos, _)| *pos) - .dedup_by(|a, b| a.0 == b.0) .collect(); let peaks = get_peaks(mmr_size); @@ -293,11 +256,13 @@ fn calculate_peaks_hashes< return Err(Error::CorruptedProof); } - // check rhs peaks + // Old `mmr.rs` code. It's not needed anymore since now we merge the `proof_iter` + // items with the nodes. + // // check rhs peaks // if let Some((_, rhs_peaks_hashes)) = proof_iter.next() { // peaks_hashes.push(rhs_peaks_hashes.clone()); // } - // ensure nothing left in proof_iter + // // ensure nothing left in proof_iter // if proof_iter.next().is_some() { // return Err(Error::CorruptedProof); // } @@ -321,3 +286,41 @@ fn calculate_root< let peaks_hashes = calculate_peaks_hashes::<_, M, _>(nodes, mmr_size, proof_iter)?; bagging_peaks_hashes::<_, M>(peaks_hashes) } + +pub fn expected_ancestry_proof_size(prev_mmr_size: u64, mmr_size: u64) -> usize { + let mut expected_proof_size: usize = 0; + let mut prev_peaks = get_peaks(prev_mmr_size); + let peaks = get_peaks(mmr_size); + + for (peak_idx, peak) in peaks.iter().enumerate() { + let mut local_prev_peaks: Vec = take_while_vec(&mut prev_peaks, |pos| *pos <= *peak); + + // Pop lowest local prev peak under the current peak + match local_prev_peaks.pop() { + Some(mut node) => { + let mut height = pos_height_in_tree(node); + while node < *peak { + if pos_height_in_tree(node + 1) > height { + // Node is right sibling + node += 1; + height += 1; + } else { + // Node is left sibling + expected_proof_size += 1; + node += parent_offset(height); + height += 1; + } + } + } + None => { + if peak_idx <= peaks.len() { + // Account for rhs bagging peaks + expected_proof_size += 1; + } + break; + } + } + } + + expected_proof_size +} diff --git a/src/helper.rs b/src/helper.rs index 50ca5a0..8a24e68 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -75,16 +75,6 @@ pub fn get_peak_map(mmr_size: u64) -> u64 { peak_map } -/// Returns whether `descendant_contender` is a descendant of `ancestor_contender` in a tree of the MMR. -pub fn is_descendant_pos(ancestor_contender: u64, descendant_contender: u64) -> bool { - // NOTE: "ancestry" here refers to the hierarchy within an MMR tree, not temporal hierarchy. - // the descendant needs to have been added to the mmr prior to the ancestor - descendant_contender <= ancestor_contender - // the descendant needs to be within the cone of positions descendant from the ancestor - && descendant_contender - >= (ancestor_contender + 1 - sibling_offset(pos_height_in_tree(ancestor_contender))) -} - /// Returns the pos of the peaks in the mmr. /// for example, for a mmr with 11 leaves, the mmr_size is 19, it will return [14, 17, 18]. /// 14 diff --git a/src/lib.rs b/src/lib.rs index f5e94db..55eb27b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,11 +23,17 @@ cfg_if::cfg_if! { use std::collections; use std::vec; use std::string; + use std::collections::BTreeSet; + use std::collections::BTreeMap; + use std::collections::btree_map::Entry as BTreeMapEntry; } else { extern crate alloc; use alloc::borrow; use alloc::collections; use alloc::vec; use alloc::string; + use alloc::collections::btree_set::BTreeSet; + use alloc::collections::btree_map::BTreeMap; + use alloc::collections::btree_map::Entry as BTreeMapEntry; } } diff --git a/src/mmr.rs b/src/mmr.rs index 7e65dbc..3ca0161 100644 --- a/src/mmr.rs +++ b/src/mmr.rs @@ -12,13 +12,12 @@ use crate::helper::{ pos_height_in_tree, sibling_offset, }; use crate::mmr_store::{MMRBatch, MMRStoreReadOps, MMRStoreWriteOps}; -use crate::util::VeqDequeExt; use crate::vec; use crate::vec::Vec; +use crate::BTreeSet; use crate::{Error, Merge, Result}; use core::fmt::Debug; use core::marker::PhantomData; - #[allow(clippy::upper_case_acronyms)] pub struct MMR { mmr_size: u64, @@ -241,13 +240,13 @@ impl, S: MMRStoreReadOps> MMR = VecDeque::new(); + let mut queue = BTreeSet::new(); for value in pos_list.iter().map(|pos| (pos_height_in_tree(*pos), *pos)) { - queue.insert_sorted(value); + queue.insert(value); } // Generate sub-tree merkle proof for positions - while let Some((height, pos)) = queue.pop_front() { + while let Some((height, pos)) = queue.pop_first() { debug_assert!(pos <= peak_pos); if pos == peak_pos { if queue.is_empty() { @@ -270,9 +269,9 @@ impl, S: MMRStoreReadOps> MMR, S: MMRStoreReadOps> MMR, I: Iterator>(mut peaks_hashes: Vec) -> Result { +pub fn bagging_peaks_hashes>(mut peaks_hashes: Vec) -> Result { // bagging peaks // bagging from right to left via hash(right, left). while peaks_hashes.len() > 1 { diff --git a/src/tests/test_ancestry.rs b/src/tests/test_ancestry.rs index fc016df..db4ffd3 100644 --- a/src/tests/test_ancestry.rs +++ b/src/tests/test_ancestry.rs @@ -1,4 +1,5 @@ use super::{MergeNumberHash, NumberHash}; +use crate::ancestry_proof::expected_ancestry_proof_size; use crate::leaf_index_to_mmr_size; use crate::util::{MemMMR, MemStore}; @@ -7,7 +8,7 @@ fn test_ancestry() { let store = MemStore::default(); let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); - let mmr_size = 300; + let mmr_size = 5000; let mut prev_roots = Vec::new(); for i in 0..mmr_size { mmr.push(NumberHash::from(i)).unwrap(); @@ -21,5 +22,9 @@ fn test_ancestry() { assert!(ancestry_proof .verify_ancestor(root.clone(), prev_roots[i as usize].clone()) .unwrap()); + assert_eq!( + expected_ancestry_proof_size(ancestry_proof.prev_mmr_size, mmr.mmr_size()), + ancestry_proof.prev_peaks_proof.proof_items().len() + ); } } diff --git a/src/tests/test_node_mmr.rs b/src/tests/test_node_mmr.rs index 8d31593..556f833 100644 --- a/src/tests/test_node_mmr.rs +++ b/src/tests/test_node_mmr.rs @@ -24,22 +24,15 @@ fn test_mmr(count: u32, proof_elem: Vec) { .collect(), ) .expect("gen proof"); - assert!(proof - .proof_items() - .iter() - .zip(proof.proof_items().iter().skip(1)) - .all(|((pos_a, _), (pos_b, _))| pos_a < pos_b)); mmr.commit().expect("commit changes"); - let result = proof - .verify( - root, - proof_elem - .iter() - .map(|elem| (positions[*elem as usize], NumberHash::from(*elem))) - .collect(), - ) - .unwrap(); - assert!(result); + let result = proof.verify( + root, + proof_elem + .iter() + .map(|elem| (positions[*elem as usize], NumberHash::from(*elem))) + .collect(), + ); + assert_eq!(result, Ok(true)); } fn test_gen_new_root_from_proof(count: u32) { @@ -218,17 +211,21 @@ fn test_invalid_proof_verification( }); let tampered_proof: Option> = - if let Some(tampered_proof_positions) = handrolled_tampered_proof_positions { - Some(NodeMerkleProof::new( + handrolled_tampered_proof_positions.map(|tampered_proof_positions| { + NodeMerkleProof::new( mmr.mmr_size(), tampered_proof_positions .iter() .map(|pos| (*pos, mmr.batch().get_elem(*pos).unwrap().unwrap())) .collect(), - )) - } else { - None - }; + ) + }); + // if proof items have been tampered with, the proof verification fails + if let Some(tampered_proof) = tampered_proof { + let tampered_proof_result = + tampered_proof.verify(root.clone(), tampered_entries_to_verify.clone()); + assert!(tampered_proof_result.is_err() || !tampered_proof_result.unwrap()); + } // test with the proof generated by the library itself, or, if provided, a handrolled proof let proof = if let Some(proof_positions) = handrolled_proof_positions { @@ -242,21 +239,11 @@ fn test_invalid_proof_verification( } else { mmr.gen_node_proof(positions_to_verify.clone()).unwrap() }; - - // if proof items have been tampered with, the proof verification fails - if let Some(tampered_proof) = tampered_proof { - let tampered_proof_result = - tampered_proof.verify(root.clone(), tampered_entries_to_verify.clone()); - assert!(tampered_proof_result.is_err() || !tampered_proof_result.unwrap()); - } - + // verification of the correct nodes passes + assert_eq!(proof.verify(root.clone(), entries_to_verify), Ok(true)); // if any nodes to be verified aren't members of the mmr, the proof verification fails - let tampered_entries_result = proof.verify(root.clone(), tampered_entries_to_verify.clone()); + let tampered_entries_result = proof.verify(root, tampered_entries_to_verify.clone()); assert!(tampered_entries_result.is_err() || !tampered_entries_result.unwrap()); - - let proof_verification = proof.verify(root, entries_to_verify); - // verification of the correct nodes passes - assert!(proof_verification.unwrap()); } #[test] diff --git a/src/util.rs b/src/util.rs index ab947d1..26ea82d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,4 @@ -use crate::collections::{BTreeMap, VecDeque}; +use crate::collections::BTreeMap; use crate::{vec::Vec, MMRStoreReadOps, MMRStoreWriteOps, Result, MMR}; use core::cell::RefCell; @@ -35,17 +35,26 @@ impl MMRStoreWriteOps for &MemStore { pub type MemMMR<'a, T, M> = MMR>; -pub trait VeqDequeExt { - fn insert_sorted(&mut self, value: T); +pub trait BTreeMapExt { + fn checked_insert(&mut self, key: K, value: V) -> bool; } -impl VeqDequeExt for VecDeque { - fn insert_sorted(&mut self, value: T) { - match self.binary_search(&value) { - Ok(_pos) => { - // element already in vector @ `pos` +impl BTreeMapExt for BTreeMap { + fn checked_insert(&mut self, key: K, value: V) -> bool { + use crate::BTreeMapEntry; + + let entry = self.entry(key); + match entry { + BTreeMapEntry::Vacant(slot) => { + slot.insert(value); + } + BTreeMapEntry::Occupied(old_value) => { + if old_value.get() != &value { + return false; + } } - Err(pos) => self.insert(pos, value), } + + true } }