Skip to content

Commit

Permalink
Merge pull request #411 from vks/bernoulli
Browse files Browse the repository at this point in the history
Implement Bernoulli distribution
  • Loading branch information
pitdicker authored May 15, 2018
2 parents c0e3e0c + c440d3e commit 63bde31
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 24 deletions.
36 changes: 31 additions & 5 deletions benches/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ use rand::prelude::*;
use rand::seq::*;

#[bench]
fn misc_gen_bool(b: &mut Bencher) {
fn misc_gen_bool_const(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
// Can be evaluated at compile time.
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.gen_bool(0.18);
}
black_box(accum);
accum
})
}

Expand All @@ -27,12 +28,37 @@ fn misc_gen_bool_var(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let mut p = 0.18;
black_box(&mut p); // Avoid constant folding.
for _ in 0..::RAND_BENCH_N {
black_box(rng.gen_bool(p));
}
})
}

#[bench]
fn misc_bernoulli_const(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
let d = rand::distributions::Bernoulli::new(0.18);
b.iter(|| {
// Can be evaluated at compile time.
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.gen_bool(p);
p += 0.0001;
accum ^= rng.sample(d);
}
accum
})
}

#[bench]
fn misc_bernoulli_var(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let mut p = 0.18;
black_box(&mut p); // Avoid constant folding.
let d = rand::distributions::Bernoulli::new(p);
for _ in 0..::RAND_BENCH_N {
black_box(rng.sample(d));
}
black_box(accum);
})
}

Expand Down
120 changes: 120 additions & 0 deletions src/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Bernoulli distribution.

use Rng;
use distributions::Distribution;

/// The Bernoulli distribution.
///
/// This is a special case of the Binomial distribution where `n = 1`.
///
/// # Example
///
/// ```rust
/// use rand::distributions::{Bernoulli, Distribution};
///
/// let d = Bernoulli::new(0.3);
/// let v = d.sample(&mut rand::thread_rng());
/// println!("{} is from a Bernoulli distribution", v);
/// ```
///
/// # Precision
///
/// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`),
/// so only probabilities that are multiples of 2<sup>-64</sup> can be
/// represented.
#[derive(Clone, Copy, Debug)]
pub struct Bernoulli {
/// Probability of success, relative to the maximal integer.
p_int: u64,
}

impl Bernoulli {
/// Construct a new `Bernoulli` with the given probability of success `p`.
///
/// # Panics
///
/// If `p < 0` or `p > 1`.
///
/// # Precision
///
/// For `p = 1.0`, the resulting distribution will always generate true.
/// For `p = 0.0`, the resulting distribution will always generate false.
///
/// This method is accurate for any input `p` in the range `[0, 1]` which is
/// a multiple of 2<sup>-64</sup>. (Note that not all multiples of
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
#[inline]
pub fn new(p: f64) -> Bernoulli {
assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0");
// Technically, this should be 2^64 or `u64::MAX + 1` because we compare
// using `<` when sampling. However, `u64::MAX` rounds to an `f64`
// larger than `u64::MAX` anyway.
const MAX_P_INT: f64 = ::core::u64::MAX as f64;
let p_int = if p < 1.0 {
(p * MAX_P_INT) as u64
} else {
// Avoid overflow: `MAX_P_INT` cannot be represented as u64.
::core::u64::MAX
};
Bernoulli { p_int }
}
}

impl Distribution<bool> for Bernoulli {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
// Make sure to always return true for p = 1.0.
if self.p_int == ::core::u64::MAX {
return true;
}
let r: u64 = rng.gen();
r < self.p_int
}
}

#[cfg(test)]
mod test {
use Rng;
use distributions::Distribution;
use super::Bernoulli;

#[test]
fn test_trivial() {
let mut r = ::test::rng(1);
let always_false = Bernoulli::new(0.0);
let always_true = Bernoulli::new(1.0);
for _ in 0..5 {
assert_eq!(r.sample::<bool, _>(&always_false), false);
assert_eq!(r.sample::<bool, _>(&always_true), true);
assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false);
assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true);
}
}

#[test]
fn test_average() {
const P: f64 = 0.3;
let d = Bernoulli::new(P);
const N: u32 = 10_000_000;

let mut sum: u32 = 0;
let mut rng = ::test::rng(2);
for _ in 0..N {
if d.sample(&mut rng) {
sum += 1;
}
}
let avg = (sum as f64) / (N as f64);

assert!((avg - P).abs() < 1e-3);
}
}
12 changes: 8 additions & 4 deletions src/distributions/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ use std::f64::consts::PI;
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Binomial {
n: u64, // number of trials
p: f64, // probability of success
/// Number of trials.
n: u64,
/// Probability of success.
p: f64,
}

