Skip to content

Commit

Permalink
implements weighted shuffle using N-ary tree
Browse files Browse the repository at this point in the history
  • Loading branch information
behzadnouri committed Mar 23, 2024
1 parent b6d2237 commit f0b4ac2
Showing 1 changed file with 51 additions and 31 deletions.
82 changes: 51 additions & 31 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ use {
std::ops::{AddAssign, Sub, SubAssign},
};

const SHIFT: usize = 4;
const N: usize = 1 << SHIFT; // Number of children of each node.
const MASK: usize = N - 1;

/// Implements an iterator where indices are shuffled according to their
/// weights:
/// - Returned indices are unique in the range [0, weights.len()).
Expand All @@ -34,7 +38,7 @@ where
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let zero = <T as Default>::default();
let mut tree = vec![zero; get_tree_size(weights.len())];
let mut tree = vec![[zero; N - 1]; get_tree_size(weights.len())];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
Expand All @@ -61,10 +65,10 @@ where
};
let mut index = tree.len() + k;
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & MASK;
index = (index - 1) >> SHIFT;
if offset > 0 {
tree[index] += weight;
tree[index][offset - 1] += weight;
}
}
}
Expand All @@ -84,17 +88,17 @@ where

impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T> + std::iter::Sum<T>,
{
// Removes given weight at index k.
fn remove(&mut self, k: usize, weight: T) {
self.weight -= weight;
let mut index = self.tree.len() + k;
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & MASK;
index = (index - 1) >> SHIFT;
if offset > 0 {
self.tree[index] -= weight;
self.tree[index][offset - 1] -= weight;
}
}
}
Expand All @@ -107,15 +111,18 @@ where
debug_assert!(val < self.weight);
let mut index = 0;
let mut weight = self.weight;
while index < self.tree.len() {
if val < self.tree[index] {
weight = self.tree[index];
index = (index << 1) + 1;
} else {
weight -= self.tree[index];
val -= self.tree[index];
index = (index << 1) + 2;
'outer: while index < self.tree.len() {
for (j, &node) in self.tree[index].iter().enumerate() {
if val < node {
weight = node;
index = (index << SHIFT) + j + 1;
continue 'outer;
} else {
weight -= node;
val -= node;
}
}
index = (index << SHIFT) + N;
}
(index - self.tree.len(), weight)
}
Expand All @@ -124,17 +131,17 @@ where
let mut index = self.tree.len() + k;
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & MASK;
index = (index - 1) >> SHIFT;
if offset > 0 {
if self.tree[index] != weight {
self.remove(k, self.tree[index] - weight);
if self.tree[index][offset - 1] != weight {
self.remove(k, self.tree[index][offset - 1] - weight);
} else {
self.remove_zero(k);
}
return;
}
weight += self.tree[index];
weight += self.tree[index].iter().copied().sum();
}
if self.weight != weight {
self.remove(k, self.weight - weight);
Expand All @@ -152,7 +159,14 @@ where

impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
T: Copy
+ Default
+ PartialOrd
+ AddAssign
+ SampleUniform
+ SubAssign
+ Sub<Output = T>
+ std::iter::Sum<T>,
{
// Equivalent to weighted_shuffle.shuffle(&mut rng).next()
pub fn first<R: Rng>(&self, rng: &mut R) -> Option<usize> {
Expand All @@ -172,7 +186,14 @@ where

impl<'a, T: 'a> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
T: Copy
+ Default
+ PartialOrd
+ AddAssign
+ SampleUniform
+ SubAssign
+ Sub<Output = T>
+ std::iter::Sum<T>,
{
pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
std::iter::from_fn(move || {
Expand All @@ -196,14 +217,13 @@ where
// Maps number of items to the "internal" size of the binary tree "implicitly"
// holding those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let shift = usize::BITS
- count.leading_zeros()
- if count.is_power_of_two() && count != 1 {
1
} else {
0
};
(1usize << shift) - 1
let mut out = 0;
let mut k = 1;
while k < count {
out += k;
k *= N;
}
return out;
}

#[cfg(test)]
Expand Down

0 comments on commit f0b4ac2

Please sign in to comment.