Skip to content

Commit 63bde31

Browse files
authored
Merge pull request #411 from vks/bernoulli
Implement Bernoulli distribution
2 parents c0e3e0c + c440d3e commit 63bde31

File tree

6 files changed

+193
-24
lines changed

6 files changed

+193
-24
lines changed

benches/misc.rs

+31-5
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ use rand::prelude::*;
1111
use rand::seq::*;
1212

1313
#[bench]
14-
fn misc_gen_bool(b: &mut Bencher) {
14+
fn misc_gen_bool_const(b: &mut Bencher) {
1515
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
1616
b.iter(|| {
17+
// Can be evaluated at compile time.
1718
let mut accum = true;
1819
for _ in 0..::RAND_BENCH_N {
1920
accum ^= rng.gen_bool(0.18);
2021
}
21-
black_box(accum);
22+
accum
2223
})
2324
}
2425

@@ -27,12 +28,37 @@ fn misc_gen_bool_var(b: &mut Bencher) {
2728
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
2829
b.iter(|| {
2930
let mut p = 0.18;
31+
black_box(&mut p); // Avoid constant folding.
32+
for _ in 0..::RAND_BENCH_N {
33+
black_box(rng.gen_bool(p));
34+
}
35+
})
36+
}
37+
38+
#[bench]
39+
fn misc_bernoulli_const(b: &mut Bencher) {
40+
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
41+
let d = rand::distributions::Bernoulli::new(0.18);
42+
b.iter(|| {
43+
// Can be evaluated at compile time.
3044
let mut accum = true;
3145
for _ in 0..::RAND_BENCH_N {
32-
accum ^= rng.gen_bool(p);
33-
p += 0.0001;
46+
accum ^= rng.sample(d);
47+
}
48+
accum
49+
})
50+
}
51+
52+
#[bench]
53+
fn misc_bernoulli_var(b: &mut Bencher) {
54+
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
55+
b.iter(|| {
56+
let mut p = 0.18;
57+
black_box(&mut p); // Avoid constant folding.
58+
let d = rand::distributions::Bernoulli::new(p);
59+
for _ in 0..::RAND_BENCH_N {
60+
black_box(rng.sample(d));
3461
}
35-
black_box(accum);
3662
})
3763
}
3864

src/distributions/bernoulli.rs

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// https://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
//! The Bernoulli distribution.
11+
12+
use Rng;
13+
use distributions::Distribution;
14+
15+
/// The Bernoulli distribution.
16+
///
17+
/// This is a special case of the Binomial distribution where `n = 1`.
18+
///
19+
/// # Example
20+
///
21+
/// ```rust
22+
/// use rand::distributions::{Bernoulli, Distribution};
23+
///
24+
/// let d = Bernoulli::new(0.3);
25+
/// let v = d.sample(&mut rand::thread_rng());
26+
/// println!("{} is from a Bernoulli distribution", v);
27+
/// ```
28+
///
29+
/// # Precision
30+
///
31+
/// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`),
32+
/// so only probabilities that are multiples of 2<sup>-64</sup> can be
33+
/// represented.
34+
#[derive(Clone, Copy, Debug)]
35+
pub struct Bernoulli {
36+
/// Probability of success, relative to the maximal integer.
37+
p_int: u64,
38+
}
39+
40+
impl Bernoulli {
41+
/// Construct a new `Bernoulli` with the given probability of success `p`.
42+
///
43+
/// # Panics
44+
///
45+
/// If `p < 0` or `p > 1`.
46+
///
47+
/// # Precision
48+
///
49+
/// For `p = 1.0`, the resulting distribution will always generate true.
50+
/// For `p = 0.0`, the resulting distribution will always generate false.
51+
///
52+
/// This method is accurate for any input `p` in the range `[0, 1]` which is
53+
/// a multiple of 2<sup>-64</sup>. (Note that not all multiples of
54+
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
55+
#[inline]
56+
pub fn new(p: f64) -> Bernoulli {
57+
assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0");
58+
// Technically, this should be 2^64 or `u64::MAX + 1` because we compare
59+
// using `<` when sampling. However, `u64::MAX` rounds to an `f64`
60+
// larger than `u64::MAX` anyway.
61+
const MAX_P_INT: f64 = ::core::u64::MAX as f64;
62+
let p_int = if p < 1.0 {
63+
(p * MAX_P_INT) as u64
64+
} else {
65+
// Avoid overflow: `MAX_P_INT` cannot be represented as u64.
66+
::core::u64::MAX
67+
};
68+
Bernoulli { p_int }
69+
}
70+
}
71+
72+
impl Distribution<bool> for Bernoulli {
73+
#[inline]
74+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
75+
// Make sure to always return true for p = 1.0.
76+
if self.p_int == ::core::u64::MAX {
77+
return true;
78+
}
79+
let r: u64 = rng.gen();
80+
r < self.p_int
81+
}
82+
}
83+
84+
#[cfg(test)]
85+
mod test {
86+
use Rng;
87+
use distributions::Distribution;
88+
use super::Bernoulli;
89+
90+
#[test]
91+
fn test_trivial() {
92+
let mut r = ::test::rng(1);
93+
let always_false = Bernoulli::new(0.0);
94+
let always_true = Bernoulli::new(1.0);
95+
for _ in 0..5 {
96+
assert_eq!(r.sample::<bool, _>(&always_false), false);
97+
assert_eq!(r.sample::<bool, _>(&always_true), true);
98+
assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false);
99+
assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true);
100+
}
101+
}
102+
103+
#[test]
104+
fn test_average() {
105+
const P: f64 = 0.3;
106+
let d = Bernoulli::new(P);
107+
const N: u32 = 10_000_000;
108+
109+
let mut sum: u32 = 0;
110+
let mut rng = ::test::rng(2);
111+
for _ in 0..N {
112+
if d.sample(&mut rng) {
113+
sum += 1;
114+
}
115+
}
116+
let avg = (sum as f64) / (N as f64);
117+
118+
assert!((avg - P).abs() < 1e-3);
119+
}
120+
}

