Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an iterator to Distribution #361

Merged
merged 3 commits into from
Apr 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions benches/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,19 @@ gen_range_int!(gen_range_i32, i32, -200_000_000i32, 800_000_000);
gen_range_int!(gen_range_i64, i64, 3i64, 123_456_789_123);
#[cfg(feature = "i128_support")]
gen_range_int!(gen_range_i128, i128, -12345678901234i128, 123_456_789_123_456_789);

#[bench]
fn dist_iter(b: &mut Bencher) {
let mut rng = XorShiftRng::new();
let distr = Normal::new(-2.71828, 3.14159);
let mut iter = distr.sample_iter(&mut rng);

b.iter(|| {
let mut accum = 0.0;
for _ in 0..::RAND_BENCH_N {
accum += iter.next().unwrap();
}
black_box(accum);
});
b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;
}
44 changes: 44 additions & 0 deletions benches/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,47 @@ macro_rules! sample_indices {
sample_indices!(misc_sample_indices_10_of_1k, 10, 1000);
sample_indices!(misc_sample_indices_50_of_1k, 50, 1000);
sample_indices!(misc_sample_indices_100_of_1k, 100, 1000);

#[bench]
fn gen_1k_iter_repeat(b: &mut Bencher) {
use std::iter;
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let v: Vec<u64> = iter::repeat(()).map(|()| rng.gen()).take(128).collect();
black_box(v);
});
b.bytes = 1024;
}

#[bench]
#[allow(deprecated)]
fn gen_1k_gen_iter(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let v: Vec<u64> = rng.gen_iter().take(128).collect();
black_box(v);
});
b.bytes = 1024;
}

#[bench]
fn gen_1k_sample_iter(b: &mut Bencher) {
use rand::distributions::{Distribution, Uniform};
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let v: Vec<u64> = Uniform.sample_iter(&mut rng).take(128).collect();
black_box(v);
});
b.bytes = 1024;
}

#[bench]
fn gen_1k_fill(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
let mut buf = [0u64; 128];
b.iter(|| {
rng.fill(&mut buf[..]);
black_box(buf);
});
b.bytes = 1024;
}
73 changes: 71 additions & 2 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,68 @@ mod impls {

/// Types (distributions) that can be used to create a random instance of `T`.
pub trait Distribution<T> {
/// Generate a random value of `T`, using `rng` as the
/// source of randomness.
/// Generate a random value of `T`, using `rng` as the source of randomness.
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T;

/// Create an iterator that generates random values of `T`, using `rng` as
/// the source of randomness.
///
/// # Example
///
/// ```rust
/// use rand::thread_rng;
/// use rand::distributions::{Distribution, Alphanumeric, Range, Uniform};
///
/// let mut rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = Uniform.sample_iter(&mut rng).take(16).collect();
///
/// // String:
/// let s: String = Alphanumeric.sample_iter(&mut rng).take(7).collect();
///
/// // Dice-rolling:
/// let die_range = Range::new_inclusive(1, 6);
/// let mut roll_die = die_range.sample_iter(&mut rng);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, R: Rng>(&'a self, rng: &'a mut R)
-> DistIter<'a, Self, R, T> where Self: Sized
{
DistIter {
distr: self,
rng: rng,
phantom: ::core::marker::PhantomData,
}
}
}

/// An iterator that generates random values of `T` with distribution `D`,
/// using `R` as the source of randomness.
///
/// This `struct` is created by the [`sample_iter`] method on [`Distribution`].
/// See its documentation for more.
///
/// [`Distribution`]: trait.Distribution.html
/// [`sample_iter`]: trait.Distribution.html#method.sample_iter
#[derive(Debug)]
pub struct DistIter<'a, D, R, T> where D: Distribution<T> + 'a, R: Rng + 'a {
distr: &'a D,
rng: &'a mut R,
phantom: ::core::marker::PhantomData<T>,
}

impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a
{
type Item = T;

#[inline(always)]
fn next(&mut self) -> Option<T> {
Some(self.distr.sample(self.rng))
}
}

impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
Expand Down Expand Up @@ -519,4 +578,14 @@ mod tests {
let sampler = Exp::new(1.0);
sampler.ind_sample(&mut ::test::rng(235));
}

