Skip to content

Commit

Permalink
Add minimum, maximum, extrema for AbstractMvNormal and `Produ…
Browse files Browse the repository at this point in the history
…ct` (#1319)

* adds minimum, maximum, extrema to AbstractMvNormal and Product distributions

* add docstrings and update docs

* update docstrrings

* add isless tests for extrema

* clean-up, remove broadcast in favor of map

* move/update minimum/maximum fallback def/docstring

Co-authored-by: Adam R Gerlach <adam.gerlach.1@afresearchlab.com>
  • Loading branch information
agerlach and Adam R Gerlach authored Oct 9, 2021
1 parent 2a7a651 commit 625b722
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 22 deletions.
4 changes: 4 additions & 0 deletions docs/src/multivariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ invcov(::Distributions.AbstractMvNormal)
logdetcov(::Distributions.AbstractMvNormal)
sqmahal(::Distributions.AbstractMvNormal, ::AbstractArray)
rand(::AbstractRNG, ::Distributions.AbstractMvNormal)
minimum(::Distributions.AbstractMvNormal)
maximum(::Distributions.AbstractMvNormal)
extrema(::Distributions.AbstractMvNormal)
```


### MvLogNormal

In addition to the methods listed in the common interface above, we also provide the following methods:
Expand Down
20 changes: 20 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,26 @@ value_support(::Type{<:Distribution{VF,VS}}) where {VF,VS} = VS
# to be decided: how to handle multivariate/matrixvariate distributions?
Broadcast.broadcastable(d::UnivariateDistribution) = Ref(d)

"""
minimum(d::Distribution)
Return the minimum of the support of `d`.
"""
minimum(d::Distribution)

"""
maximum(d::Distribution)
Return the maximum of the support of `d`.
"""
maximum(d::Distribution)

"""
extrema(d::Distribution)
Return the minimum and maximum of the support of `d` as a 2-tuple.
"""
Base.extrema(d::Distribution) = minimum(d), maximum(d)

## TODO: the following types need to be improved
abstract type SufficientStats end
Expand Down
2 changes: 2 additions & 0 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ abstract type AbstractMvNormal <: ContinuousMultivariateDistribution end
insupport(d::AbstractMvNormal, x::AbstractVector) =
length(d) == length(x) && all(isfinite, x)

minimum(d::AbstractMvNormal) = fill(eltype(d)(-Inf), length(d))
maximum(d::AbstractMvNormal) = fill(eltype(d)(Inf), length(d))
mode(d::AbstractMvNormal) = mean(d)
modes(d::AbstractMvNormal) = [mean(d)]

Expand Down
2 changes: 2 additions & 0 deletions src/multivariate/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ var(d::Product) = var.(d.v)
cov(d::Product) = Diagonal(var(d))
entropy(d::Product) = sum(entropy, d.v)
insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x))
minimum(d::Product) = map(minimum, d.v)
maximum(d::Product) = map(maximum, d.v)

"""
product_distribution(dists::AbstractVector{<:UnivariateDistribution})
Expand Down
21 changes: 0 additions & 21 deletions src/univariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,6 @@ Get the degrees of freedom.
"""
dof(d::UnivariateDistribution)

"""
minimum(d::UnivariateDistribution)
Return the minimum of the support of `d`.
"""
minimum(d::UnivariateDistribution)

"""
maximum(d::UnivariateDistribution)
Return the maximum of the support of `d`.
"""
maximum(d::UnivariateDistribution)

"""
extrema(d::UnivariateDistribution)
Return the minimum and maximum of the support of `d` as a 2-tuple.
"""
extrema(d::UnivariateDistribution) = (minimum(d), maximum(d))

"""
insupport(d::UnivariateDistribution, x::Any)
Expand Down
4 changes: 4 additions & 0 deletions test/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ function test_mvnormal(g::AbstractMvNormal, n_tsamples::Int=10^6,
vs = diag(Σ)
@test g == typeof(g)(params(g)...)
@test g == deepcopy(g)
@test minimum(g) == fill(-Inf, d)
@test maximum(g) == fill(Inf, d)
@test extrema(g) == (minimum(g), maximum(g))
@test isless(extrema(g)...)

# test sampling for AbstractMatrix (here, a SubArray):
if ismissing(rng)
Expand Down
6 changes: 5 additions & 1 deletion test/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
N = 11
# Construct independent distributions and `Product` distribution from these.
ubound = rand(N)
ds = Uniform.(0.0, ubound)
ds = Uniform.(-ubound, ubound)
x = rand.(ds)
d_product = product_distribution(ds)
@test d_product isa Product
Expand All @@ -43,6 +43,10 @@ end
@test entropy(d_product) == sum(entropy.(ds))
@test insupport(d_product, ubound) == true
@test insupport(d_product, ubound .+ 1) == false
@test minimum(d_product) == -ubound
@test maximum(d_product) == ubound
@test extrema(d_product) == (-ubound, ubound)
@test isless(extrema(d_product)...)

y = rand(d_product)
@test y isa typeof(x)
Expand Down

0 comments on commit 625b722

Please sign in to comment.