Skip to content

Commit

Permalink
Merge pull request #361 from pitdicker/sample_iter
Browse files Browse the repository at this point in the history
Add an iterator to `Distribution`
  • Loading branch information
dhardy authored Apr 1, 2018
2 parents fbf9572 + 4eb0831 commit dba18b8
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 32 deletions.
16 changes: 16 additions & 0 deletions benches/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,19 @@ gen_range_int!(gen_range_i32, i32, -200_000_000i32, 800_000_000);
gen_range_int!(gen_range_i64, i64, 3i64, 123_456_789_123);
#[cfg(feature = "i128_support")]
gen_range_int!(gen_range_i128, i128, -12345678901234i128, 123_456_789_123_456_789);

#[bench]
fn dist_iter(b: &mut Bencher) {
let mut rng = XorShiftRng::new();
let distr = Normal::new(-2.71828, 3.14159);
let mut iter = distr.sample_iter(&mut rng);

b.iter(|| {
let mut accum = 0.0;
for _ in 0..::RAND_BENCH_N {
accum += iter.next().unwrap();
}
black_box(accum);
});
b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;
}
44 changes: 44 additions & 0 deletions benches/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,47 @@ macro_rules! sample_indices {
sample_indices!(misc_sample_indices_10_of_1k, 10, 1000);
sample_indices!(misc_sample_indices_50_of_1k, 50, 1000);
sample_indices!(misc_sample_indices_100_of_1k, 100, 1000);

#[bench]
fn gen_1k_iter_repeat(b: &mut Bencher) {
use std::iter;
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let v: Vec<u64> = iter::repeat(()).map(|()| rng.gen()).take(128).collect();
black_box(v);
});
b.bytes = 1024;
}

#[bench]
#[allow(deprecated)]
fn gen_1k_gen_iter(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let v: Vec<u64> = rng.gen_iter().take(128).collect();
black_box(v);
});
b.bytes = 1024;
}

#[bench]
fn gen_1k_sample_iter(b: &mut Bencher) {
use rand::distributions::{Distribution, Uniform};
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let v: Vec<u64> = Uniform.sample_iter(&mut rng).take(128).collect();
black_box(v);
});
b.bytes = 1024;
}

#[bench]
fn gen_1k_fill(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
let mut buf = [0u64; 128];
b.iter(|| {
rng.fill(&mut buf[..]);
black_box(buf);
});
b.bytes = 1024;
}
73 changes: 71 additions & 2 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,68 @@ mod impls {

/// Types (distributions) that can be used to create a random instance of `T`.
pub trait Distribution<T> {
/// Generate a random value of `T`, using `rng` as the
/// source of randomness.
/// Generate a random value of `T`, using `rng` as the source of randomness.
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T;

/// Create an iterator that generates random values of `T`, using `rng` as
/// the source of randomness.
///
/// # Example
///
/// ```rust
/// use rand::thread_rng;
/// use rand::distributions::{Distribution, Alphanumeric, Range, Uniform};
///
/// let mut rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = Uniform.sample_iter(&mut rng).take(16).collect();
///
/// // String:
/// let s: String = Alphanumeric.sample_iter(&mut rng).take(7).collect();
///
/// // Dice-rolling:
/// let die_range = Range::new_inclusive(1, 6);
/// let mut roll_die = die_range.sample_iter(&mut rng);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, R: Rng>(&'a self, rng: &'a mut R)
-> DistIter<'a, Self, R, T> where Self: Sized
{
DistIter {
distr: self,
rng: rng,
phantom: ::core::marker::PhantomData,
}
}
}

/// An iterator that generates random values of `T` with distribution `D`,
/// using `R` as the source of randomness.
///
/// This `struct` is created by the [`sample_iter`] method on [`Distribution`].
/// See its documentation for more.
///
/// [`Distribution`]: trait.Distribution.html
/// [`sample_iter`]: trait.Distribution.html#method.sample_iter
#[derive(Debug)]
pub struct DistIter<'a, D, R, T> where D: Distribution<T> + 'a, R: Rng + 'a {
distr: &'a D,
rng: &'a mut R,
phantom: ::core::marker::PhantomData<T>,
}

impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a
{
type Item = T;

#[inline(always)]
fn next(&mut self) -> Option<T> {
Some(self.distr.sample(self.rng))
}
}

impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
Expand Down Expand Up @@ -519,4 +578,14 @@ mod tests {
let sampler = Exp::new(1.0);
sampler.ind_sample(&mut ::test::rng(235));
}

#[cfg(feature="std")]
#[test]
fn test_distributions_iter() {
use distributions::Normal;
let mut rng = ::test::rng(210);
let distr = Normal::new(10.0, 10.0);
let results: Vec<_> = distr.sample_iter(&mut rng).take(100).collect();
println!("{:?}", results);
}
}
66 changes: 36 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,33 +300,6 @@ pub trait Rand : Sized {
/// }
/// ```
///
/// # Iteration
///
/// Iteration over an `Rng` can be achieved using `iter::repeat` as follows:
///
/// ```rust
/// use std::iter;
/// use rand::{Rng, thread_rng};
/// use rand::distributions::{Alphanumeric, Range};
///
/// let mut rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = iter::repeat(()).map(|()| rng.gen()).take(16).collect();
///
/// // String:
/// let s: String = iter::repeat(())
/// .map(|()| rng.sample(Alphanumeric))
/// .take(7).collect();
///
/// // Dice-rolling:
/// let die_range = Range::new_inclusive(1, 6);
/// let mut roll_die = iter::repeat(()).map(|()| rng.sample(die_range));
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
///
/// [`RngCore`]: https://docs.rs/rand_core/0.1/rand_core/trait.RngCore.html
pub trait Rng: RngCore {
/// Fill `dest` entirely with random bytes (uniform value distribution),
Expand Down Expand Up @@ -408,6 +381,39 @@ pub trait Rng: RngCore {
fn sample<T, D: Distribution<T>>(&mut self, distr: D) -> T {
distr.sample(self)
}

/// Create an iterator that generates values using the given distribution.
///
/// # Example
///
/// ```rust
/// use rand::{thread_rng, Rng};
/// use rand::distributions::{Alphanumeric, Range, Uniform};
///
/// let mut rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = thread_rng().sample_iter(&Uniform).take(16).collect();
///
/// // String:
/// let s: String = rng.sample_iter(&Alphanumeric).take(7).collect();
///
/// // Combined values
/// println!("{:?}", thread_rng().sample_iter(&Uniform).take(5)
/// .collect::<Vec<(f64, bool)>>());
///
/// // Dice-rolling:
/// let die_range = Range::new_inclusive(1, 6);
/// let mut roll_die = rng.sample_iter(&die_range);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, T, D: Distribution<T>>(&'a mut self, distr: &'a D)
-> distributions::DistIter<'a, D, Self, T> where Self: Sized
{
distr.sample_iter(self)
}

/// Return a random value supporting the [`Uniform`] distribution.
///
Expand Down Expand Up @@ -443,7 +449,7 @@ pub trait Rng: RngCore {
/// .collect::<Vec<(f64, bool)>>());
/// ```
#[allow(deprecated)]
#[deprecated(since="0.5.0", note="use iter::repeat instead")]
#[deprecated(since="0.5.0", note="use Rng::sample_iter(&Uniform) instead")]
fn gen_iter<T>(&mut self) -> Generator<T, &mut Self> where Uniform: Distribution<T> {
Generator { rng: self, _marker: marker::PhantomData }
}
Expand Down Expand Up @@ -528,7 +534,7 @@ pub trait Rng: RngCore {
/// println!("{}", s);
/// ```
#[allow(deprecated)]
#[deprecated(since="0.5.0", note="use distributions::Alphanumeric instead")]
#[deprecated(since="0.5.0", note="use sample_iter(&Alphanumeric) instead")]
fn gen_ascii_chars(&mut self) -> AsciiGenerator<&mut Self> {
AsciiGenerator { rng: self }
}
Expand Down Expand Up @@ -694,7 +700,7 @@ impl_as_byte_slice_arrays!(!div 4096, N,N,N,N,N,N,N,);
/// [`Rng`]: trait.Rng.html
#[derive(Debug)]
#[allow(deprecated)]
#[deprecated(since="0.5.0", note="use iter::repeat instead")]
#[deprecated(since="0.5.0", note="use Rng::sample_iter instead")]
pub struct Generator<T, R: RngCore> {
rng: R,
_marker: marker::PhantomData<fn() -> T>,
Expand Down

0 comments on commit dba18b8

Please sign in to comment.