-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance improvements for shuffle
and partial_shuffle
#1272
Changes from all commits
ce42437
e5f2c3b
ae53b3c
485d015
c7c52a5
d2e939f
06820c2
f7b4a99
fbd7114
7bff828
bf38097
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright 2018-2023 Developers of the Rand project. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
use criterion::{black_box, criterion_group, criterion_main, Criterion}; | ||
use rand::prelude::*; | ||
use rand::SeedableRng; | ||
|
||
criterion_group!( | ||
name = benches; | ||
config = Criterion::default(); | ||
targets = bench | ||
); | ||
criterion_main!(benches); | ||
|
||
pub fn bench(c: &mut Criterion) { | ||
bench_rng::<rand_chacha::ChaCha12Rng>(c, "ChaCha12"); | ||
bench_rng::<rand_pcg::Pcg32>(c, "Pcg32"); | ||
bench_rng::<rand_pcg::Pcg64>(c, "Pcg64"); | ||
} | ||
|
||
fn bench_rng<Rng: RngCore + SeedableRng>(c: &mut Criterion, rng_name: &'static str) { | ||
for length in [1, 2, 3, 10, 100, 1000, 10000].map(|x| black_box(x)) { | ||
c.bench_function(format!("shuffle_{length}_{rng_name}").as_str(), |b| { | ||
let mut rng = Rng::seed_from_u64(123); | ||
let mut vec: Vec<usize> = (0..length).collect(); | ||
b.iter(|| { | ||
vec.shuffle(&mut rng); | ||
vec[0] | ||
}) | ||
}); | ||
|
||
if length >= 10 { | ||
c.bench_function( | ||
format!("partial_shuffle_{length}_{rng_name}").as_str(), | ||
|b| { | ||
let mut rng = Rng::seed_from_u64(123); | ||
let mut vec: Vec<usize> = (0..length).collect(); | ||
b.iter(|| { | ||
vec.partial_shuffle(&mut rng, length / 2); | ||
vec[0] | ||
}) | ||
}, | ||
); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright 2018-2023 Developers of the Rand project. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
use crate::{Rng, RngCore}; | ||
|
||
/// Similar to a Uniform distribution, | ||
/// but after returning a number in the range [0,n], n is increased by 1. | ||
pub(crate) struct IncreasingUniform<R: RngCore> { | ||
pub rng: R, | ||
n: u32, | ||
// Chunk is a random number in [0, (n + 1) * (n + 2) *..* (n + chunk_remaining) ) | ||
chunk: u32, | ||
chunk_remaining: u8, | ||
} | ||
|
||
impl<R: RngCore> IncreasingUniform<R> { | ||
/// Create a dice roller. | ||
/// The next item returned will be a random number in the range [0,n] | ||
pub fn new(rng: R, n: u32) -> Self { | ||
// If n = 0, the first number returned will always be 0 | ||
// so we don't need to generate a random number | ||
let chunk_remaining = if n == 0 { 1 } else { 0 }; | ||
Self { | ||
rng, | ||
n, | ||
chunk: 0, | ||
chunk_remaining, | ||
} | ||
} | ||
|
||
/// Returns a number in [0,n] and increments n by 1. | ||
/// Generates new random bits as needed | ||
/// Panics if `n >= u32::MAX` | ||
#[inline] | ||
pub fn next_index(&mut self) -> usize { | ||
let next_n = self.n + 1; | ||
|
||
// There's room for further optimisation here: | ||
// gen_range uses rejection sampling (or other method; see #1196) to avoid bias. | ||
// When the initial sample is biased for range 0..bound | ||
// it may still be viable to use for a smaller bound | ||
// (especially if small biases are considered acceptable). | ||
|
||
let next_chunk_remaining = self.chunk_remaining.checked_sub(1).unwrap_or_else(|| { | ||
// If the chunk is empty, generate a new chunk | ||
let (bound, remaining) = calculate_bound_u32(next_n); | ||
// bound = (n + 1) * (n + 2) *..* (n + remaining) | ||
self.chunk = self.rng.gen_range(0..bound); | ||
// Chunk is a random number in | ||
// [0, (n + 1) * (n + 2) *..* (n + remaining) ) | ||
|
||
remaining - 1 | ||
}); | ||
|
||
let result = if next_chunk_remaining == 0 { | ||
// `chunk` is a random number in the range [0..n+1) | ||
// Because `chunk_remaining` is about to be set to zero | ||
// we do not need to clear the chunk here | ||
self.chunk as usize | ||
} else { | ||
// `chunk` is a random number in a range that is a multiple of n+1 | ||
// so r will be a random number in [0..n+1) | ||
let r = self.chunk % next_n; | ||
self.chunk /= next_n; | ||
r as usize | ||
}; | ||
|
||
self.chunk_remaining = next_chunk_remaining; | ||
self.n = next_n; | ||
result | ||
} | ||
} | ||
|
||
#[inline] | ||
/// Calculates `bound`, `count` such that bound (m)*(m+1)*..*(m + remaining - 1) | ||
fn calculate_bound_u32(m: u32) -> (u32, u8) { | ||
debug_assert!(m > 0); | ||
#[inline] | ||
const fn inner(m: u32) -> (u32, u8) { | ||
let mut product = m; | ||
let mut current = m + 1; | ||
|
||
loop { | ||
if let Some(p) = u32::checked_mul(product, current) { | ||
product = p; | ||
current += 1; | ||
} else { | ||
// Count has a maximum value of 13 for when min is 1 or 2 | ||
let count = (current - m) as u8; | ||
return (product, count); | ||
} | ||
} | ||
} | ||
|
||
const RESULT2: (u32, u8) = inner(2); | ||
if m == 2 { | ||
// Making this value a constant instead of recalculating it | ||
// gives a significant (~50%) performance boost for small shuffles | ||
return RESULT2; | ||
} | ||
|
||
inner(m) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
// Copyright 2018 Developers of the Rand project. | ||
// Copyright 2018-2023 Developers of the Rand project. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
|
@@ -29,6 +29,8 @@ mod coin_flipper; | |
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] | ||
pub mod index; | ||
|
||
mod increasing_uniform; | ||
|
||
#[cfg(feature = "alloc")] | ||
use core::ops::Index; | ||
|
||
|
@@ -42,6 +44,7 @@ use crate::distributions::WeightedError; | |
use crate::Rng; | ||
|
||
use self::coin_flipper::CoinFlipper; | ||
use self::increasing_uniform::IncreasingUniform; | ||
|
||
/// Extension trait on slices, providing random mutation and sampling methods. | ||
/// | ||
|
@@ -620,10 +623,11 @@ impl<T> SliceRandom for [T] { | |
where | ||
R: Rng + ?Sized, | ||
{ | ||
for i in (1..self.len()).rev() { | ||
// invariant: elements with index > i have been locked in place. | ||
self.swap(i, gen_index(rng, i + 1)); | ||
if self.len() <= 1 { | ||
// There is no need to shuffle an empty or single element slice | ||
return; | ||
} | ||
self.partial_shuffle(rng, self.len()); | ||
} | ||
|
||
fn partial_shuffle<R>( | ||
|
@@ -632,19 +636,30 @@ impl<T> SliceRandom for [T] { | |
where | ||
R: Rng + ?Sized, | ||
{ | ||
// This applies Durstenfeld's algorithm for the | ||
// [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) | ||
// for an unbiased permutation, but exits early after choosing `amount` | ||
// elements. | ||
|
||
let len = self.len(); | ||
let end = if amount >= len { 0 } else { len - amount }; | ||
let m = self.len().saturating_sub(amount); | ||
|
||
for i in (end..len).rev() { | ||
// invariant: elements with index > i have been locked in place. | ||
self.swap(i, gen_index(rng, i + 1)); | ||
// The algorithm below is based on Durstenfeld's algorithm for the | ||
// [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) | ||
// for an unbiased permutation. | ||
// It ensures that the last `amount` elements of the slice | ||
// are randomly selected from the whole slice. | ||
|
||
//`IncreasingUniform::next_index()` is faster than `gen_index` | ||
//but only works for 32 bit integers | ||
//So we must use the slow method if the slice is longer than that. | ||
if self.len() < (u32::MAX as usize) { | ||
wainwrightmark marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let mut chooser = IncreasingUniform::new(rng, m as u32); | ||
for i in m..self.len() { | ||
let index = chooser.next_index(); | ||
self.swap(i, index); | ||
} | ||
} else { | ||
for i in m..self.len() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to reverse the iterator (both loops). Your code can only "choose" the last element of the list with probability 1/len when it should be m/len. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ermm, I'm pretty sure I've got this right. The last element gets swapped to a random place in the list so it has a m/len probability of being in the first m elements. Earlier elements are more likely to be chosen initially but can get booted out by later ones. The The reason I don't reverse the iterator is because the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay. We previously reversed since this way the proof by induction is easier. But we can also prove this algorithm works. First, lets not use Lets say we have a list Algorithm is: for i in end..len {
elts.swap(i, rng.sample_range(0..=i));
} For any length, for Thus, we assume:
We perform the last step of the algorithm:
Thus each element has chance |
||
let index = gen_index(rng, i + 1); | ||
self.swap(i, index); | ||
} | ||
} | ||
let r = self.split_at_mut(end); | ||
let r = self.split_at_mut(m); | ||
(r.1, r.0) | ||
} | ||
} | ||
|
@@ -765,11 +780,11 @@ mod test { | |
|
||
let mut r = crate::test::rng(414); | ||
nums.shuffle(&mut r); | ||
assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]); | ||
assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]); | ||
nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; | ||
let res = nums.partial_shuffle(&mut r, 6); | ||
assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]); | ||
assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]); | ||
assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]); | ||
assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]); | ||
} | ||
|
||
#[derive(Clone)] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's probably also room for further optimisation here: modulus is a slow operation (see https://www.pcg-random.org/posts/bounded-rands.html).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did read that article and it helped me find some of the optimizations I used for this. I also tried using a method based on bitmask but it turned out about 50% slower than this. Obviously I could easily have missed something.