src/distributions/binomial.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,17 @@ use std::f64::consts::PI;
3131
/// ```
3232
#[derive(Clone, Copy, Debug)]
3333
pub struct Binomial {
34-
n: u64, // number of trials
35-
p: f64, // probability of success
34+
/// Number of trials.
35+
n: u64,
36+
/// Probability of success.
37+
p: f64,
3638
}
3739

3840
impl Binomial {
39-
/// Construct a new `Binomial` with the given shape parameters
40-
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`.
41+
/// Construct a new `Binomial` with the given shape parameters `n` (number
42+
/// of trials) and `p` (probability of success).
43+
///
44+
/// Panics if `p <= 0` or `p >= 1`.
4145
pub fn new(n: u64, p: f64) -> Binomial {
4246
assert!(p > 0.0, "Binomial::new called with p <= 0");
4347
assert!(p < 1.0, "Binomial::new called with p >= 1");

src/distributions/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ pub use self::uniform::Uniform as Range;
178178
#[doc(inline)] pub use self::poisson::Poisson;
179179
#[cfg(feature = "std")]
180180
#[doc(inline)] pub use self::binomial::Binomial;
181+
#[doc(inline)] pub use self::bernoulli::Bernoulli;
181182

182183
pub mod uniform;
183184
#[cfg(feature="std")]
@@ -190,6 +191,7 @@ pub mod uniform;
190191
#[doc(hidden)] pub mod poisson;
191192
#[cfg(feature = "std")]
192193
#[doc(hidden)] pub mod binomial;
194+
#[doc(hidden)] pub mod bernoulli;
193195

194196
mod float;
195197
mod integer;

src/lib.rs

+9-15
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ pub trait Rng: RngCore {
318318
/// println!("{}", x);
319319
/// println!("{:?}", rng.gen::<(f64, bool)>());
320320
/// ```
321-
#[inline(always)]
322321
fn gen<T>(&mut self) -> T where Standard: Distribution<T> {
323322
Standard.sample(self)
324323
}
@@ -474,6 +473,8 @@ pub trait Rng: RngCore {
474473

475474
/// Return a bool with a probability `p` of being true.
476475
///
476+
/// This is a wrapper around [`distributions::Bernoulli`].
477+
///
477478
/// # Example
478479
///
479480
/// ```rust
@@ -483,20 +484,15 @@ pub trait Rng: RngCore {
483484
/// println!("{}", rng.gen_bool(1.0 / 3.0));
484485
/// ```
485486
///
486-
/// # Accuracy note
487+
/// # Panics
488+
///
489+
/// If `p` < 0 or `p` > 1.
487490
///
488-
/// `gen_bool` uses 32 bits of the RNG, so if you use it to generate close
489-
/// to or more than `2^32` results, a tiny bias may become noticable.
490-
/// A notable consequence of the method used here is that the worst case is
491-
/// `rng.gen_bool(0.0)`: it has a chance of 1 in `2^32` of being true, while
492-
/// it should always be false. But using `gen_bool` to consume *many* values
493-
/// from an RNG just to consistently generate `false` does not match with
494-
/// the intent of this method.
491+
/// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
492+
#[inline]
495493
fn gen_bool(&mut self, p: f64) -> bool {
496-
assert!(p >= 0.0 && p <= 1.0);
497-
// If `p` is constant, this will be evaluated at compile-time.
498-
let p_int = (p * f64::from(core::u32::MAX)) as u32;
499-
self.gen::<u32>() <= p_int
494+
let d = distributions::Bernoulli::new(p);
495+
self.sample(d)
500496
}
501497

502498
/// Return a random element from `values`.
@@ -897,7 +893,6 @@ pub fn weak_rng() -> XorShiftRng {
897893
/// [`thread_rng`]: fn.thread_rng.html
898894
/// [`Standard`]: distributions/struct.Standard.html
899895
#[cfg(feature="std")]
900-
#[inline]
901896
pub fn random<T>() -> T where Standard: Distribution<T> {
902897
thread_rng().gen()
903898
}
@@ -918,7 +913,6 @@ pub fn random<T>() -> T where Standard: Distribution<T> {
918913
/// println!("{:?}", sample);
919914
/// ```
920915
#[cfg(feature="std")]
921-
#[inline(always)]
922916
#[deprecated(since="0.4.0", note="renamed to seq::sample_iter")]
923917
pub fn sample<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Vec<T>
924918
where I: IntoIterator<Item=T>,

tests/bool.rs

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#![no_std]
2+
3+
extern crate rand;
4+
5+
use rand::SeedableRng;
6+
use rand::rngs::SmallRng;
7+
use rand::distributions::{Distribution, Bernoulli};
8+
9+
/// This test should make sure that we don't accidentally have undefined
10+
/// behavior for large propabilties due to
11+
/// https://github.com/rust-lang/rust/issues/10184.
12+
/// Expressions like `1.0*(u64::MAX as f64) as u64` have to be avoided.
13+
#[test]
14+
fn large_probability() {
15+
let p = 1. - ::core::f64::EPSILON / 2.;
16+
assert!(p < 1.);
17+
let d = Bernoulli::new(p);
18+
let mut rng = SmallRng::from_seed(
19+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
20+
for _ in 0..10 {
21+
assert!(d.sample(&mut rng), "extremely unlikely to fail by accident");
22+
}
23+
}

0 commit comments

Comments
 (0)