#[cfg(feature="std")]
#[test]
fn test_distributions_iter() {
use distributions::Normal;
let mut rng = ::test::rng(210);
let distr = Normal::new(10.0, 10.0);
let results: Vec<_> = distr.sample_iter(&mut rng).take(100).collect();
println!("{:?}", results);
}
}
66 changes: 36 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,33 +299,6 @@ pub trait Rand : Sized {
/// }
/// ```
///
/// # Iteration
///
/// Iteration over an `Rng` can be achieved using `iter::repeat` as follows:
///
/// ```rust
/// use std::iter;
/// use rand::{Rng, thread_rng};
/// use rand::distributions::{Alphanumeric, Range};
///
/// let mut rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = iter::repeat(()).map(|()| rng.gen()).take(16).collect();
///
/// // String:
/// let s: String = iter::repeat(())
/// .map(|()| rng.sample(Alphanumeric))
/// .take(7).collect();
///
/// // Dice-rolling:
/// let die_range = Range::new_inclusive(1, 6);
/// let mut roll_die = iter::repeat(()).map(|()| rng.sample(die_range));
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
///
/// [`RngCore`]: https://docs.rs/rand_core/0.1/rand_core/trait.RngCore.html
pub trait Rng: RngCore {
/// Fill `dest` entirely with random bytes (uniform value distribution),
Expand Down Expand Up @@ -407,6 +380,39 @@ pub trait Rng: RngCore {
fn sample<T, D: Distribution<T>>(&mut self, distr: D) -> T {
distr.sample(self)
}

/// Create an iterator that generates values using the given distribution.
///
/// # Example
///
/// ```rust
/// use rand::{thread_rng, Rng};
/// use rand::distributions::{Alphanumeric, Range, Uniform};
///
/// let mut rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = thread_rng().sample_iter(&Uniform).take(16).collect();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, but lets not promote bad styles in examples by creating a local rng handle then not using it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to show that it was also possible to use thread_rng() directly. But yes, bad style. Will change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most people will figure that out.

///
/// // String:
/// let s: String = rng.sample_iter(&Alphanumeric).take(7).collect();
///
/// // Combined values
/// println!("{:?}", thread_rng().sample_iter(&Uniform).take(5)
/// .collect::<Vec<(f64, bool)>>());
///
/// // Dice-rolling:
/// let die_range = Range::new_inclusive(1, 6);
/// let mut roll_die = rng.sample_iter(&die_range);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, T, D: Distribution<T>>(&'a mut self, distr: &'a D)
-> distributions::DistIter<'a, D, Self, T> where Self: Sized
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting that this takes the distribution by reference while sample takes it by value. I'm not saying this is wrong. If we have both, then distr.iter(&mut rng) is still an option when passing by reference is required, so this could take the distribution by value (which would consume die_range above but only if the & is removed — we also implement Distribution for the references).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we can have both, because the iterator expects an reference instead of something owned. Or do we want to use some sort of Cow or MaybeOwned trick?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will work if you just change to distr: D here; essentially Distribution::iter receives a reference from this function stack.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error[E0597]: `distr` does not live long enough
   --> src/lib.rs:414:9
    |
414 |         distr.sample_iter(self)
    |         ^^^^^ borrowed value does not live long enough
415 |     }
    |     - borrowed value only lives until here
    |
note: borrowed value must be valid for the lifetime 'a as defined on the method body at 411:5...
   --> src/lib.rs:411:5
    |
411 | /     fn sample_iter<'a, T, D: Distribution<T>>(&'a mut self, distr: D)
412 | |         -> distributions::DistIter<'a, D, Self, T> where Self: Sized
413 | |     {
414 | |         distr.sample_iter(self)
415 | |     }
    | |_____^

error: aborting due to previous error

For more information about this error, try `rustc --explain E0597`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, of course the iterator is returned from sample_iter!

{
distr.sample_iter(self)
}

/// Return a random value supporting the [`Uniform`] distribution.
///
Expand Down Expand Up @@ -442,7 +448,7 @@ pub trait Rng: RngCore {
/// .collect::<Vec<(f64, bool)>>());
/// ```
#[allow(deprecated)]
#[deprecated(since="0.5.0", note="use iter::repeat instead")]
#[deprecated(since="0.5.0", note="use Rng::sample_iter(&Uniform) instead")]
fn gen_iter<T>(&mut self) -> Generator<T, &mut Self> where Uniform: Distribution<T> {
Generator { rng: self, _marker: marker::PhantomData }
}
Expand Down Expand Up @@ -527,7 +533,7 @@ pub trait Rng: RngCore {
/// println!("{}", s);
/// ```
#[allow(deprecated)]
#[deprecated(since="0.5.0", note="use distributions::Alphanumeric instead")]
#[deprecated(since="0.5.0", note="use sample_iter(&Alphanumeric) instead")]
fn gen_ascii_chars(&mut self) -> AsciiGenerator<&mut Self> {
AsciiGenerator { rng: self }
}
Expand Down Expand Up @@ -678,7 +684,7 @@ impl_as_byte_slice_arrays!(32, N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N
/// [`Rng`]: trait.Rng.html
#[derive(Debug)]
#[allow(deprecated)]
#[deprecated(since="0.5.0", note="use iter::repeat instead")]
#[deprecated(since="0.5.0", note="use Rng::sample_iter instead")]
pub struct Generator<T, R: RngCore> {
rng: R,
_marker: marker::PhantomData<fn() -> T>,
Expand Down