Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement Truncated and Truncated{Normal}. Fix #292 #295

Merged
merged 5 commits into from
Nov 4, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ export
Poisson,
Rayleigh,
Skellam,
TDist,
SymTriangularDist,
TDist,
Truncated,
TruncatedNormal,
Uniform,
VonMises,
VonMisesFisher,
Expand Down Expand Up @@ -222,7 +223,6 @@ include("matrixvariates.jl")

# others
include("truncate.jl")
include(joinpath("univariate", "truncated", "normal.jl"))
include("conjugates.jl")
include("qq.jl")
include("estimators.jl")
Expand Down
7 changes: 0 additions & 7 deletions src/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,5 @@ function test_stats(d::ContinuousUnivariateDistribution, xs::AbstractVector{Floa
@test_approx_eq_eps var(d) xvar 5.0 * vd * (kd + 2) / sqrt(n)
end
end

# test entropy
if applicable(entropy, d)
xentropy = mean(-logpdf(d, xs))
dentropy = entropy(d)
@test_approx_eq_eps dentropy xentropy 1.0e-2 * (abs(dentropy) + 1.0e-3)
end
end

112 changes: 56 additions & 56 deletions src/truncate.jl
Original file line number Diff line number Diff line change
@@ -1,102 +1,102 @@

immutable Truncated{D<:UnivariateDistribution,S<:ValueSupport} <: Distribution{Univariate,S}
untruncated::D
lower::Float64
upper::Float64
nc::Float64
function Truncated{T<:UnivariateDistribution}(d::T, l::Real, u::Real, nc::Real)
if l >= u
error("upper must be > lower")
end
new(d, float64(l), float64(u), float64(nc))
end
end
lower::Float64 # lower bound
upper::Float64 # upper bound
lcdf::Float64 # cdf of lower bound
ucdf::Float64 # cdf of upper bound

function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Real, u::Real, nc::Real)
Truncated{typeof(d),S}(d,l,u,nc)
tp::Float64 # the probability of the truncated part, i.e. ucdf - lcdf
logtp::Float64 # log(tp), i.e. log(ucdf - lcdf)
end

function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Real, u::Real)
Truncated{typeof(d),S}(d,l,u, cdf(d, u) - cdf(d, l))
### Constructors

function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Float64, u::Float64)
l < u || error("lower bound should be less than upper bound.")
lcdf = isinf(l) ? 0.0 : cdf(d, l)
ucdf = isinf(u) ? 1.0 : cdf(d, u)
tp = ucdf - lcdf
Truncated{typeof(d),S}(d, l, u, lcdf, ucdf, tp, log(tp))
end

insupport(d::Truncated, x::Real) =
x >= d.lower && x <= d.upper && insupport(d.untruncated, x)
Truncated(d::UnivariateDistribution, l::Real, u::Real) = Truncated(d, float64(l), float64(u))


### range and support

islowerbounded(d::Truncated) = islowerbounded(d.untruncated) || isfinite(d.lower)
isupperbounded(d::Truncated) = isupperbounded(d.untruncated) || isfinite(d.upper)

minimum(d::Truncated) = max(minimum(d.untruncated), d.lower)
maximum(d::Truncated) = min(maximum(d.untruncated), d.upper)

function pdf(d::Truncated, x::Real)
if !insupport(d, x)
return 0.0
else
return pdf(d.untruncated, x) / d.nc
end
end
insupport(d::Truncated, x::Real) =
d.lower <= x <= d.upper && insupport(d.untruncated, x)

function logpdf(d::Truncated, x::Real)
if !insupport(d, x)
return -Inf
else
return logpdf(d.untruncated, x) - log(d.nc)
end
end

function cdf(d::Truncated, x::Real)
if x < d.lower
return 0.0
elseif x > d.upper
return 1.0
else
return (cdf(d.untruncated, x) - cdf(d.untruncated, d.lower)) / d.nc
end
end
### evaluation

pdf(d::Truncated, x::Real) = d.lower <= x <= d.upper ? pdf(d.untruncated, x) / d.tp : 0.0

logpdf(d::Truncated, x::Real) = d.lower <= x <= d.upper ? logpdf(d.untruncated, x) - d.logtp : -Inf

cdf(d::Truncated, x::Real) = x <= d.lower ? 0.0 :
x >= d.upper ? 1.0 :
(cdf(d.untruncated, x) - d.lcdf) / d.tp

logcdf(d::Truncated, x::Real) = x <= d.lower ? -Inf :
x >= d.upper ? 0.0 :
log(cdf(d.untruncated, x) - d.lcdf) - d.logtp

ccdf(d::Truncated, x::Real) = x <= d.lower ? 1.0 :
x >= d.upper ? 0.0 :
(d.ucdf - cdf(d.untruncated, x)) / d.tp

logccdf(d::Truncated, x::Real) = x <= d.lower ? 0.0 :
x >= d.upper ? -Inf :
log(d.ucdf - cdf(d.untruncated, x)) - d.logtp

function quantile(d::Truncated, p::Real)
top = cdf(d.untruncated, d.upper)
bottom = cdf(d.untruncated, d.lower)
return quantile(d.untruncated, bottom + p * (top - bottom))
end

