From 625b72237a342c8d3bf60ec05541f8cb4a78faff Mon Sep 17 00:00:00 2001 From: agerlach <599421+agerlach@users.noreply.github.com> Date: Sat, 9 Oct 2021 15:18:06 -0400 Subject: [PATCH] Add `minimum`, `maximum`, `extrema` for `AbstractMvNormal` and `Product` (#1319) * adds minimum, maximum, extrema to AbstractMvNormal and Product distributions * add docstrings and update docs * update docstrrings * add isless tests for extrema * clean-up, remove broadcast in favor of map * move/update minimum/maximum fallback def/docstring Co-authored-by: Adam R Gerlach --- docs/src/multivariate.md | 4 ++++ src/common.jl | 20 ++++++++++++++++++++ src/multivariate/mvnormal.jl | 2 ++ src/multivariate/product.jl | 2 ++ src/univariates.jl | 21 --------------------- test/mvnormal.jl | 4 ++++ test/product.jl | 6 +++++- 7 files changed, 37 insertions(+), 22 deletions(-) diff --git a/docs/src/multivariate.md b/docs/src/multivariate.md index c1edc18cb4..be4f5427a6 100644 --- a/docs/src/multivariate.md +++ b/docs/src/multivariate.md @@ -71,8 +71,12 @@ invcov(::Distributions.AbstractMvNormal) logdetcov(::Distributions.AbstractMvNormal) sqmahal(::Distributions.AbstractMvNormal, ::AbstractArray) rand(::AbstractRNG, ::Distributions.AbstractMvNormal) +minimum(::Distributions.AbstractMvNormal) +maximum(::Distributions.AbstractMvNormal) +extrema(::Distributions.AbstractMvNormal) ``` + ### MvLogNormal In addition to the methods listed in the common interface above, we also provide the following methods: diff --git a/src/common.jl b/src/common.jl index 36c19a5bc3..f2f1798354 100644 --- a/src/common.jl +++ b/src/common.jl @@ -138,6 +138,26 @@ value_support(::Type{<:Distribution{VF,VS}}) where {VF,VS} = VS # to be decided: how to handle multivariate/matrixvariate distributions? Broadcast.broadcastable(d::UnivariateDistribution) = Ref(d) +""" + minimum(d::Distribution) + +Return the minimum of the support of `d`. +""" +minimum(d::Distribution) + +""" + maximum(d::Distribution) + +Return the maximum of the support of `d`. +""" +maximum(d::Distribution) + +""" + extrema(d::Distribution) + +Return the minimum and maximum of the support of `d` as a 2-tuple. +""" +Base.extrema(d::Distribution) = minimum(d), maximum(d) ## TODO: the following types need to be improved abstract type SufficientStats end diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 6e6c8cad0a..3827af61ac 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -80,6 +80,8 @@ abstract type AbstractMvNormal <: ContinuousMultivariateDistribution end insupport(d::AbstractMvNormal, x::AbstractVector) = length(d) == length(x) && all(isfinite, x) +minimum(d::AbstractMvNormal) = fill(eltype(d)(-Inf), length(d)) +maximum(d::AbstractMvNormal) = fill(eltype(d)(Inf), length(d)) mode(d::AbstractMvNormal) = mean(d) modes(d::AbstractMvNormal) = [mean(d)] diff --git a/src/multivariate/product.jl b/src/multivariate/product.jl index 454ffada3d..cee21c55d7 100644 --- a/src/multivariate/product.jl +++ b/src/multivariate/product.jl @@ -40,6 +40,8 @@ var(d::Product) = var.(d.v) cov(d::Product) = Diagonal(var(d)) entropy(d::Product) = sum(entropy, d.v) insupport(d::Product, x::AbstractVector) = all(insupport.(d.v, x)) +minimum(d::Product) = map(minimum, d.v) +maximum(d::Product) = map(maximum, d.v) """ product_distribution(dists::AbstractVector{<:UnivariateDistribution}) diff --git a/src/univariates.jl b/src/univariates.jl index 73c6eddff9..c31bc9a464 100644 --- a/src/univariates.jl +++ b/src/univariates.jl @@ -77,27 +77,6 @@ Get the degrees of freedom. """ dof(d::UnivariateDistribution) -""" - minimum(d::UnivariateDistribution) - -Return the minimum of the support of `d`. -""" -minimum(d::UnivariateDistribution) - -""" - maximum(d::UnivariateDistribution) - -Return the maximum of the support of `d`. -""" -maximum(d::UnivariateDistribution) - -""" - extrema(d::UnivariateDistribution) - -Return the minimum and maximum of the support of `d` as a 2-tuple. -""" -extrema(d::UnivariateDistribution) = (minimum(d), maximum(d)) - """ insupport(d::UnivariateDistribution, x::Any) diff --git a/test/mvnormal.jl b/test/mvnormal.jl index bb24c9e213..934ec29115 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -30,6 +30,10 @@ function test_mvnormal(g::AbstractMvNormal, n_tsamples::Int=10^6, vs = diag(Σ) @test g == typeof(g)(params(g)...) @test g == deepcopy(g) + @test minimum(g) == fill(-Inf, d) + @test maximum(g) == fill(Inf, d) + @test extrema(g) == (minimum(g), maximum(g)) + @test isless(extrema(g)...) # test sampling for AbstractMatrix (here, a SubArray): if ismissing(rng) diff --git a/test/product.jl b/test/product.jl index 3dc0e8de36..44c3b83c4f 100644 --- a/test/product.jl +++ b/test/product.jl @@ -29,7 +29,7 @@ end N = 11 # Construct independent distributions and `Product` distribution from these. ubound = rand(N) - ds = Uniform.(0.0, ubound) + ds = Uniform.(-ubound, ubound) x = rand.(ds) d_product = product_distribution(ds) @test d_product isa Product @@ -43,6 +43,10 @@ end @test entropy(d_product) == sum(entropy.(ds)) @test insupport(d_product, ubound) == true @test insupport(d_product, ubound .+ 1) == false + @test minimum(d_product) == -ubound + @test maximum(d_product) == ubound + @test extrema(d_product) == (-ubound, ubound) + @test isless(extrema(d_product)...) y = rand(d_product) @test y isa typeof(x)