Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1.18: implements weighted shuffle using binary tree (backport of #185) #425

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 128 additions & 58 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,52 +18,54 @@ use {
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices.
// Underlying array implementing binary tree.
// tree[i] is the sum of weights in the left sub-tree of node i.
tree: Vec<T>,
// Current sum of all weights, excluding already sampled ones.
weight: T,
zeros: Vec<usize>, // Indices of zero weighted entries.
}

// The implementation uses binary indexed tree:
// https://en.wikipedia.org/wiki/Fenwick_tree
// to maintain cumulative sum of weights excluding already selected indices
// over self.arr.
impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
{
/// If weights are negative or overflow the total sum
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let size = weights.len() + 1;
let zero = <T as Default>::default();
let mut arr = vec![zero; size];
let mut tree = vec![zero; get_tree_size(weights.len())];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
let mut num_overflow = 0;
for (mut k, &weight) in (1usize..).zip(weights) {
for (k, &weight) in weights.iter().enumerate() {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= zero) {
zeros.push(k - 1);
zeros.push(k);
num_negative += 1;
continue;
}
if weight == zero {
zeros.push(k - 1);
zeros.push(k);
continue;
}
sum = match sum.checked_add(&weight) {
Some(val) => val,
None => {
zeros.push(k - 1);
zeros.push(k);
num_overflow += 1;
continue;
}
};
while k < size {
arr[k] += weight;
k += k & k.wrapping_neg();
let mut index = tree.len() + k;
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
tree[index] += weight;
}
}
}
if num_negative > 0 {
Expand All @@ -72,62 +74,77 @@ where
if num_overflow > 0 {
datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64));
}
Self { arr, sum, zeros }
Self {
tree,
weight: sum,
zeros,
}
}
}

impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
{
// Returns cumulative sum of current weights upto index k (inclusive).
fn cumsum(&self, mut k: usize) -> T {
let mut out = <T as Default>::default();
while k != 0 {
out += self.arr[k];
k ^= k & k.wrapping_neg();
}
out
}

// Removes given weight at index k.
fn remove(&mut self, mut k: usize, weight: T) {
self.sum -= weight;
let size = self.arr.len();
while k < size {
self.arr[k] -= weight;
k += k & k.wrapping_neg();
fn remove(&mut self, k: usize, weight: T) {
self.weight -= weight;
let mut index = self.tree.len() + k;
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
self.tree[index] -= weight;
}
}
}

// Returns smallest index such that self.cumsum(k) > val,
// Returns smallest index such that cumsum of weights[..=k] > val,
// along with its respective weight.
fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) {
fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
let zero = <T as Default>::default();
debug_assert!(val >= zero);
debug_assert!(val < self.sum);
let mut lo = (/*index:*/ 0, /*cumsum:*/ zero);
let mut hi = (self.arr.len() - 1, self.sum);
while lo.0 + 1 < hi.0 {
let k = lo.0 + (hi.0 - lo.0) / 2;
let sum = self.cumsum(k);
if sum <= val {
lo = (k, sum);
debug_assert!(val < self.weight);
let mut index = 0;
let mut weight = self.weight;
while index < self.tree.len() {
if val < self.tree[index] {
weight = self.tree[index];
index = (index << 1) + 1;
} else {
hi = (k, sum);
weight -= self.tree[index];
val -= self.tree[index];
index = (index << 1) + 2;
}
}
debug_assert!(lo.1 <= val);
debug_assert!(hi.1 > val);
(hi.0, hi.1 - lo.1)
(index - self.tree.len(), weight)
}

pub fn remove_index(&mut self, index: usize) {
let zero = <T as Default>::default();
let weight = self.cumsum(index + 1) - self.cumsum(index);
if weight != zero {
self.remove(index + 1, weight);
} else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) {
pub fn remove_index(&mut self, k: usize) {
let mut index = self.tree.len() + k;
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
if offset > 0 {
if self.tree[index] != weight {
self.remove(k, self.tree[index] - weight);
} else {
self.remove_zero(k);
}
return;
}
weight += self.tree[index];
}
if self.weight != weight {
self.remove(k, self.weight - weight);
} else {
self.remove_zero(k);
}
}

fn remove_zero(&mut self, k: usize) {
if let Some(index) = self.zeros.iter().position(|&ix| ix == k) {
self.zeros.remove(index);
}
}
Expand All @@ -140,10 +157,10 @@ where
// Equivalent to weighted_shuffle.shuffle(&mut rng).next()
pub fn first<R: Rng>(&self, rng: &mut R) -> Option<usize> {
let zero = <T as Default>::default();
if self.sum > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
if self.weight > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.weight, rng);
let (index, _weight) = WeightedShuffle::search(self, sample);
return Some(index - 1);
return Some(index);
}
if self.zeros.is_empty() {
return None;
Expand All @@ -160,11 +177,11 @@ where
pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
std::iter::from_fn(move || {
let zero = <T as Default>::default();
if self.sum > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
if self.weight > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.weight, rng);
let (index, weight) = WeightedShuffle::search(&self, sample);
self.remove(index, weight);
return Some(index - 1);
return Some(index);
}
if self.zeros.is_empty() {
return None;
Expand All @@ -176,6 +193,19 @@ where
}
}

// Maps number of items to the "internal" size of the binary tree "implicitly"
// holding those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let shift = usize::BITS
- count.leading_zeros()
- if count.is_power_of_two() && count != 1 {
1
} else {
0
};
(1usize << shift) - 1
}

#[cfg(test)]
mod tests {
use {
Expand Down Expand Up @@ -218,6 +248,30 @@ mod tests {
shuffle
}

#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
assert_eq!(get_tree_size(1), 1);
assert_eq!(get_tree_size(2), 1);
assert_eq!(get_tree_size(3), 3);
assert_eq!(get_tree_size(4), 3);
for count in 5..9 {
assert_eq!(get_tree_size(count), 7);
}
for count in 9..17 {
assert_eq!(get_tree_size(count), 15);
}
for count in 17..33 {
assert_eq!(get_tree_size(count), 31);
}
assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1);
assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1);
assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1);
assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1);
}

// Asserts that empty weights will return empty shuffle.
#[test]
fn test_weighted_shuffle_empty_weights() {
Expand Down Expand Up @@ -357,4 +411,20 @@ mod tests {
assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0]));
}
}

#[test]
fn test_weighted_shuffle_paranoid() {
let mut rng = rand::thread_rng();
for size in 0..1351 {
let weights: Vec<_> = repeat_with(|| rng.gen_range(0..1000)).take(size).collect();
let seed = rng.gen::<[u8; 32]>();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng.clone(), weights.clone());
let shuffle = WeightedShuffle::new("", &weights);
if size > 0 {
assert_eq!(shuffle.first(&mut rng.clone()), Some(shuffle_slow[0]));
}
assert_eq!(shuffle.shuffle(&mut rng).collect::<Vec<_>>(), shuffle_slow);
}
}
}
Loading