Skip to content

Commit

Permalink
Allow specifying non-existant bounds to truncated using nothing (#1489)
Browse files Browse the repository at this point in the history
* Add kwarg method

* Add fallbacks to Truncated

* Document the method generically

* Expand truncated docstring

* Raise deprecation warning if infinite bounds provided

* Improve warning

* Fix warning for older Julia versions

* Remove deprecated method calls

* Use new syntax in show method

* Remove last call to deprecated function

* Add tests for variants of truncated

* Add missing tests

* Add LogUniform tests

* Add Uniform and LogUniform overloads

* Improve documentation

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Fix typo

* Test deprecation warnings

* Test show methods

* Use `@test_deprecated`

* Update src/truncate.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Don't force deprecation warnings

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
sethaxen and devmotion authored Jan 29, 2022
1 parent e188865 commit 02bcbf8
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 26 deletions.
78 changes: 71 additions & 7 deletions src/truncate.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,78 @@
"""
truncated(d::UnivariateDistribution, l::Real, u::Real)
truncated(d0::UnivariateDistribution; [lower::Real], [upper::Real])
truncated(d0::UnivariateDistribution, lower::Real, upper::Real)
Truncate a univariate distribution `d` to the interval `[l, u]`.
A _truncated distribution_ `d` of a distribution `d0` to the interval
``[l, u]=```[lower, upper]` has the probability density (mass) function:
The lower bound `l` can be finite or `-Inf` and the upper bound `u` can be finite or
`Inf`. The function throws an error if `l > u`.
```math
f(x; d_0, l, u) = \\frac{f_{d_0}(x)}{P_{Z \\sim d_0}(l \\le Z \\le u)}, \\quad x \\in [l, u],
```
where ``f_{d_0}(x)`` is the probability density (mass) function of ``d_0``.
The function throws an error if ``l > u``.
```julia
truncated(d0; lower=l) # d0 left-truncated to the interval [l, Inf)
truncated(d0; upper=u) # d0 right-truncated to the interval (-Inf, u]
truncated(d0; lower=l, upper=u) # d0 truncated to the interval [l, u]
truncated(d0, l, u) # d0 truncated to the interval [l, u]
```
The function falls back to constructing a [`Truncated`](@ref) wrapper.
# Implementation
To implement a specialized truncated form for distributions of type `D`, the method
`truncated(d::D, l::T, u::T) where {T <: Real}` should be implemented.
To implement a specialized truncated form for distributions of type `D`, one or more of the
following methods should be implemented:
- `truncated(d0::D, l::T, u::T) where {T <: Real}`: interval-truncated
- `truncated(d0::D, ::Nothing, u::Real)`: right-truncated
- `truncated(d0::D, l::Real, u::Nothing)`: left-truncated
"""
function truncated end
function truncated(d::UnivariateDistribution, l::Real, u::Real)
return truncated(d, promote(l, u)...)
end
function truncated(
d::UnivariateDistribution;
lower::Union{Real,Nothing}=nothing,
upper::Union{Real,Nothing}=nothing,
)
return truncated(d, lower, upper)
end
function truncated(d::UnivariateDistribution, ::Nothing, u::Real)
# (log)ucdf = (log)tp = (log) P(X ≤ u) where X ~ d
logucdf = logtp = logcdf(d, u)
ucdf = tp = exp(logucdf)

Truncated(d, promote(oftype(float(u), -Inf), u, zero(ucdf), ucdf, tp, logtp)...)
end
function truncated(d::UnivariateDistribution, l::Real, ::Nothing)
# (log)lcdf = (log) P(X < l) where X ~ d
loglcdf = if value_support(typeof(d)) === Discrete
logsubexp(logcdf(d, l), logpdf(d, l))
else
logcdf(d, l)
end
lcdf = exp(loglcdf)

# (log)tp = (log) P(l ≤ X) where X ∼ d
logtp = log1mexp(loglcdf)
tp = exp(logtp)

Truncated(d, promote(l, oftype(float(l), Inf), lcdf, one(lcdf), tp, logtp)...)
end
truncated(d::UnivariateDistribution, ::Nothing, ::Nothing) = d
function truncated(d::UnivariateDistribution, l::T, u::T) where {T <: Real}
l <= u || error("the lower bound must be less or equal than the upper bound")
l == -Inf && Base.depwarn(
"`truncated(d, -Inf, u)` is deprecated. Please use `truncated(d; upper=u)` instead.",
:truncated,
)
u == Inf && Base.depwarn(
"`truncated(d, l, Inf)` is deprecated. Please use `truncated(d; lower=l)` instead.",
:truncated,
)

# (log)lcdf = (log) P(X < l) where X ~ d
loglcdf = if value_support(typeof(d)) === Discrete
Expand Down Expand Up @@ -160,7 +214,17 @@ function show(io::IO, d::Truncated)
uml, namevals = _use_multline_show(d0)
uml ? show_multline(io, d0, namevals) :
show_oneline(io, d0, namevals)
print(io, ", range=($(d.lower), $(d.upper)))")
if d.lower > -Inf
if d.upper < Inf
print(io, "; lower=$(d.lower), upper=$(d.upper))")
else
print(io, "; lower=$(d.lower))")
end
elseif d.upper < Inf
print(io, "; upper=$(d.upper))")
else
print(io, ")")
end
uml && println(io)
end

Expand Down
2 changes: 2 additions & 0 deletions src/truncated/loguniform.jl
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
truncated(d::LogUniform, lo::T, hi::T) where {T<:Real} = LogUniform(max(d.a, lo), min(d.b, hi))
truncated(d::LogUniform, lo::Real, ::Nothing) = LogUniform(max(d.a, lo), d.b)
truncated(d::LogUniform, ::Nothing, hi::Real) = LogUniform(d.a, min(d.b, hi))
2 changes: 2 additions & 0 deletions src/truncated/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
#####

truncated(d::Uniform, l::T, u::T) where {T <: Real} = Uniform(max(l, d.a), min(u, d.b))
truncated(d::Uniform, l::Real, ::Nothing) = Uniform(max(l, d.a), d.b)
truncated(d::Uniform, ::Nothing, u::Real) = Uniform(d.a, min(u, d.b))
3 changes: 3 additions & 0 deletions test/loguniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ import Random
@test @inferred(maximum(d)) === 10
@test partype(d) === Int
@test truncated(d, 2, 14) === LogUniform(2,10)
@test truncated(d, 0, 8) === LogUniform(1, 8)
@test truncated(d; upper=8) === LogUniform(1, 8)
@test truncated(d; lower=3) === LogUniform(3, 10)

# numbers obtained by calling scipy.stats.loguniform
@test @inferred(std(d) ) 2.49399867607628
Expand Down
6 changes: 3 additions & 3 deletions test/ref/continuous_test.ref.json
Original file line number Diff line number Diff line change
Expand Up @@ -4702,7 +4702,7 @@
]
},
{
"expr": "truncated(Normal(27, 3), 0, Inf)",
"expr": "truncated(Normal(27, 3); lower=0)",
"dtype": "TruncatedNormal",
"minimum": 0,
"maximum": "inf",
Expand Down Expand Up @@ -4732,7 +4732,7 @@
]
},
{
"expr": "truncated(Normal(-5, 1), -Inf, -10)",
"expr": "truncated(Normal(-5, 1); upper=-10)",
"dtype": "TruncatedNormal",
"minimum": "-inf",
"maximum": -10,
Expand Down Expand Up @@ -4762,7 +4762,7 @@
]
},
{
"expr": "truncated(Normal(1.8, 1.2), -Inf, 0)",
"expr": "truncated(Normal(1.8, 1.2); upper=0)",
"dtype": "TruncatedNormal",
"minimum": "-inf",
"maximum": 0,
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const tests = [
"truncate",
"truncnormal",
"truncated_exponential",
"truncated_uniform",
"normal",
"laplace",
"cauchy",
Expand Down
23 changes: 22 additions & 1 deletion test/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ function verify_and_test(d::UnivariateDistribution, dct::Dict, n_tsamples::Int)
end
end

# default methods
for (μ, lower, upper) in [(0, -1, 1), (1, 2, 4)]
d = truncated(Normal(μ, 1), lower, upper)
@test d.untruncated === Normal(μ, 1)
@test d.lower == lower
@test d.upper == upper
@test truncated(Normal(μ, 1); lower=lower, upper=upper) === d
end
@test truncated(Normal(); lower=1) == Distributions.Truncated(Normal(), 1.0, Inf)
@test truncated(Normal(); lower=-2) == Distributions.Truncated(Normal(), -2.0, Inf)
@test truncated(Normal(); upper=1) == Distributions.Truncated(Normal(), -Inf, 1.0)
@test truncated(Normal(); upper=2) == Distributions.Truncated(Normal(), -Inf, 2.0)
@test truncated(Normal()) === Normal()
@test_deprecated truncated(Normal(), -Inf, 2)
@test_deprecated truncated(Normal(), 2, Inf)

## main

Expand Down Expand Up @@ -152,7 +167,7 @@ at = [0.0, 1.0, 0.0, 1.0]
@testset "#1328" begin
dist = Poisson(2.0)
dist_zeroinflated = MixtureModel([Dirac(0.0), dist], [0.4, 0.6])
dist_zerotruncated = truncated(dist, 1, Inf)
dist_zerotruncated = truncated(dist; lower=1)
dist_zeromodified = MixtureModel([Dirac(0.0), dist_zerotruncated], [0.4, 0.6])

@test logsumexp(logpdf(dist, x) for x in 0:1000) 0 atol=1e-15
Expand All @@ -161,3 +176,9 @@ at = [0.0, 1.0, 0.0, 1.0]
@test logsumexp(logpdf(dist_zeromodified, x) for x in 0:1000) 0 atol=1e-15
end
end

@testset "show" begin
@test sprint(show, "text/plain", truncated(Normal(); lower=2.0)) == "Truncated($(Normal()); lower=2.0)"
@test sprint(show, "text/plain", truncated(Normal(); upper=3.0)) == "Truncated($(Normal()); upper=3.0)"
@test sprint(show, "text/plain", truncated(Normal(), 2.0, 3.0)) == "Truncated($(Normal()); lower=2.0, upper=3.0)"
end
8 changes: 4 additions & 4 deletions test/truncated_exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ using Distributions, Random, Test
d = Exponential(1.5)
l = 1.2
r = 2.7
@test mean(d) mean(truncated(d, -3.0, Inf)) # non-binding truncation
@test mean(truncated(d, l, Inf)) mean(d) + l
@test mean(d) mean(truncated(d; lower=-3.0)) # non-binding truncation
@test mean(truncated(d; lower=l)) mean(d) + l
# test values below calculated using symbolic integration in Maxima
@test mean(truncated(d, 0, r)) 0.9653092084094841
@test mean(truncated(d, l, r)) 1.82703493969601

# all the fun corner cases and numerical quirks
@test mean(truncated(Exponential(1.0), -Inf, 0)) == 0 # degenerate
@test mean(truncated(Exponential(1.0), -Inf, 0+eps())) 0 atol = eps() # near-degenerate
@test mean(truncated(Exponential(1.0); upper=0)) == 0 # degenerate
@test mean(truncated(Exponential(1.0); upper=0+eps())) 0 atol = eps() # near-degenerate
@test mean(truncated(Exponential(1.0), 1.0, 1.0+eps())) 1.0 # near-degenerate
@test mean(truncated(Exponential(1e308), 1.0, 1.0+eps())) 1.0 # near-degenerate
end
8 changes: 4 additions & 4 deletions test/truncated_uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ using Distributions, Test
@testset "truncated uniform" begin
# just test equivalence of truncation results
u = Uniform(1, 2)
@test truncated(u, -Inf, Inf) == u
@test truncated(u, 0, Inf) == u
@test truncated(u, -Inf, 2.1) == u
@test truncated(u, 1.1, Inf) == Uniform(1.1, 2)
@test truncated(u) === u
@test truncated(u; lower=0) == u
@test truncated(u; upper=2.1) == u
@test truncated(u; lower=1.1) == Uniform(1.1, 2)
@test truncated(u, 1.1, 2.1) == Uniform(1.1, 2)
@test truncated(u, 1.1, 1.9) == Uniform(1.1, 1.9)
end
14 changes: 7 additions & 7 deletions test/truncnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ rng = MersenneTwister(123)
@test mean(truncated(Normal(0,1),100,115)) 100.00999800099926070518490239457545847490332879043
@test mean(truncated(Normal(-2,3),50,70)) 50.171943499898757645751683644632860837133138152489
@test mean(truncated(Normal(0,2),-100,0)) -1.59576912160573071175978423973752747390343452465973
@test mean(truncated(Normal(0,1),-Inf,Inf)) == 0
@test mean(truncated(Normal(0,1),0,+Inf)) +√(2/π)
@test mean(truncated(Normal(0,1),-Inf,0)) -√(2/π)
@test mean(truncated(Normal(0,1))) == 0
@test mean(truncated(Normal(0,1); lower=0)) +√(2/π)
@test mean(truncated(Normal(0,1); upper=0)) -√(2/π)
@test var(truncated(Normal(0,1),50,70)) 0.00039904318680389954790992722653605933053648912703600
@test var(truncated(Normal(-2,3),50,70)) 0.029373438107168350377591231295634273607812172191712
@test var(truncated(Normal(0,1),-Inf,Inf)) == 1
@test var(truncated(Normal(0,1),0,+Inf)) 1 - 2/π
@test var(truncated(Normal(0,1),-Inf,0)) 1 - 2/π
@test var(truncated(Normal(-2,3); lower=50, upper=70)) 0.029373438107168350377591231295634273607812172191712
@test var(truncated(Normal(0,1))) == 1
@test var(truncated(Normal(0,1); lower=0)) 1 - 2/π
@test var(truncated(Normal(0,1); upper=0)) 1 - 2/π
# https://github.com/JuliaStats/Distributions.jl/issues/827
@test mean(truncated(Normal(1000000,1),0,1000)) 999.999998998998999001005011019018990904720462367106
@test var(truncated(Normal(),999000,1e6)) 0
Expand Down

0 comments on commit 02bcbf8

Please sign in to comment.