Skip to content

Commit

Permalink
Merge pull request #295 from JuliaStats/dh/trunc2
Browse files Browse the repository at this point in the history
Reimplement Truncated and Truncated{Normal}. Fix #292
  • Loading branch information
lindahua committed Nov 4, 2014
2 parents b943947 + 8637e3a commit eec38f3
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 192 deletions.
4 changes: 2 additions & 2 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ export
Poisson,
Rayleigh,
Skellam,
TDist,
SymTriangularDist,
TDist,
Truncated,
TruncatedNormal,
Uniform,
VonMises,
VonMisesFisher,
Expand Down Expand Up @@ -222,7 +223,6 @@ include("matrixvariates.jl")

# others
include("truncate.jl")
include(joinpath("univariate", "truncated", "normal.jl"))
include("conjugates.jl")
include("qq.jl")
include("estimators.jl")
Expand Down
7 changes: 0 additions & 7 deletions src/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,5 @@ function test_stats(d::ContinuousUnivariateDistribution, xs::AbstractVector{Floa
@test_approx_eq_eps var(d) xvar 5.0 * vd * (kd + 2) / sqrt(n)
end
end

# test entropy
if applicable(entropy, d)
xentropy = mean(-logpdf(d, xs))
dentropy = entropy(d)
@test_approx_eq_eps dentropy xentropy 1.0e-2 * (abs(dentropy) + 1.0e-3)
end
end

112 changes: 56 additions & 56 deletions src/truncate.jl
Original file line number Diff line number Diff line change
@@ -1,102 +1,102 @@

immutable Truncated{D<:UnivariateDistribution,S<:ValueSupport} <: Distribution{Univariate,S}
untruncated::D
lower::Float64
upper::Float64
nc::Float64
function Truncated{T<:UnivariateDistribution}(d::T, l::Real, u::Real, nc::Real)
if l >= u
error("upper must be > lower")
end
new(d, float64(l), float64(u), float64(nc))
end
end
lower::Float64 # lower bound
upper::Float64 # upper bound
lcdf::Float64 # cdf of lower bound
ucdf::Float64 # cdf of upper bound

function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Real, u::Real, nc::Real)
Truncated{typeof(d),S}(d,l,u,nc)
tp::Float64 # the probability of the truncated part, i.e. ucdf - lcdf
logtp::Float64 # log(tp), i.e. log(ucdf - lcdf)
end

function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Real, u::Real)
Truncated{typeof(d),S}(d,l,u, cdf(d, u) - cdf(d, l))
### Constructors

function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Float64, u::Float64)
l < u || error("lower bound should be less than upper bound.")
lcdf = isinf(l) ? 0.0 : cdf(d, l)
ucdf = isinf(u) ? 1.0 : cdf(d, u)
tp = ucdf - lcdf
Truncated{typeof(d),S}(d, l, u, lcdf, ucdf, tp, log(tp))
end

insupport(d::Truncated, x::Real) =
x >= d.lower && x <= d.upper && insupport(d.untruncated, x)
Truncated(d::UnivariateDistribution, l::Real, u::Real) = Truncated(d, float64(l), float64(u))


### range and support

islowerbounded(d::Truncated) = islowerbounded(d.untruncated) || isfinite(d.lower)
isupperbounded(d::Truncated) = isupperbounded(d.untruncated) || isfinite(d.upper)

minimum(d::Truncated) = max(minimum(d.untruncated), d.lower)
maximum(d::Truncated) = min(maximum(d.untruncated), d.upper)

function pdf(d::Truncated, x::Real)
if !insupport(d, x)
return 0.0
else
return pdf(d.untruncated, x) / d.nc
end
end
insupport(d::Truncated, x::Real) =
d.lower <= x <= d.upper && insupport(d.untruncated, x)

function logpdf(d::Truncated, x::Real)
if !insupport(d, x)
return -Inf
else
return logpdf(d.untruncated, x) - log(d.nc)
end
end

function cdf(d::Truncated, x::Real)
if x < d.lower
return 0.0
elseif x > d.upper
return 1.0
else
return (cdf(d.untruncated, x) - cdf(d.untruncated, d.lower)) / d.nc
end
end
### evaluation

pdf(d::Truncated, x::Real) = d.lower <= x <= d.upper ? pdf(d.untruncated, x) / d.tp : 0.0

logpdf(d::Truncated, x::Real) = d.lower <= x <= d.upper ? logpdf(d.untruncated, x) - d.logtp : -Inf

cdf(d::Truncated, x::Real) = x <= d.lower ? 0.0 :
x >= d.upper ? 1.0 :
(cdf(d.untruncated, x) - d.lcdf) / d.tp

logcdf(d::Truncated, x::Real) = x <= d.lower ? -Inf :
x >= d.upper ? 0.0 :
log(cdf(d.untruncated, x) - d.lcdf) - d.logtp

ccdf(d::Truncated, x::Real) = x <= d.lower ? 1.0 :
x >= d.upper ? 0.0 :
(d.ucdf - cdf(d.untruncated, x)) / d.tp

logccdf(d::Truncated, x::Real) = x <= d.lower ? 0.0 :
x >= d.upper ? -Inf :
log(d.ucdf - cdf(d.untruncated, x)) - d.logtp

function quantile(d::Truncated, p::Real)
top = cdf(d.untruncated, d.upper)
bottom = cdf(d.untruncated, d.lower)
return quantile(d.untruncated, bottom + p * (top - bottom))
end

