Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SkewNormal distribution implementation #1174

Merged
merged 8 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
- New `Zeta` and `Zipf` distributions (#1136)
- New `SkewNormal` distribution (#1149)

## [0.4.1] - 2021-06-15
- Empirically test PDF of normal distribution (#1121)
Expand Down
2 changes: 2 additions & 0 deletions rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ rand_pcg = { version = "0.3.0", path = "../rand_pcg" }
rand = { path = "..", version = "0.8.0", default-features = false, features = ["std_rng", "std", "small_rng"] }
# Histogram implementation for testing uniformity
average = { version = "0.13", features = [ "std" ] }
# Special functions for testing distributions
statrs = "0.15.0"
7 changes: 7 additions & 0 deletions rand_distr/benches/src/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ fn bench(c: &mut Criterion<CyclesPerByte>) {
});
}

{
let mut g = c.benchmark_group("skew_normal");
distr_float!(g, "shape_zero", f64, SkewNormal::new(0.0, 1.0, 0.0).unwrap());
distr_float!(g, "shape_positive", f64, SkewNormal::new(0.0, 1.0, 100.0).unwrap());
distr_float!(g, "shape_negative", f64, SkewNormal::new(0.0, 1.0, -100.0).unwrap());
}

{
let mut g = c.benchmark_group("gamma");
distr_float!(g, "gamma_large_shape", f64, Gamma::new(10., 1.0).unwrap());
Expand Down
11 changes: 8 additions & 3 deletions rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
//! - Related to real-valued quantities that grow linearly
//! (e.g. errors, offsets):
//! - [`Normal`] distribution, and [`StandardNormal`] as a primitive
//! - [`SkewNormal`] distribution
//! - [`Cauchy`] distribution
//! - Related to Bernoulli trials (yes/no events, with a given probability):
//! - [`Binomial`] distribution
Expand Down Expand Up @@ -107,19 +108,22 @@ pub use self::gamma::{
pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric};
pub use self::gumbel::{Error as GumbelError, Gumbel};
pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric};
pub use self::inverse_gaussian::{InverseGaussian, Error as InverseGaussianError};
pub use self::inverse_gaussian::{Error as InverseGaussianError, InverseGaussian};
pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal};
pub use self::normal_inverse_gaussian::{NormalInverseGaussian, Error as NormalInverseGaussianError};
pub use self::normal_inverse_gaussian::{
Error as NormalInverseGaussianError, NormalInverseGaussian,
};
pub use self::pareto::{Error as ParetoError, Pareto};
pub use self::pert::{Pert, PertError};
pub use self::poisson::{Error as PoissonError, Poisson};
pub use self::skew_normal::{Error as SkewNormalError, SkewNormal};
pub use self::triangular::{Triangular, TriangularError};
pub use self::unit_ball::UnitBall;
pub use self::unit_circle::UnitCircle;
pub use self::unit_disc::UnitDisc;
pub use self::unit_sphere::UnitSphere;
pub use self::weibull::{Error as WeibullError, Weibull};
pub use self::zipf::{ZetaError, Zeta, ZipfError, Zipf};
pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError};
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use rand::distributions::{WeightedError, WeightedIndex};
Expand Down Expand Up @@ -196,6 +200,7 @@ mod normal_inverse_gaussian;
mod pareto;
mod pert;
mod poisson;
mod skew_normal;
mod triangular;
mod unit_ball;
mod unit_circle;
Expand Down
218 changes: 218 additions & 0 deletions rand_distr/src/skew_normal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright 2021 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! The Skew Normal distribution.

use crate::{Distribution, StandardNormal};
use core::fmt;
use num_traits::Float;
use rand::Rng;

/// The [skew normal distribution] `SN(location, scale, shape)`.
///
/// The skew normal distribution is a generalization of the
/// [`Normal`] distribution to allow for non-zero skewness.
///
/// It has the density function
/// `f(x) = 2 / scale * phi((x - location) / scale) * Phi(alpha * (x - location) / scale)`
/// where `phi` and `Phi` are the density and distribution of a standard normal variable.
saona-raimundo marked this conversation as resolved.
Show resolved Hide resolved
///
/// # Example
///
/// ```
/// use rand_distr::{SkewNormal, Distribution};
///
/// // location 2, scale 3, shape 1
/// let skew_normal = SkewNormal::new(2.0, 3.0, 1.0).unwrap();
/// let v = skew_normal.sample(&mut rand::thread_rng());
/// println!("{} is from a SN(2, 3, 1) distribution", v)
/// ```
///
/// # Implementation details
///
/// We are using the algorithm from [A Method to Simulate the Skew Normal Distribution].
///
/// [`skew normal distribution`]: https://en.wikipedia.org/wiki/Skew_normal_distribution
/// [A Method to Simulate the Skew Normal Distribution]:
/// Ghorbanzadeh, D. , Jaupi, L. and Durand, P. (2014)
/// [A Method to Simulate the Skew Normal Distribution](https://dx.doi.org/10.4236/am.2014.513201).
/// Applied Mathematics, 5, 2073-2076.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct SkewNormal<F>
where
F: Float,
StandardNormal: Distribution<F>,
{
location: F,
scale: F,
shape: F,
}