median(d::Truncated) = quantile(d, 0.5)
quantile(d::Truncated, p::Real) = quantile(d.untruncated, d.lcdf + p * d.tp)


## random number generation

function rand(d::Truncated)
if d.nc > 0.25
d0 = d.untruncated
if d.tp > 0.25
while true
r = rand(d.untruncated)
r = rand(d0)
if d.lower <= r <= d.upper
return r
end
end
else
return quantile(d.untruncated, cdf(d.untruncated, d.lower) + rand() * d.nc)
return quantile(d0, d.lcdf + rand() * d.tp)
end
end

# from fallbacks
function rand{D<:ContinuousUnivariateDistribution}(d::Truncated{D}, dims::Dims)
return rand!(d, Array(Float64, dims))
end

function rand{D<:DiscreteUnivariateDistribution}(d::Truncated{D}, dims::Dims)
return rand!(d, Array(Int, dims))
end
## show

function show(io::IO, d::Truncated)
print(io, "Truncated(")
d0 = d.untruncated
uml, namevals = _use_multline_show(d0)
uml ? show_multline(io, d0, namevals) :
show_oneline(io, d0, namevals)
print(io, ", range = ($(d.lower), $(d.upper)), prop = $(d.nc) )")
print(io, ", range=($(d.lower), $(d.upper)), tp=$(d.tp))")
uml && println(io)
end

_use_multline_show(d::Truncated) = _use_multline_show(d.untruncated)


### specialized truncated distributions

include(joinpath("truncated", "normal.jl"))




128 changes: 128 additions & 0 deletions src/truncated/normal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Truncated normal distribution

TruncatedNormal(mu::Float64, sigma::Float64, a::Float64, b::Float64) =
Truncated(Normal(mu, sigma), a, b)

TruncatedNormal(mu::Real, sigma::Real, a::Real, b::Real) =
TruncatedNormal(float64(mu), float64(sigma), float64(a), float64(b))

### statistics

minimum(d::Truncated{Normal}) = d.lower
maximum(d::Truncated{Normal}) = d.upper


function mode(d::Truncated{Normal})
μ = mean(d.untruncated)
d.upper < mu ? d.upper :
d.lower > mu ? d.lower : μ
end

modes(d::Truncated{Normal}) = [mode(d)]


function mean(d::Truncated{Normal})
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
μ + ((φ(a) - φ(b)) / d.tp) * σ
end

function var(d::Truncated{Normal})
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
z = d.tp
φa = φ(a)
φb = φ(b)
aφa = isinf(a) ? 0.0 : a * φa
bφb = isinf(b) ? 0.0 : b * φb
t1 = (aφa - bφb) / z
t2 = abs2((φa - φb) / z)
abs2(σ) * (1 + t1 - t2)
end

function entropy(d::Truncated{Normal})
d0 = d.untruncated
z = d.tp
μ = mean(d0)
σ = std(d0)
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
aφa = isinf(a) ? 0.0 : a * φ(a)
bφb = isinf(b) ? 0.0 : b * φ(b)
0.5 * (log2π + 1.) + log(σ * z) + (aφa - bφb) / (2.0 * z)
end


### sampling

## Benchmarks doesn't seem to show that this specialized
## sampler is faster than the generic quantile-based method

# function rand(d::Truncated{Normal})
# d0 = d.untruncated
# μ = mean(d0)
# σ = std(d0)
# a = (d.lower - μ) / σ
# b = (d.upper - μ) / σ
# z = randnt(a, b, d.tp)
# return μ + σ * z
# end

# Rejection sampler based on algorithm from Robert (1995)
#
# - Available at http://arxiv.org/abs/0907.4010
#
# function randnt(lb::Float64, ub::Float64, tp::Float64)
# r::Float64
# if tp > 0.3 # has considerable chance of falling in [lb, ub]
# r = randn()
# while r < lb || r > ub
# r = randn()
# end
# return r

# else
# span = ub - lb
# if lb > 0 && span > 2.0 / (lb + sqrt(lb^2 + 4.0)) * exp((lb^2 - lb * sqrt(lb^2 + 4.0)) / 4.0)
# a = (lb + sqrt(lb^2 + 4.0))/2.0
# while true
# r = rand(Exponential(1.0 / a)) + lb
# u = rand()
# if u < exp(-0.5 * (r - a)^2) && r < ub
# return r
# end
# end
# elseif ub < 0 && ub - lb > 2.0 / (-ub + sqrt(ub^2 + 4.0)) * exp((ub^2 + ub * sqrt(ub^2 + 4.0)) / 4.0)
# a = (-ub + sqrt(ub^2 + 4.0)) / 2.0
# while true
# r = rand(Exponential(1.0 / a)) - ub
# u = rand()
# if u < exp(-0.5 * (r - a)^2) && r < -lb
# return -r
# end
# end
# else
# while true
# r = lb + rand() * (ub - lb)
# u = rand()
# if lb > 0
# rho = exp((lb^2 - r^2) * 0.5)
# elseif ub < 0
# rho = exp((ub^2 - r^2) * 0.5)
# else
# rho = exp(-r^2 * 0.5)
# end
# if u < rho
# return r
# end
# end
# end
# end
# end

Loading