impl Binomial {
/// Construct a new `Binomial` with the given shape parameters
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`.
/// Construct a new `Binomial` with the given shape parameters `n` (number
/// of trials) and `p` (probability of success).
///
/// Panics if `p <= 0` or `p >= 1`.
pub fn new(n: u64, p: f64) -> Binomial {
assert!(p > 0.0, "Binomial::new called with p <= 0");
assert!(p < 1.0, "Binomial::new called with p >= 1");
Expand Down
2 changes: 2 additions & 0 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ pub use self::uniform::Uniform as Range;
#[doc(inline)] pub use self::poisson::Poisson;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::binomial::Binomial;
#[doc(inline)] pub use self::bernoulli::Bernoulli;

pub mod uniform;
#[cfg(feature="std")]
Expand All @@ -190,6 +191,7 @@ pub mod uniform;
#[doc(hidden)] pub mod poisson;
#[cfg(feature = "std")]
#[doc(hidden)] pub mod binomial;
#[doc(hidden)] pub mod bernoulli;

mod float;
mod integer;
Expand Down
24 changes: 9 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ pub trait Rng: RngCore {
/// println!("{}", x);
/// println!("{:?}", rng.gen::<(f64, bool)>());
/// ```
#[inline(always)]
fn gen<T>(&mut self) -> T where Standard: Distribution<T> {
Standard.sample(self)
}
Expand Down Expand Up @@ -474,6 +473,8 @@ pub trait Rng: RngCore {

/// Return a bool with a probability `p` of being true.
///
/// This is a wrapper around [`distributions::Bernoulli`].
///
/// # Example
///
/// ```rust
Expand All @@ -483,20 +484,15 @@ pub trait Rng: RngCore {
/// println!("{}", rng.gen_bool(1.0 / 3.0));
/// ```
///
/// # Accuracy note
/// # Panics
///
/// If `p` < 0 or `p` > 1.
///
/// `gen_bool` uses 32 bits of the RNG, so if you use it to generate close
/// to or more than `2^32` results, a tiny bias may become noticable.
/// A notable consequence of the method used here is that the worst case is
/// `rng.gen_bool(0.0)`: it has a chance of 1 in `2^32` of being true, while
/// it should always be false. But using `gen_bool` to consume *many* values
/// from an RNG just to consistently generate `false` does not match with
/// the intent of this method.
/// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
#[inline]
fn gen_bool(&mut self, p: f64) -> bool {
assert!(p >= 0.0 && p <= 1.0);
// If `p` is constant, this will be evaluated at compile-time.
let p_int = (p * f64::from(core::u32::MAX)) as u32;
self.gen::<u32>() <= p_int
let d = distributions::Bernoulli::new(p);
self.sample(d)
}

/// Return a random element from `values`.
Expand Down Expand Up @@ -897,7 +893,6 @@ pub fn weak_rng() -> XorShiftRng {
/// [`thread_rng`]: fn.thread_rng.html
/// [`Standard`]: distributions/struct.Standard.html
#[cfg(feature="std")]
#[inline]
pub fn random<T>() -> T where Standard: Distribution<T> {
thread_rng().gen()
}
Expand All @@ -918,7 +913,6 @@ pub fn random<T>() -> T where Standard: Distribution<T> {
/// println!("{:?}", sample);
/// ```
#[cfg(feature="std")]
#[inline(always)]
#[deprecated(since="0.4.0", note="renamed to seq::sample_iter")]
pub fn sample<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Vec<T>
where I: IntoIterator<Item=T>,
Expand Down
23 changes: 23 additions & 0 deletions tests/bool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#![no_std]

extern crate rand;

use rand::SeedableRng;
use rand::rngs::SmallRng;
use rand::distributions::{Distribution, Bernoulli};

/// This test should make sure that we don't accidentally have undefined
/// behavior for large propabilties due to
/// https://github.com/rust-lang/rust/issues/10184.
/// Expressions like `1.0*(u64::MAX as f64) as u64` have to be avoided.
#[test]
fn large_probability() {
let p = 1. - ::core::f64::EPSILON / 2.;
assert!(p < 1.);
let d = Bernoulli::new(p);
let mut rng = SmallRng::from_seed(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
for _ in 0..10 {
assert!(d.sample(&mut rng), "extremely unlikely to fail by accident");
}
}

0 comments on commit 63bde31

Please sign in to comment.