/// Error type returned from `SkewNormal::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// The scale parameter is not finite or not positive.
BadScale,
/// The shape parameter is not finite.
BadShape,
saona-raimundo marked this conversation as resolved.
Show resolved Hide resolved
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::BadScale => "scale parameter is non-finite in skew normal distribution",
saona-raimundo marked this conversation as resolved.
Show resolved Hide resolved
Error::BadShape => "shape parameter is non-finite in skew normal distribution",
})
}
}

#[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for Error {}

impl<F> SkewNormal<F>
where
F: Float,
StandardNormal: Distribution<F>,
{
/// Construct, from location, scale and shape.
///
/// Parameters:
///
/// - location (unrestricted)
saona-raimundo marked this conversation as resolved.
Show resolved Hide resolved
/// - scale (must be finite and positive)
saona-raimundo marked this conversation as resolved.
Show resolved Hide resolved
/// - shape (must be finite)
#[inline]
pub fn new(location: F, scale: F, shape: F) -> Result<SkewNormal<F>, Error> {
if !scale.is_finite() || !(scale > F::zero()) {
return Err(Error::BadScale);
}
if !shape.is_finite() {
return Err(Error::BadShape);
}
Ok(SkewNormal {
location,
scale,
shape,
})
}

/// Returns the location of the distribution.
pub fn location(&self) -> F {
self.location
}

/// Returns the scale of the distribution.
pub fn scale(&self) -> F {
self.scale
}

/// Returns the shape of the distribution.
pub fn shape(&self) -> F {
self.shape
}
}

impl<F> Distribution<F> for SkewNormal<F>
where
F: Float,
StandardNormal: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let linear_map = |x: F| -> F { x * self.scale + self.location };
let u_1: F = rng.sample(StandardNormal);
if self.shape == F::zero() {
linear_map(u_1)
} else {
let u_2 = rng.sample(StandardNormal);
let (u, v) = (u_1.max(u_2), u_1.min(u_2));
if self.shape == -F::one() {
linear_map(v)
} else if self.shape == F::one() {
linear_map(u)
} else {
let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v)
/ ((F::one() + self.shape * self.shape).sqrt()
* F::from(core::f64::consts::SQRT_2).unwrap());
linear_map(normalized)
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

fn test_samples<F: Float + core::fmt::Debug, D: Distribution<F>>(
distr: D, zero: F, expected: &[F],
) {
let mut rng = crate::test::rng(213);
let mut buf = [zero; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}

#[test]
#[should_panic]
fn invalid_scale_nan() {
SkewNormal::new(0.0, core::f64::NAN, 0.0).unwrap();
}

#[test]
#[should_panic]
fn invalid_scale_zero() {
SkewNormal::new(0.0, 0.0, 0.0).unwrap();
}

#[test]
#[should_panic]
fn invalid_scale_negative() {
SkewNormal::new(0.0, -1.0, 0.0).unwrap();
}

#[test]
#[should_panic]
fn invalid_scale_infinite() {
SkewNormal::new(0.0, core::f64::INFINITY, 0.0).unwrap();
}

#[test]
#[should_panic]
fn invalid_shape_nan() {
SkewNormal::new(0.0, 1.0, core::f64::NAN).unwrap();
}

#[test]
#[should_panic]
fn invalid_shape_infinite() {
SkewNormal::new(0.0, 1.0, core::f64::INFINITY).unwrap();
}

#[test]
fn skew_normal_value_stability() {
test_samples(
SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
0f32,
&[-0.11844189, 0.781378, 0.06563994, -1.1932899],
);
test_samples(
SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
0f64,
&[
-0.11844188827977231,
0.7813779637772346,
0.06563993969580051,
-1.1932899004186373,
],
);
}
}
Loading