-
Notifications
You must be signed in to change notification settings - Fork 419
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #295 from JuliaStats/dh/trunc2
Reimplement Truncated and Truncated{Normal}. Fix #292
- Loading branch information
Showing
10 changed files
with
213 additions
and
192 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,102 +1,102 @@ | ||
|
||
immutable Truncated{D<:UnivariateDistribution,S<:ValueSupport} <: Distribution{Univariate,S} | ||
untruncated::D | ||
lower::Float64 | ||
upper::Float64 | ||
nc::Float64 | ||
function Truncated{T<:UnivariateDistribution}(d::T, l::Real, u::Real, nc::Real) | ||
if l >= u | ||
error("upper must be > lower") | ||
end | ||
new(d, float64(l), float64(u), float64(nc)) | ||
end | ||
end | ||
lower::Float64 # lower bound | ||
upper::Float64 # upper bound | ||
lcdf::Float64 # cdf of lower bound | ||
ucdf::Float64 # cdf of upper bound | ||
|
||
function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Real, u::Real, nc::Real) | ||
Truncated{typeof(d),S}(d,l,u,nc) | ||
tp::Float64 # the probability of the truncated part, i.e. ucdf - lcdf | ||
logtp::Float64 # log(tp), i.e. log(ucdf - lcdf) | ||
end | ||
|
||
function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Real, u::Real) | ||
Truncated{typeof(d),S}(d,l,u, cdf(d, u) - cdf(d, l)) | ||
### Constructors | ||
|
||
function Truncated{S<:ValueSupport}(d::UnivariateDistribution{S}, l::Float64, u::Float64) | ||
l < u || error("lower bound should be less than upper bound.") | ||
lcdf = isinf(l) ? 0.0 : cdf(d, l) | ||
ucdf = isinf(u) ? 1.0 : cdf(d, u) | ||
tp = ucdf - lcdf | ||
Truncated{typeof(d),S}(d, l, u, lcdf, ucdf, tp, log(tp)) | ||
end | ||
|
||
insupport(d::Truncated, x::Real) = | ||
x >= d.lower && x <= d.upper && insupport(d.untruncated, x) | ||
Truncated(d::UnivariateDistribution, l::Real, u::Real) = Truncated(d, float64(l), float64(u)) | ||
|
||
|
||
### range and support | ||
|
||
islowerbounded(d::Truncated) = islowerbounded(d.untruncated) || isfinite(d.lower) | ||
isupperbounded(d::Truncated) = isupperbounded(d.untruncated) || isfinite(d.upper) | ||
|
||
minimum(d::Truncated) = max(minimum(d.untruncated), d.lower) | ||
maximum(d::Truncated) = min(maximum(d.untruncated), d.upper) | ||
|
||
function pdf(d::Truncated, x::Real) | ||
if !insupport(d, x) | ||
return 0.0 | ||
else | ||
return pdf(d.untruncated, x) / d.nc | ||
end | ||
end | ||
insupport(d::Truncated, x::Real) = | ||
d.lower <= x <= d.upper && insupport(d.untruncated, x) | ||
|
||
function logpdf(d::Truncated, x::Real) | ||
if !insupport(d, x) | ||
return -Inf | ||
else | ||
return logpdf(d.untruncated, x) - log(d.nc) | ||
end | ||
end | ||
|
||
function cdf(d::Truncated, x::Real) | ||
if x < d.lower | ||
return 0.0 | ||
elseif x > d.upper | ||
return 1.0 | ||
else | ||
return (cdf(d.untruncated, x) - cdf(d.untruncated, d.lower)) / d.nc | ||
end | ||
end | ||
### evaluation | ||
|
||
pdf(d::Truncated, x::Real) = d.lower <= x <= d.upper ? pdf(d.untruncated, x) / d.tp : 0.0 | ||
|
||
logpdf(d::Truncated, x::Real) = d.lower <= x <= d.upper ? logpdf(d.untruncated, x) - d.logtp : -Inf | ||
|
||
cdf(d::Truncated, x::Real) = x <= d.lower ? 0.0 : | ||
x >= d.upper ? 1.0 : | ||
(cdf(d.untruncated, x) - d.lcdf) / d.tp | ||
|
||
logcdf(d::Truncated, x::Real) = x <= d.lower ? -Inf : | ||
x >= d.upper ? 0.0 : | ||
log(cdf(d.untruncated, x) - d.lcdf) - d.logtp | ||
|
||
ccdf(d::Truncated, x::Real) = x <= d.lower ? 1.0 : | ||
x >= d.upper ? 0.0 : | ||
(d.ucdf - cdf(d.untruncated, x)) / d.tp | ||
|
||
logccdf(d::Truncated, x::Real) = x <= d.lower ? 0.0 : | ||
x >= d.upper ? -Inf : | ||
log(d.ucdf - cdf(d.untruncated, x)) - d.logtp | ||
|
||
function quantile(d::Truncated, p::Real) | ||
top = cdf(d.untruncated, d.upper) | ||
bottom = cdf(d.untruncated, d.lower) | ||
return quantile(d.untruncated, bottom + p * (top - bottom)) | ||
end | ||
|
||
median(d::Truncated) = quantile(d, 0.5) | ||
quantile(d::Truncated, p::Real) = quantile(d.untruncated, d.lcdf + p * d.tp) | ||
|
||
|
||
## random number generation | ||
|
||
function rand(d::Truncated) | ||
if d.nc > 0.25 | ||
d0 = d.untruncated | ||
if d.tp > 0.25 | ||
while true | ||
r = rand(d.untruncated) | ||
r = rand(d0) | ||
if d.lower <= r <= d.upper | ||
return r | ||
end | ||
end | ||
else | ||
return quantile(d.untruncated, cdf(d.untruncated, d.lower) + rand() * d.nc) | ||
return quantile(d0, d.lcdf + rand() * d.tp) | ||
end | ||
end | ||
|
||
# from fallbacks | ||
function rand{D<:ContinuousUnivariateDistribution}(d::Truncated{D}, dims::Dims) | ||
return rand!(d, Array(Float64, dims)) | ||
end | ||
|
||
function rand{D<:DiscreteUnivariateDistribution}(d::Truncated{D}, dims::Dims) | ||
return rand!(d, Array(Int, dims)) | ||
end | ||
## show | ||
|
||
function show(io::IO, d::Truncated) | ||
print(io, "Truncated(") | ||
d0 = d.untruncated | ||
uml, namevals = _use_multline_show(d0) | ||
uml ? show_multline(io, d0, namevals) : | ||
show_oneline(io, d0, namevals) | ||
print(io, ", range = ($(d.lower), $(d.upper)), prop = $(d.nc) )") | ||
print(io, ", range=($(d.lower), $(d.upper)), tp=$(d.tp))") | ||
uml && println(io) | ||
end | ||
|
||
_use_multline_show(d::Truncated) = _use_multline_show(d.untruncated) | ||
|
||
|
||
### specialized truncated distributions | ||
|
||
include(joinpath("truncated", "normal.jl")) | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Truncated normal distribution | ||
|
||
TruncatedNormal(mu::Float64, sigma::Float64, a::Float64, b::Float64) = | ||
Truncated(Normal(mu, sigma), a, b) | ||
|
||
TruncatedNormal(mu::Real, sigma::Real, a::Real, b::Real) = | ||
TruncatedNormal(float64(mu), float64(sigma), float64(a), float64(b)) | ||
|
||
### statistics | ||
|
||
minimum(d::Truncated{Normal}) = d.lower | ||
maximum(d::Truncated{Normal}) = d.upper | ||
|
||
|
||
function mode(d::Truncated{Normal}) | ||
μ = mean(d.untruncated) | ||
d.upper < mu ? d.upper : | ||
d.lower > mu ? d.lower : μ | ||
end | ||
|
||
modes(d::Truncated{Normal}) = [mode(d)] | ||
|
||
|
||
function mean(d::Truncated{Normal}) | ||
d0 = d.untruncated | ||
μ = mean(d0) | ||
σ = std(d0) | ||
a = (d.lower - μ) / σ | ||
b = (d.upper - μ) / σ | ||
μ + ((φ(a) - φ(b)) / d.tp) * σ | ||
end | ||
|
||
function var(d::Truncated{Normal}) | ||
d0 = d.untruncated | ||
μ = mean(d0) | ||
σ = std(d0) | ||
a = (d.lower - μ) / σ | ||
b = (d.upper - μ) / σ | ||
z = d.tp | ||
φa = φ(a) | ||
φb = φ(b) | ||
aφa = isinf(a) ? 0.0 : a * φa | ||
bφb = isinf(b) ? 0.0 : b * φb | ||
t1 = (aφa - bφb) / z | ||
t2 = abs2((φa - φb) / z) | ||
abs2(σ) * (1 + t1 - t2) | ||
end | ||
|
||
function entropy(d::Truncated{Normal}) | ||
d0 = d.untruncated | ||
z = d.tp | ||
μ = mean(d0) | ||
σ = std(d0) | ||
a = (d.lower - μ) / σ | ||
b = (d.upper - μ) / σ | ||
aφa = isinf(a) ? 0.0 : a * φ(a) | ||
bφb = isinf(b) ? 0.0 : b * φ(b) | ||
0.5 * (log2π + 1.) + log(σ * z) + (aφa - bφb) / (2.0 * z) | ||
end | ||
|
||
|
||
### sampling | ||
|
||
## Benchmarks doesn't seem to show that this specialized | ||
## sampler is faster than the generic quantile-based method | ||
|
||
# function rand(d::Truncated{Normal}) | ||
# d0 = d.untruncated | ||
# μ = mean(d0) | ||
# σ = std(d0) | ||
# a = (d.lower - μ) / σ | ||
# b = (d.upper - μ) / σ | ||
# z = randnt(a, b, d.tp) | ||
# return μ + σ * z | ||
# end | ||
|
||
# Rejection sampler based on algorithm from Robert (1995) | ||
# | ||
# - Available at http://arxiv.org/abs/0907.4010 | ||
# | ||
# function randnt(lb::Float64, ub::Float64, tp::Float64) | ||
# r::Float64 | ||
# if tp > 0.3 # has considerable chance of falling in [lb, ub] | ||
# r = randn() | ||
# while r < lb || r > ub | ||
# r = randn() | ||
# end | ||
# return r | ||
|
||
# else | ||
# span = ub - lb | ||
# if lb > 0 && span > 2.0 / (lb + sqrt(lb^2 + 4.0)) * exp((lb^2 - lb * sqrt(lb^2 + 4.0)) / 4.0) | ||
# a = (lb + sqrt(lb^2 + 4.0))/2.0 | ||
# while true | ||
# r = rand(Exponential(1.0 / a)) + lb | ||
# u = rand() | ||
# if u < exp(-0.5 * (r - a)^2) && r < ub | ||
# return r | ||
# end | ||
# end | ||
# elseif ub < 0 && ub - lb > 2.0 / (-ub + sqrt(ub^2 + 4.0)) * exp((ub^2 + ub * sqrt(ub^2 + 4.0)) / 4.0) | ||
# a = (-ub + sqrt(ub^2 + 4.0)) / 2.0 | ||
# while true | ||
# r = rand(Exponential(1.0 / a)) - ub | ||
# u = rand() | ||
# if u < exp(-0.5 * (r - a)^2) && r < -lb | ||
# return -r | ||
# end | ||
# end | ||
# else | ||
# while true | ||
# r = lb + rand() * (ub - lb) | ||
# u = rand() | ||
# if lb > 0 | ||
# rho = exp((lb^2 - r^2) * 0.5) | ||
# elseif ub < 0 | ||
# rho = exp((ub^2 - r^2) * 0.5) | ||
# else | ||
# rho = exp(-r^2 * 0.5) | ||
# end | ||
# if u < rho | ||
# return r | ||
# end | ||
# end | ||
# end | ||
# end | ||
# end | ||
|
Oops, something went wrong.