Skip to content

Commit

Permalink
Add tests for log-normal, chi squared, F, student T.
Browse files Browse the repository at this point in the history
  • Loading branch information
huonw committed Dec 7, 2013
1 parent df16d31 commit 0424b8a
Showing 1 changed file with 87 additions and 16 deletions.
103 changes: 87 additions & 16 deletions std_dists.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/// Distributional tests for distributions in the standard lib
use std::num;
use std::{num, vec, cmp};
use std::rand::{Rng, StdRng};
use std::rand::distributions::{Sample, Exp1, StandardNormal, Gamma};
use std::rand::distributions::Sample;
use std::rand::distributions::{ChiSquared, Exp, FisherF, Gamma, LogNormal, Normal, StudentT};

#[cfg(test)]
use std::rand::distributions::{RandSample};
Expand Down Expand Up @@ -118,16 +119,6 @@ pub fn ks_test_dist<S: Sample<f64>>(name: &str,
}
}

struct DirectExpSample;
impl Sample<f64> for DirectExpSample {
fn sample<R: Rng>(&mut self, r: &mut R) -> f64 { *r.gen::<Exp1>() }
}
struct DirectNormSample;
impl Sample<f64> for DirectNormSample {
fn sample<R: Rng>(&mut self, r: &mut R) -> f64 { *r.gen::<StandardNormal>() }
}


#[test]
fn t_test_unif() {
let mut moments = [0., .. NUM_MOMENTS];
Expand All @@ -149,7 +140,7 @@ fn t_test_exp() {
prod *= i as f64 + 1.;
*m = prod
}
t_test_mean_var("Exp(1)", DirectExpSample,
t_test_mean_var("Exp(1)", Exp::new(1.0),
moments);
}
#[test]
Expand All @@ -165,7 +156,7 @@ fn t_test_norm() {
*m = prod;
}
}
t_test_mean_var("N(0, 1)", DirectNormSample,
t_test_mean_var("N(0, 1)", Normal::new(0.0, 1.0),
moments);
}

Expand Down Expand Up @@ -196,19 +187,99 @@ fn t_test_gamma_large() { test_gamma(2.5, 4.) }
#[test]
fn t_test_gamma_very_large() { test_gamma(1000., 5.) }

#[test]
fn t_test_t() {
static DOF: uint = 100;

// k-th moments are only defined for k < dof
let mut moments = vec::from_elem(cmp::min(NUM_MOMENTS, DOF - 1), 0.0);
let mut current_moment = 1.;
for (i, m) in moments.mut_iter().enumerate() {
// k even:
// E[T^k] = dof^{k/2} [(2 - 1) / (dof - 2)] * .. * [(2(k/2) - 1)/(dof - 2(k/2))]
// k odd: E[T^k] = 0
let k = i + 1;

if k % 2 == 0 {
current_moment *= DOF as f64 * (k as f64 - 1.0) / (DOF as f64 - k as f64);
*m = current_moment;
}
}
t_test_mean_var(format!("StudentT({})", DOF),
StudentT::new(DOF as f64),
moments)
}

#[test]
fn t_test_log_normal() {
let mut moments = [0.0, .. NUM_MOMENTS];
for (i, m) in moments.mut_iter().enumerate() {
let k = (i + 1) as f64;
*m = num::exp(0.5 * k * k);
}
t_test_mean_var("ln N(0, 1)",
LogNormal::new(0.0, 1.0),
moments)
}

fn test_chi_squared(dof: f64) {
let mut moments = [0.0, .. NUM_MOMENTS];
for (i, m) in moments.mut_iter().enumerate() {
let k = (i + 1) as f64;
let log_frac = (k + dof * 0.5).lgamma().n1() - (dof * 0.5).lgamma().n1();
*m = 2f64.pow(&k) * num::exp(log_frac)
}
t_test_mean_var(format!("χ²({})", dof),
ChiSquared::new(dof),
moments)
}
#[test]
fn t_test_chi_squared_one() {
test_chi_squared(1.0)
}
#[test]
fn t_test_chi_squared_large() {
test_chi_squared(100.0)
}

#[test]
fn test_f() {
static D1: uint = 10;
static D2: uint = 20;
let mut moments = vec::from_elem(cmp::min(NUM_MOMENTS, (D2 - 1) / 2), 0.0);

let ratio = D2 as f64 / D1 as f64;
for (i, m) in moments.mut_iter().enumerate() {
let k = (i + 1) as f64;
let log_frac_1 = (D1 as f64 * 0.5 + k).lgamma().n1() - (D1 as f64 * 0.5).lgamma().n1();
let log_frac_2 = (D2 as f64 * 0.5 - k).lgamma().n1() - (D2 as f64 * 0.5).lgamma().n1();
*m = ratio.pow(&k) * num::exp(log_frac_1 + log_frac_2);
}
t_test_mean_var(format!("F({}, {})", D1, D2),
FisherF::new(D1 as f64, D2 as f64),
moments)
}
#[test]
fn ks_test_unif() {
ks_test_dist("U(0, 1)", RandSample::<f64>, ::unif_cdf)
}

#[test]
fn ks_test_exp() {
ks_test_dist("Exp(1)", DirectExpSample, ::exp_cdf)
ks_test_dist("Exp(1)", Exp::new(1.0), ::exp_cdf)
}

#[test]
fn ks_test_norm() {
ks_test_dist("N(0, 1)", DirectNormSample, ::normal_cdf)
ks_test_dist("N(0, 1)", Normal::new(0.0, 1.0), ::normal_cdf)
}

#[test]
fn ks_test_log_normal() {
fn cdf(x: f64) -> f64 {
::normal_cdf(x.ln())
}
ks_test_dist("ln N(0, 1)", LogNormal::new(0.0, 1.0), cdf)
}

// Don't have the infrastructure (specifically, the CDF is awkward to
Expand Down

0 comments on commit 0424b8a

Please sign in to comment.