From 5af435810f5b51f399fe5824dadd6da79f99d205 Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Tue, 4 Nov 2014 11:25:10 +0800 Subject: [PATCH 1/5] clean up the implementations of generic methods for truncated distributions --- src/Distributions.jl | 1 - src/truncate.jl | 112 +++++++++++------------ src/{univariate => }/truncated/normal.jl | 6 +- 3 files changed, 59 insertions(+), 60 deletions(-) rename src/{univariate => }/truncated/normal.jl (99%) diff --git a/src/Distributions.jl b/src/Distributions.jl index 17b4d3104..30d49e704 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -222,7 +222,6 @@ include("matrixvariates.jl") # others include("truncate.jl") -include(joinpath("univariate", "truncated", "normal.jl")) include("conjugates.jl") include("qq.jl") include("estimators.jl") diff --git a/src/truncate.jl b/src/truncate.jl index 5e06f1b27..dfe7976b2 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -1,27 +1,29 @@ immutable Truncated{D<:UnivariateDistribution,S<:ValueSupport} <: Distribution{Univariate,S} untruncated::D - lower::Float64 - upper::Float64 - nc::Float64 - function Truncated{T<:UnivariateDistribution}(d::T, l::Real, u::Real, nc::Real) - if l >= u - error("upper must be > lower") - end - new(d, float64(l), float64(u), float64(nc)) - end -end + lower::Float64 # lower bound + upper::Float64 # upper bound + lcdf::Float64 # cdf of lower bound + ucdf::Float64 # cdf of upper bound -function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Real, u::Real, nc::Real) - Truncated{typeof(d),S}(d,l,u,nc) + tp::Float64 # the probability of the truncated part, i.e. ucdf - lcdf + logtp::Float64 # log(tp), i.e. log(ucdf - lcdf) end -function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Real, u::Real) - Truncated{typeof(d),S}(d,l,u, cdf(d, u) - cdf(d, l)) +### Constructors + +function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Float64, u::Float64) + l < u || error("lower bound should be less than upper bound.") + lcdf = isinf(l) ? 0.0 : cdf(d, l) + ucdf = isinf(u) ? 1.0 : cdf(d, u) + tp = ucdf - lcdf + Truncated{typeof(d),S}(d, l, u, lcdf, ucdf, tp, log(tp)) end -insupport(d::Truncated, x::Real) = - x >= d.lower && x <= d.upper && insupport(d.untruncated, x) +Truncated(d::UnivariateDistribution, l::Real, u::Real) = Truncated(d, float64(l), float64(u)) + + +### range and support islowerbounded(d::Truncated) = islowerbounded(d.untruncated) || isfinite(d.lower) isupperbounded(d::Truncated) = isupperbounded(d.untruncated) || isfinite(d.upper) @@ -29,61 +31,54 @@ isupperbounded(d::Truncated) = isupperbounded(d.untruncated) || isfinite(d.upper minimum(d::Truncated) = max(minimum(d.untruncated), d.lower) maximum(d::Truncated) = min(maximum(d.untruncated), d.upper) -function pdf(d::Truncated, x::Real) - if !insupport(d, x) - return 0.0 - else - return pdf(d.untruncated, x) / d.nc - end -end +insupport(d::Truncated, x::Real) = + d.lower <= x <= d.upper && insupport(d.untruncated, x) -function logpdf(d::Truncated, x::Real) - if !insupport(d, x) - return -Inf - else - return logpdf(d.untruncated, x) - log(d.nc) - end -end -function cdf(d::Truncated, x::Real) - if x < d.lower - return 0.0 - elseif x > d.upper - return 1.0 - else - return (cdf(d.untruncated, x) - cdf(d.untruncated, d.lower)) / d.nc - end -end +### evaluation + +pdf(d::Truncated, x::Real) = d.lower <= x <= d.upper ? pdf(d.untruncated, x) / d.tp : 0.0 + +logpdf(d::Truncated, x::Real) = d.lower <= x <= d.upper ? logpdf(d.untruncated, x) - d.logtp : -Inf + +cdf(d::Truncated, x::Real) = x <= d.lower ? 0.0 : + x >= d.upper ? 1.0 : + (cdf(d.untruncated, x) - d.lcdf) / d.tp + +logcdf(d::Truncated, x::Real) = x <= d.lower ? -Inf : + x >= d.upper ? 0.0 : + log(cdf(d.untruncated, x) - d.lcdf) - d.logtp + +ccdf(d::Truncated, x::Real) = x <= d.lower ? 1.0 : + x >= d.upper ? 0.0 : + (d.ucdf - cdf(d.untruncated, x)) / d.tp + +logccdf(d::Truncated, x::Real) = x <= d.lower ? 0.0 : + x >= d.upper ? -Inf : + log(d.ucdf - cdf(d.untruncated, x)) - d.logtp -function quantile(d::Truncated, p::Real) - top = cdf(d.untruncated, d.upper) - bottom = cdf(d.untruncated, d.lower) - return quantile(d.untruncated, bottom + p * (top - bottom)) -end -median(d::Truncated) = quantile(d, 0.5) +quantile(d::Truncated, p::Real) = quantile(d.untruncated, d.lcdf + p * d.tp) + + +## random number generation function rand(d::Truncated) - if d.nc > 0.25 + d0 = d.untruncated + if d.tp > 0.25 while true - r = rand(d.untruncated) + r = rand(d0) if d.lower <= r <= d.upper return r end end else - return quantile(d.untruncated, cdf(d.untruncated, d.lower) + rand() * d.nc) + return quantile(d0, d.lcdf + p * d.tp) end end -# from fallbacks -function rand{D<:ContinuousUnivariateDistribution}(d::Truncated{D}, dims::Dims) - return rand!(d, Array(Float64, dims)) -end -function rand{D<:DiscreteUnivariateDistribution}(d::Truncated{D}, dims::Dims) - return rand!(d, Array(Int, dims)) -end +## show function show(io::IO, d::Truncated) print(io, "Truncated(") @@ -91,12 +86,17 @@ function show(io::IO, d::Truncated) uml, namevals = _use_multline_show(d0) uml ? show_multline(io, d0, namevals) : show_oneline(io, d0, namevals) - print(io, ", range = ($(d.lower), $(d.upper)), prop = $(d.nc) )") + print(io, ", range = ($(d.lower), $(d.upper)), prop = $(d.tp) )") uml && println(io) end _use_multline_show(d::Truncated) = _use_multline_show(d.untruncated) +### specialized truncated distributions + +include(joinpath("truncated", "normal.jl")) + + diff --git a/src/univariate/truncated/normal.jl b/src/truncated/normal.jl similarity index 99% rename from src/univariate/truncated/normal.jl rename to src/truncated/normal.jl index c1c2683c8..ef0cde75c 100644 --- a/src/univariate/truncated/normal.jl +++ b/src/truncated/normal.jl @@ -7,14 +7,14 @@ function entropy(d::Truncated{Normal}) phi_b = pdf(d.untruncated, b) * s a_phi_a = a == -Inf ? 0.0 : a * phi_a b_phi_b = b == Inf ? 0.0 : b * phi_b - z = d.nc + z = d.tp return entropy(d.untruncated) + log(z) + 0.5 * (a_phi_a - b_phi_b) / z - 0.5 * ((phi_a - phi_b) / z)^2 end function mean(d::Truncated{Normal}) delta = pdf(d.untruncated, d.lower) - pdf(d.untruncated, d.upper) - return mean(d.untruncated) + delta * var(d.untruncated) / d.nc + return mean(d.untruncated) + delta * var(d.untruncated) / d.tp end function modes(d::Truncated{Normal}) @@ -43,7 +43,7 @@ function var(d::Truncated{Normal}) phi_b = pdf(d.untruncated, b) * s a_phi_a = a == -Inf ? 0.0 : a * phi_a b_phi_b = b == Inf ? 0.0 : b * phi_b - z = d.nc + z = d.tp return s^2 * (1 + (a_phi_a - b_phi_b) / z - ((phi_a - phi_b) / z)^2) end From 5fc11793faed178d5044bbc78900773de427fc29 Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Tue, 4 Nov 2014 12:41:29 +0800 Subject: [PATCH 2/5] reimplemented mean, var, entropy for truncated normal --- src/testutils.jl | 7 ---- src/truncated/normal.jl | 83 ++++++++++++++++++++++++----------------- test/truncate.jl | 5 ++- 3 files changed, 52 insertions(+), 43 deletions(-) diff --git a/src/testutils.jl b/src/testutils.jl index 17896d173..fdf6c8576 100644 --- a/src/testutils.jl +++ b/src/testutils.jl @@ -434,12 +434,5 @@ function test_stats(d::ContinuousUnivariateDistribution, xs::AbstractVector{Floa @test_approx_eq_eps var(d) xvar 5.0 * vd * (kd + 2) / sqrt(n) end end - - # test entropy - if applicable(entropy, d) - xentropy = mean(-logpdf(d, xs)) - dentropy = entropy(d) - @test_approx_eq_eps dentropy xentropy 1.0e-2 * (abs(dentropy) + 1.0e-3) - end end diff --git a/src/truncated/normal.jl b/src/truncated/normal.jl index ef0cde75c..f504ad3eb 100644 --- a/src/truncated/normal.jl +++ b/src/truncated/normal.jl @@ -1,33 +1,60 @@ +# Truncated normal distribution -function entropy(d::Truncated{Normal}) - s = std(d.untruncated) - a = d.lower - b = d.upper - phi_a = pdf(d.untruncated, a) * s - phi_b = pdf(d.untruncated, b) * s - a_phi_a = a == -Inf ? 0.0 : a * phi_a - b_phi_b = b == Inf ? 0.0 : b * phi_b - z = d.tp - return entropy(d.untruncated) + log(z) + - 0.5 * (a_phi_a - b_phi_b) / z - 0.5 * ((phi_a - phi_b) / z)^2 +### statistics + +minimum(d::Truncated{Normal}) = d.lower +maximum(d::Truncated{Normal}) = d.upper + + +function mode(d::Truncated{Normal}) + μ = mean(d.untruncated) + d.upper < mu ? d.upper : + d.lower > mu ? d.lower : μ end +modes(d::Truncated{Normal}) = [mode(d)] + + function mean(d::Truncated{Normal}) - delta = pdf(d.untruncated, d.lower) - pdf(d.untruncated, d.upper) - return mean(d.untruncated) + delta * var(d.untruncated) / d.tp + d0 = d.untruncated + μ = mean(d0) + σ = std(d0) + a = (d.lower - μ) / σ + b = (d.upper - μ) / σ + μ + ((φ(a) - φ(b)) / d.tp) * σ end -function modes(d::Truncated{Normal}) - mu = mean(d.untruncated) - if d.upper < mu - return [d.upper] - elseif d.lower > mu - return [d.lower] - else - return [mu] - end +function var(d::Truncated{Normal}) + d0 = d.untruncated + μ = mean(d0) + σ = std(d0) + a = (d.lower - μ) / σ + b = (d.upper - μ) / σ + z = d.tp + φa = φ(a) + φb = φ(b) + aφa = isinf(a) ? 0.0 : a * φa + bφb = isinf(b) ? 0.0 : b * φb + t1 = (aφa - bφb) / z + t2 = abs2((φa - φb) / z) + abs2(σ) * (1 + t1 - t2) +end + +function entropy(d::Truncated{Normal}) + d0 = d.untruncated + z = d.tp + μ = mean(d0) + σ = std(d0) + a = (d.lower - μ) / σ + b = (d.upper - μ) / σ + aφa = isinf(a) ? 0.0 : a * φ(a) + bφb = isinf(b) ? 0.0 : b * φ(b) + 0.5 * (log2π + 1.) + log(σ * z) + (aφa - bφb) / (2.0 * z) end + +### sampling + function rand(d::Truncated{Normal}) mu = mean(d.untruncated) sigma = std(d.untruncated) @@ -35,18 +62,6 @@ function rand(d::Truncated{Normal}) return mu + sigma * z end -function var(d::Truncated{Normal}) - s = std(d.untruncated) - a = d.lower - b = d.upper - phi_a = pdf(d.untruncated, a) * s - phi_b = pdf(d.untruncated, b) * s - a_phi_a = a == -Inf ? 0.0 : a * phi_a - b_phi_b = b == Inf ? 0.0 : b * phi_b - z = d.tp - return s^2 * (1 + (a_phi_a - b_phi_b) / z - ((phi_a - phi_b) / z)^2) -end - # Rejection sampler based on algorithm from Robert (1992) # - Available at http://arxiv.org/abs/0907.4010 function randnt(lower::Real, upper::Real) diff --git a/test/truncate.jl b/test/truncate.jl index ce5528310..06057f165 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -5,9 +5,10 @@ n_tsamples = 10^6 for (pa0, lb, ub) in [ ((0, 1), -2, 2), - # ((3, 10), 7, 8), + ((3, 10), 7, 8), ((27, 3), 0, Inf), - # ((-5, 1), -Inf, -10) + ((-5, 1), -Inf, -10), + ((1.8, 1.2), -Inf, 0) ] d = Truncated(Normal(pa0...), lb, ub) From f9e9401e1dc335db641e562861ec7fab579bf424 Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Tue, 4 Nov 2014 12:45:51 +0800 Subject: [PATCH 3/5] introduce TruncatedNormal to directly construct a truncated normal distribution --- src/Distributions.jl | 3 ++- src/truncated/normal.jl | 6 ++++++ test/truncate.jl | 14 +++++++------- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/Distributions.jl b/src/Distributions.jl index 30d49e704..20de8af46 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -104,9 +104,10 @@ export Poisson, Rayleigh, Skellam, - TDist, SymTriangularDist, + TDist, Truncated, + TruncatedNormal, Uniform, VonMises, VonMisesFisher, diff --git a/src/truncated/normal.jl b/src/truncated/normal.jl index f504ad3eb..ef1802477 100644 --- a/src/truncated/normal.jl +++ b/src/truncated/normal.jl @@ -1,5 +1,11 @@ # Truncated normal distribution +TruncatedNormal(mu::Float64, sigma::Float64, a::Float64, b::Float64) = + Truncated(Normal(mu, sigma), a, b) + +TruncatedNormal(mu::Real, sigma::Real, a::Real, b::Real) = + TruncatedNormal(float64(mu), float64(sigma), float64(a), float64(b)) + ### statistics minimum(d::Truncated{Normal}) = d.lower diff --git a/test/truncate.jl b/test/truncate.jl index 06057f165..ffa7ac27e 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -3,15 +3,15 @@ using Base.Test n_tsamples = 10^6 -for (pa0, lb, ub) in [ - ((0, 1), -2, 2), - ((3, 10), 7, 8), - ((27, 3), 0, Inf), - ((-5, 1), -Inf, -10), - ((1.8, 1.2), -Inf, 0) +for (mu, sigma, lb, ub) in [ + (0, 1, -2, 2), + (3, 10, 7, 8), + (27, 3, 0, Inf), + (-5, 1, -Inf, -10), + (1.8, 1.2, -Inf, 0) ] - d = Truncated(Normal(pa0...), lb, ub) + d = TruncatedNormal(mu, sigma, lb, ub) println(" testing $d") @test d.lower == lb From 67ec4c499b2edb2bc4859ff4b0eadec67f2fc495 Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Tue, 4 Nov 2014 13:02:46 +0800 Subject: [PATCH 4/5] incorporate the testing of truncated normal to the standard testing framework for continuous distributions --- src/truncate.jl | 2 +- test/continuous_ref.csv | 15 ++++++++++----- test/continuous_ref.txt | 5 +++++ test/prepdref.py | 11 +++++++++++ test/runtests.jl | 3 +-- test/truncate.jl | 24 ------------------------ 6 files changed, 28 insertions(+), 32 deletions(-) delete mode 100644 test/truncate.jl diff --git a/src/truncate.jl b/src/truncate.jl index dfe7976b2..b020b9cee 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -86,7 +86,7 @@ function show(io::IO, d::Truncated) uml, namevals = _use_multline_show(d0) uml ? show_multline(io, d0, namevals) : show_oneline(io, d0, namevals) - print(io, ", range = ($(d.lower), $(d.upper)), prop = $(d.tp) )") + print(io, ", range=($(d.lower), $(d.upper)), tp=$(d.tp))") uml && println(io) end diff --git a/test/continuous_ref.csv b/test/continuous_ref.csv index dc0ffe480..c594d1d7e 100644 --- a/test/continuous_ref.csv +++ b/test/continuous_ref.csv @@ -1,10 +1,10 @@ -"Arcsine()", 5.0000000000000000e-01, 1.2500000000000000e-01, -2.4156447527049044e-01, 1.4644660940672624e-01, 4.9999999999999989e-01, 8.5355339059327373e-01, -1.0500911500948222e-01, -4.5158270528945482e-01, -1.0500911500948222e-01 +"Arcsine()", 5.0000000000000000e-01, 1.2500000000000000e-01, -2.4156447527049044e-01, 1.4644660940672624e-01, 4.9999999999999989e-01, 8.5355339059327373e-01, -1.0500911500948223e-01, -4.5158270528945482e-01, -1.0500911500948223e-01 "Beta(2.0, 2.0)", 5.0000000000000000e-01, 5.0000000000000003e-02, -1.2509280256138822e-01, 3.2635182233306964e-01, 5.0000000000000000e-01, 6.7364817766693030e-01, 2.7693290334825971e-01, 4.0546510810816438e-01, 2.7693290334825971e-01 "Beta(3.0, 4.0)", 4.2857142857142855e-01, 3.0612244897959183e-02, -3.4434456222662974e-01, 2.9691654830725756e-01, 4.2140719069071314e-01, 5.5319825174395865e-01, 6.0889717985308245e-01, 7.2456419155863738e-01, 4.9334598643820282e-01 "Beta(17.0, 13.0)", 5.6666666666666665e-01, 7.9211469534050186e-03, -1.0018500631798657e+00, 5.0619868878674412e-01, 5.6816746281971997e-01, 6.2873506375525445e-01, 1.2372834658662732e+00, 1.4759302328255508e+00, 1.2831670640930035e+00 -"BetaPrime(3.0, 3.0)", 1.5000000000000000e+00, 3.7500000000000000e+00, 1.2988026183378452e+00, 5.6112466085697021e-01, 1.0000000000000000e+00, 1.7821351827110390e+00, -4.2686597864526687e-01, -7.5768570169751603e-01, -1.5824903512024453e+00 -"BetaPrime(3.0, 5.0)", 7.5000000000000000e-01, 4.3750000000000000e-01, 5.8889679270198236e-01, 3.3882066973930408e-01, 5.7261408676656855e-01, 9.4589247295867829e-01, 1.5507869745509506e-01, -8.3040285575946449e-02, -7.8305822708285078e-01 -"BetaPrime(5.0, 3.0)", 2.5000000000000000e+00, 8.7500000000000000e+00, 1.7555634593678637e+00, 1.0572026193126152e+00, 1.7463768759982674e+00, 2.9514137988376672e+00, -8.9431098976115475e-01, -1.1981268560001972e+00, -2.0094899213240280e+00 +"BetaPrime(3.0, 3.0)", 1.5000000000000000e+00, 3.7500000000000000e+00, 1.2988026183378452e+00, 5.6112466085697021e-01, 9.9999999999999989e-01, 1.7821351827110392e+00, -4.2686597864526687e-01, -7.5768570169751603e-01, -1.5824903512024462e+00 +"BetaPrime(3.0, 5.0)", 7.5000000000000000e-01, 4.3750000000000000e-01, 5.8889679270198236e-01, 3.3882066973930408e-01, 5.7261408676656844e-01, 9.4589247295867784e-01, 1.5507869745509506e-01, -8.3040285575946449e-02, -7.8305822708284989e-01 +"BetaPrime(5.0, 3.0)", 2.5000000000000000e+00, 8.7500000000000000e+00, 1.7555634593678637e+00, 1.0572026193126152e+00, 1.7463768759982674e+00, 2.9514137988376676e+00, -8.9431098976115475e-01, -1.1981268560001972e+00, -2.0094899213240280e+00 "Cauchy(0.0, 1.0)", inf, inf, 2.5310242469692907e+00, -9.9999999999999989e-01, 0.0000000000000000e+00, 9.9999999999999989e-01, -1.8378770664093453e+00, -1.1447298858494002e+00, -1.8378770664093453e+00 "Cauchy(10.0, 1.0)", inf, inf, 2.5310242469692907e+00, 9.0000000000000000e+00, 1.0000000000000000e+01, 1.1000000000000000e+01, -1.8378770664093453e+00, -1.1447298858494002e+00, -1.8378770664093453e+00 "Cauchy(2.0, 10.0)", inf, inf, 4.8336093399633366e+00, -7.9999999999999982e+00, 2.0000000000000000e+00, 1.1999999999999998e+01, -4.1404621594033912e+00, -3.4473149788434458e+00, -4.1404621594033912e+00 @@ -42,7 +42,7 @@ "NormalCanon(-1.0, 2.5)", -4.0000000000000002e-01, 4.0000000000000002e-01, 9.6079316726759512e-01, -8.2658477381152395e-01, -4.0000000000000002e-01, 2.6584773811523965e-02, -6.8826137882738148e-01, -4.6079316726759517e-01, -6.8826137882738148e-01 "NormalCanon(2.0, 0.8)", 2.5000000000000000e+00, 1.2500000000000002e+00, 1.5305103088617775e+00, 1.7458975342173546e+00, 2.5000000000000000e+00, 3.2541024657826454e+00, -1.2579785204215639e+00, -1.0305103088617775e+00, -1.2579785204215639e+00 "Pareto(1.0, 1.0)", inf, inf, 2.0000000000000000e+00, 1.3333333333333333e+00, 2.0000000000000000e+00, 4.0000000000000000e+00, -5.7536414490356169e-01, -1.3862943611198906e+00, -2.7725887222397811e+00 -"Pareto(2.0, 1.0)", 2.0000000000000000e+00, inf, 8.0685281944005471e-01, 1.1547005383792515e+00, 1.4142135623730951e+00, 2.0000000000000000e+00, 2.6162407188227410e-01, -3.4657359027997292e-01, -1.3862943611198906e+00 +"Pareto(2.0, 1.0)", 2.0000000000000000e+00, inf, 8.0685281944005471e-01, 1.1547005383792515e+00, 1.4142135623730951e+00, 2.0000000000000000e+00, 2.6162407188227416e-01, -3.4657359027997292e-01, -1.3862943611198906e+00 "Pareto(3.0, 2.0)", 3.0000000000000000e+00, 3.0000000000000000e+00, 9.2786822522516876e-01, 2.2012848325964178e+00, 2.5198420997897464e+00, 3.1748021039363987e+00, 2.1889011505789702e-02, -5.1873113263842940e-01, -1.4429273733850225e+00 "Rayleigh(1.0)", 1.2533141373155001e+00, 4.2920367320510344e-01, 9.4203424217079368e-01, 7.5852761644093214e-01, 1.1774100225154747e+00, 1.6651092223153954e+00, -5.6405814402542731e-01, -5.2983005057080479e-01, -8.7640364085077727e-01 "Rayleigh(3.0)", 3.7599424119465006e+00, 3.8628330588459310e+00, 2.0406465308389032e+00, 2.2755828493227965e+00, 3.5322300675464238e+00, 4.9953276669461859e+00, -1.6626704326935371e+00, -1.6284423392389145e+00, -1.9750159295188872e+00 @@ -50,6 +50,11 @@ "TDist(1.2)", 0.0000000000000000e+00, inf, 2.3401827213118382e+00, -9.3358614772330029e-01, 7.9820775349995142e-17, 9.3358614772330029e-01, -1.7122227129873395e+00, -1.1116320206507522e+00, -1.7122227129873395e+00 "TDist(5.0)", 0.0000000000000000e+00, 1.6666666666666667e+00, 1.6275026724163131e+00, -7.2668684379793969e-01, 6.9760036230033171e-17, 7.2668684379793969e-01, -1.2698241446815288e+00, -9.6861958905472412e-01, -1.2698241446815288e+00 "TDist(28.0)", 0.0000000000000000e+00, 1.0769230769230769e+00, 1.4549639186397185e+00, -6.8335284452749046e-01, 6.6975103488125950e-17, 6.8335284111072625e-01, -1.1676951605678694e+00, -9.2786520944663353e-01, -1.1676951581892965e+00 +"TruncatedNormal(0, 1, -2, 2)", 0.0000000000000000e+00, 7.7374130354992321e-01, 1.2592412726872442e+00, -6.3911191087127295e-01, 0.0000000000000000e+00, 6.3911191087127306e-01, -1.0766026382210476e+00, -8.7237062091228246e-01, -1.0766026382210476e+00 +"TruncatedNormal(3, 10, 7, 8)", 7.4962513762870771e+00, 8.3297130072637615e-02, -8.4383927631215272e-05, 7.2458921432975716e+00, 7.4943789877185045e+00, 7.7456727334081119e+00, 1.1444251309232101e-02, 5.8503934558418180e-04, -1.1024796691253780e-02 +"TruncatedNormal(27, 3, 0, Inf)", 2.7000000000000000e+01, 9.0000000000000000e+00, 2.5175508218727831e+00, 2.4976530749411754e+01, 2.7000000000000000e+01, 2.9023469250588246e+01, -2.2450190334325688e+00, -2.0175508218727822e+00, -2.2450190334325688e+00 +"TruncatedNormal(-5, 1, -Inf, -10)", -1.0186503967125853e+01, 3.2696434617054848e-02, -6.7979994296945212e-01, -1.0260933604821162e+01, -1.0132018332044298e+01, -1.0055183405359138e+01, 3.0734866361576252e-01, 9.7725378056468415e-01, 1.3686202298748498e+00 +"TruncatedNormal(1.8, 1.2, -Inf, 0)", -5.2641259994705192e-01, 2.1534709471229191e-01, 3.4932356414164489e-01, -7.5263810010822607e-01, -3.9956263907172240e-01, -1.7259902445127429e-01, -6.5780224130117315e-01, -7.5203120841459775e-02, 2.5359163330230439e-01 "Uniform(0.0, 1.0)", 5.0000000000000000e-01, 8.3333333333333329e-02, 0.0000000000000000e+00, 2.5000000000000000e-01, 5.0000000000000000e-01, 7.5000000000000000e-01, 0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00 "Uniform(3.0, 17.0)", 1.0000000000000000e+01, 1.6333333333333332e+01, 2.6390573296152584e+00, 6.5000000000000000e+00, 1.0000000000000000e+01, 1.3500000000000000e+01, -2.6390573296152584e+00, -2.6390573296152584e+00, -2.6390573296152584e+00 "Uniform(3.0, 3.1)", 3.0499999999999998e+00, 8.3333333333333480e-04, -2.3025850929940450e+00, 3.0249999999999999e+00, 3.0499999999999998e+00, 3.0750000000000002e+00, 2.3025850929940450e+00, 2.3025850929940450e+00, 2.3025850929940450e+00 diff --git a/test/continuous_ref.txt b/test/continuous_ref.txt index 842088249..8f5c956fd 100644 --- a/test/continuous_ref.txt +++ b/test/continuous_ref.txt @@ -50,6 +50,11 @@ Rayleigh(8.0) TDist(1.2) TDist(5.0) TDist(28.0) +TruncatedNormal(0, 1, -2, 2) +TruncatedNormal(3, 10, 7, 8) +TruncatedNormal(27, 3, 0, Inf) +TruncatedNormal(-5, 1, -Inf, -10) +TruncatedNormal(1.8, 1.2, -Inf, 0) Uniform(0.0, 1.0) Uniform(3.0, 17.0) Uniform(3.0, 3.1) diff --git a/test/prepdref.py b/test/prepdref.py index b94234ef7..d70ed1a98 100644 --- a/test/prepdref.py +++ b/test/prepdref.py @@ -28,6 +28,8 @@ def read_distr_list(filename): lst = [] for line in lines: s = line.strip() + if s.startswith("#"): + continue name, args = parse_distr(s) lst.append((s, name, args)) @@ -137,6 +139,15 @@ def to_scipy_dist(name, args): assert len(args) == 1 return t(args[0]) + elif name == "TruncatedNormal": + assert len(args) == 4 + mu, sig, a, b = args + za = (a - mu) / sig + zb = (b - mu) / sig + za = max(za, -1000.0) + zb = min(zb, 1000.0) + return truncnorm(za, zb, loc=mu, scale=sig) + elif name == "Uniform": assert len(args) == 2 return uniform(args[0], args[1] - args[0]) diff --git a/test/runtests.jl b/test/runtests.jl index 95ffdd24a..e956afd37 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,8 +2,7 @@ tests = [ "types", "samplers", "discrete", - "continuous", - "truncate", + "continuous", "fit", "multinomial", "dirichlet", diff --git a/test/truncate.jl b/test/truncate.jl deleted file mode 100644 index ffa7ac27e..000000000 --- a/test/truncate.jl +++ /dev/null @@ -1,24 +0,0 @@ -using Distributions -using Base.Test - -n_tsamples = 10^6 - -for (mu, sigma, lb, ub) in [ - (0, 1, -2, 2), - (3, 10, 7, 8), - (27, 3, 0, Inf), - (-5, 1, -Inf, -10), - (1.8, 1.2, -Inf, 0) - ] - - d = TruncatedNormal(mu, sigma, lb, ub) - println(" testing $d") - - @test d.lower == lb - @test d.upper == ub - @test minimum(d) == lb - @test maximum(d) == ub - - test_distr(d, n_tsamples) -end - From 8637e3abf8af5535f1470719dc11c1655fef7d4e Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Tue, 4 Nov 2014 13:34:15 +0800 Subject: [PATCH 5/5] fix sampling for truncated normal --- src/truncate.jl | 2 +- src/truncated/normal.jl | 116 ++++++++++++++++++++++------------------ 2 files changed, 64 insertions(+), 54 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index b020b9cee..6820832e4 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -73,7 +73,7 @@ function rand(d::Truncated) end end else - return quantile(d0, d.lcdf + p * d.tp) + return quantile(d0, d.lcdf + rand() * d.tp) end end diff --git a/src/truncated/normal.jl b/src/truncated/normal.jl index ef1802477..ebf5f0912 100644 --- a/src/truncated/normal.jl +++ b/src/truncated/normal.jl @@ -61,58 +61,68 @@ end ### sampling -function rand(d::Truncated{Normal}) - mu = mean(d.untruncated) - sigma = std(d.untruncated) - z = randnt((d.lower - mu) / sigma, (d.upper - mu) / sigma) - return mu + sigma * z -end +## Benchmarks doesn't seem to show that this specialized +## sampler is faster than the generic quantile-based method + +# function rand(d::Truncated{Normal}) +# d0 = d.untruncated +# μ = mean(d0) +# σ = std(d0) +# a = (d.lower - μ) / σ +# b = (d.upper - μ) / σ +# z = randnt(a, b, d.tp) +# return μ + σ * z +# end -# Rejection sampler based on algorithm from Robert (1992) +# Rejection sampler based on algorithm from Robert (1995) +# # - Available at http://arxiv.org/abs/0907.4010 -function randnt(lower::Real, upper::Real) - if (lower <= 0 && upper == Inf) || - (upper >= 0 && lower == Inf) || - (lower <= 0 && upper >= 0 && upper - lower > sqrt2π) - while true - r = randn() - if r > lower && r < upper - return r - end - end - elseif lower > 0 && upper - lower > 2.0 / (lower + sqrt(lower^2 + 4.0)) * exp((lower^2 - lower * sqrt(lower^2 + 4.0)) / 4.0) - a = (lower + sqrt(lower^2 + 4.0))/2.0 - while true - r = rand(Exponential(1.0 / a)) + lower - u = rand() - if u < exp(-0.5 * (r - a)^2) && r < upper - return r - end - end - elseif upper < 0 && upper - lower > 2.0 / (-upper + sqrt(upper^2 + 4.0)) * exp((upper^2 + upper * sqrt(upper^2 + 4.0)) / 4.0) - a = (-upper + sqrt(upper^2 + 4.0)) / 2.0 - while true - r = rand(Exponential(1.0 / a)) - upper - u = rand() - if u < exp(-0.5 * (r - a)^2) && r < -lower - return -r - end - end - else - while true - r = lower + rand() * (upper - lower) - u = rand() - if lower > 0 - rho = exp((lower^2 - r^2) * 0.5) - elseif upper < 0 - rho = exp((upper^2 - r^2) * 0.5) - else - rho = exp(-r^2 * 0.5) - end - if u < rho - return r - end - end - end - return 0.0 -end +# +# function randnt(lb::Float64, ub::Float64, tp::Float64) +# r::Float64 +# if tp > 0.3 # has considerable chance of falling in [lb, ub] +# r = randn() +# while r < lb || r > ub +# r = randn() +# end +# return r + +# else +# span = ub - lb +# if lb > 0 && span > 2.0 / (lb + sqrt(lb^2 + 4.0)) * exp((lb^2 - lb * sqrt(lb^2 + 4.0)) / 4.0) +# a = (lb + sqrt(lb^2 + 4.0))/2.0 +# while true +# r = rand(Exponential(1.0 / a)) + lb +# u = rand() +# if u < exp(-0.5 * (r - a)^2) && r < ub +# return r +# end +# end +# elseif ub < 0 && ub - lb > 2.0 / (-ub + sqrt(ub^2 + 4.0)) * exp((ub^2 + ub * sqrt(ub^2 + 4.0)) / 4.0) +# a = (-ub + sqrt(ub^2 + 4.0)) / 2.0 +# while true +# r = rand(Exponential(1.0 / a)) - ub +# u = rand() +# if u < exp(-0.5 * (r - a)^2) && r < -lb +# return -r +# end +# end +# else +# while true +# r = lb + rand() * (ub - lb) +# u = rand() +# if lb > 0 +# rho = exp((lb^2 - r^2) * 0.5) +# elseif ub < 0 +# rho = exp((ub^2 - r^2) * 0.5) +# else +# rho = exp(-r^2 * 0.5) +# end +# if u < rho +# return r +# end +# end +# end +# end +# end +