Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix a bug with MemForest due to off-by-one error #62

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions src/accumulator/mem_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
let mut positions = Vec::new();
for target in targets {
let node = self.map.get(target).ok_or("Could not find node")?;
let position = self.get_pos(node);
let position = self.get_pos(node)?;

positions.push(position);
}
let needed = get_proof_positions(&positions, self.leaves, tree_rows(self.leaves));
Expand Down Expand Up @@ -414,21 +415,20 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
}

fn del(&mut self, targets: &[Hash]) -> Result<(), String> {
let mut pos = targets
.iter()
.flat_map(|target| self.map.get(target))
.flat_map(|target| target.upgrade())
.map(|target| {
(
self.get_pos(self.map.get(&target.data.get()).unwrap()),
target.data.get(),
)
})
.collect::<Vec<_>>();
let mut nodes = Vec::new();

pos.sort();
let (_, targets): (Vec<u64>, Vec<Hash>) = pos.into_iter().unzip();
for target in targets {
let node_ref = self.map.get(target).ok_or("Could not find node")?;
let pos = self.get_pos(node_ref)?;

let node = node_ref.upgrade().ok_or("Could not upgrade node")?;

nodes.push((pos, node.get_data()));
}

nodes.sort_by(|a, b| a.0.cmp(&b.0));

for (_, target) in nodes {
match self.map.remove(&target) {
Some(target) => {
self.del_single(&target.upgrade().unwrap());
Expand All @@ -450,18 +450,21 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
proof.verify(del_hashes, &roots, self.leaves)
}

fn get_pos(&self, node: &Weak<Node<Hash>>) -> u64 {
fn get_pos(&self, node: &Weak<Node<Hash>>) -> Result<u64, String> {
// This indicates whether the node is a left or right child at each level
// When we go down the tree, we can use the indicator to know which
// child to take.
let mut left_child_indicator = 0_u64;
let mut rows_to_top = 0;
let mut node = node.upgrade().unwrap();
let mut node = node
.upgrade()
.ok_or("Could not upgrade node. Is this reference valid?")?;

while let Some(parent) = node.parent.clone().into_inner() {
let parent_left = parent
.upgrade()
.and_then(|parent| parent.left.clone().into_inner())
.unwrap()
.ok_or("Could not upgrade parent")?
.clone();

// If the current node is a left child, we left-shift the indicator
Expand All @@ -475,22 +478,26 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
left_child_indicator |= 1;
}
rows_to_top += 1;
node = parent.upgrade().unwrap();
node = parent.upgrade().ok_or("could not upgrade parent")?;
}
let mut root_idx = self.roots.len() - 1;
let forest_rows = tree_rows(self.leaves);
let mut root_row = 0;
let mut root_row = None;
// Find the root of the tree that the node belongs to
for row in 0..forest_rows {
for row in 0..=forest_rows {
if is_root_populated(row, self.leaves) {
let root = &self.roots[root_idx];
if root.get_data() == node.get_data() {
root_row = row;
root_row = Some(row);
break;
}
root_idx -= 1;
}
}

let root_row = root_row.ok_or(format!(
"Could not find the root position for row {root_idx}"
))?;
let mut pos = root_position(self.leaves, root_row, forest_rows);
for _ in 0..rows_to_top {
// If LSB is 0, go left, otherwise go right
Expand All @@ -505,7 +512,8 @@ impl<Hash: AccumulatorHash> MemForest<Hash> {
}
left_child_indicator >>= 1;
}
pos

Ok(pos)
}

fn del_single(&mut self, node: &Node<Hash>) -> Option<()> {
Expand Down Expand Up @@ -950,7 +958,7 @@ mod test {
($p:ident, $pos:literal) => {
assert_eq!(
$p.get_pos(&Rc::downgrade(&$p.grab_node($pos).unwrap().0)),
$pos
Ok($pos)
);
};
}
Expand All @@ -971,18 +979,18 @@ mod test {
test_get_pos!(p, 11);
test_get_pos!(p, 12);

assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), 28);
assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), Ok(28));
assert_eq!(
p.get_pos(&Rc::downgrade(
p.get_roots()[0].left.borrow().as_ref().unwrap()
)),
24
Ok(24)
);
assert_eq!(
p.get_pos(&Rc::downgrade(
p.get_roots()[0].right.borrow().as_ref().unwrap()
)),
25
Ok(25)
);
}

Expand Down
Loading