median(d::Truncated) = quantile(d, 0.5)
quantile(d::Truncated, p::Real) = quantile(d.untruncated, d.lcdf + p * d.tp)


## random number generation

function rand(d::Truncated)
if d.nc > 0.25
d0 = d.untruncated
if d.tp > 0.25
while true
r = rand(d.untruncated)
r = rand(d0)
if d.lower <= r <= d.upper
return r
end
end
else
return quantile(d.untruncated, cdf(d.untruncated, d.lower) + rand() * d.nc)
return quantile(d0, d.lcdf + rand() * d.tp)
end
end

# from fallbacks
function rand{D<:ContinuousUnivariateDistribution}(d::Truncated{D}, dims::Dims)
return rand!(d, Array(Float64, dims))
end

function rand{D<:DiscreteUnivariateDistribution}(d::Truncated{D}, dims::Dims)
return rand!(d, Array(Int, dims))
end
## show

function show(io::IO, d::Truncated)
print(io, "Truncated(")
d0 = d.untruncated
uml, namevals = _use_multline_show(d0)
uml ? show_multline(io, d0, namevals) :
show_oneline(io, d0, namevals)
print(io, ", range = ($(d.lower), $(d.upper)), prop = $(d.nc) )")
print(io, ", range=($(d.lower), $(d.upper)), tp=$(d.tp))")
uml && println(io)
end

_use_multline_show(d::Truncated) = _use_multline_show(d.untruncated)


### specialized truncated distributions

include(joinpath("truncated", "normal.jl"))




128 changes: 128 additions & 0 deletions src/truncated/normal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Truncated normal distribution

TruncatedNormal(mu::Float64, sigma::Float64, a::Float64, b::Float64) =
Truncated(Normal(mu, sigma), a, b)

TruncatedNormal(mu::Real, sigma::Real, a::Real, b::Real) =
TruncatedNormal(float64(mu), float64(sigma), float64(a), float64(b))

### statistics

minimum(d::Truncated{Normal}) = d.lower
maximum(d::Truncated{Normal}) = d.upper


function mode(d::Truncated{Normal})
μ = mean(d.untruncated)
d.upper < mu ? d.upper :
d.lower > mu ? d.lower : μ
end

modes(d::Truncated{Normal}) = [mode(d)]


function mean(d::Truncated{Normal})
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
μ + ((φ(a) - φ(b)) / d.tp) * σ
end

function var(d::Truncated{Normal})
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
z = d.tp
φa = φ(a)
φb = φ(b)
aφa = isinf(a) ? 0.0 : a * φa
bφb = isinf(b) ? 0.0 : b * φb
t1 = (aφa - bφb) / z
t2 = abs2((φa - φb) / z)
abs2(σ) * (1 + t1 - t2)
end

function entropy(d::Truncated{Normal})
d0 = d.untruncated
z = d.tp
μ = mean(d0)
σ = std(d0)
a = (d.lower - μ) / σ
b = (d.upper - μ) / σ
aφa = isinf(a) ? 0.0 : a * φ(a)
bφb = isinf(b) ? 0.0 : b * φ(b)
0.5 * (log2π + 1.) + log* z) + (aφa - bφb) / (2.0 * z)
end


### sampling

## Benchmarks doesn't seem to show that this specialized
## sampler is faster than the generic quantile-based method

# function rand(d::Truncated{Normal})
# d0 = d.untruncated
# μ = mean(d0)
# σ = std(d0)
# a = (d.lower - μ) / σ
# b = (d.upper - μ) / σ
# z = randnt(a, b, d.tp)
# return μ + σ * z
# end

# Rejection sampler based on algorithm from Robert (1995)
#
# - Available at http://arxiv.org/abs/0907.4010
#
# function randnt(lb::Float64, ub::Float64, tp::Float64)
# r::Float64
# if tp > 0.3 # has considerable chance of falling in [lb, ub]
# r = randn()
# while r < lb || r > ub
# r = randn()
# end
# return r

# else
# span = ub - lb
# if lb > 0 && span > 2.0 / (lb + sqrt(lb^2 + 4.0)) * exp((lb^2 - lb * sqrt(lb^2 + 4.0)) / 4.0)
# a = (lb + sqrt(lb^2 + 4.0))/2.0
# while true
# r = rand(Exponential(1.0 / a)) + lb
# u = rand()
# if u < exp(-0.5 * (r - a)^2) && r < ub
# return r
# end
# end
# elseif ub < 0 && ub - lb > 2.0 / (-ub + sqrt(ub^2 + 4.0)) * exp((ub^2 + ub * sqrt(ub^2 + 4.0)) / 4.0)
# a = (-ub + sqrt(ub^2 + 4.0)) / 2.0
# while true
# r = rand(Exponential(1.0 / a)) - ub
# u = rand()
# if u < exp(-0.5 * (r - a)^2) && r < -lb
# return -r
# end
# end
# else
# while true
# r = lb + rand() * (ub - lb)
# u = rand()
# if lb > 0
# rho = exp((lb^2 - r^2) * 0.5)
# elseif ub < 0
# rho = exp((ub^2 - r^2) * 0.5)
# else
# rho = exp(-r^2 * 0.5)
# end
# if u < rho
# return r
# end
# end
# end
# end
# end

Loading

0 comments on commit eec38f3

Please sign in to comment.