From 3c3b618ac8ce0d34c84b57052cb710e51d4326b0 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Thu, 6 Jul 2023 16:50:39 +0200 Subject: [PATCH 1/7] Add half-normal --- src/ReactiveMP.jl | 1 + src/distributions/gamma_shape_rate.jl | 5 +++++ src/nodes/half_normal.jl | 10 ++++++++++ src/rules/half_normal/out.jl | 3 +++ src/rules/prototypes.jl | 2 ++ 5 files changed, 21 insertions(+) create mode 100644 src/nodes/half_normal.jl create mode 100644 src/rules/half_normal/out.jl diff --git a/src/ReactiveMP.jl b/src/ReactiveMP.jl index d92e67640..ab389d85e 100644 --- a/src/ReactiveMP.jl +++ b/src/ReactiveMP.jl @@ -159,6 +159,7 @@ include("nodes/bifm.jl") include("nodes/bifm_helper.jl") include("nodes/probit.jl") include("nodes/poisson.jl") +include("nodes/half_normal.jl") include("nodes/flow/flow.jl") include("nodes/delta/delta.jl") diff --git a/src/distributions/gamma_shape_rate.jl b/src/distributions/gamma_shape_rate.jl index 3a1bd7489..b7051135b 100644 --- a/src/distributions/gamma_shape_rate.jl +++ b/src/distributions/gamma_shape_rate.jl @@ -63,6 +63,11 @@ function prod(::ProdAnalytical, left::GammaShapeRate, right::GammaShapeRate) return GammaShapeRate(shape(left) + shape(right) - one(T), rate(left) + rate(right)) end +function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaDistributionsFamily) + @assert (left.lower ≈ zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * GammaDistributionsFamily only implemented for Truncated{Normal}(0, Inf)" + return vague(GammaShapeRate) +end + Distributions.pdf(dist::GammaShapeRate, x::Real) = exp(logpdf(dist, x)) Distributions.logpdf(dist::GammaShapeRate, x::Real) = shape(dist) * log(rate(dist)) - loggamma(shape(dist)) + (shape(dist) - 1) * log(x) - rate(dist) * x diff --git a/src/nodes/half_normal.jl b/src/nodes/half_normal.jl new file mode 100644 index 000000000..d35337f18 --- /dev/null +++ b/src/nodes/half_normal.jl @@ -0,0 +1,10 @@ +export HalfNormal + +struct HalfNormal end + +@node HalfNormal Stochastic [out, (v, aliases = [var, σ²]),] + +@average_energy HalfNormal (q_out::Any, q_v::Any) = begin + out_mean, out_var = mean_var(q_out) + return log2π + mean(log, q_v) + mean(inv, q_v) * (out_var + out_mean^2) / 2 +end \ No newline at end of file diff --git a/src/rules/half_normal/out.jl b/src/rules/half_normal/out.jl new file mode 100644 index 000000000..c9d55ca30 --- /dev/null +++ b/src/rules/half_normal/out.jl @@ -0,0 +1,3 @@ +@rule HalfNormal(:out, Marginalisation) (q_v::PointMass,) = begin + return Truncated(Normal(0.0, mean(q_v)), 0.0, Inf) +end \ No newline at end of file diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index 52933ecba..00d28e64a 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -171,3 +171,5 @@ include("delta/unscented/marginals.jl") include("delta/cvi/in.jl") include("delta/cvi/out.jl") include("delta/cvi/marginals.jl") + +include("half_normal/out.jl") \ No newline at end of file From df6db3622d6fdfa3c597144b6818e2fa1f9bc7ac Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Thu, 6 Jul 2023 17:07:14 +0200 Subject: [PATCH 2/7] Update prod --- src/distributions/gamma_shape_rate.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distributions/gamma_shape_rate.jl b/src/distributions/gamma_shape_rate.jl index b7051135b..c7bbc6027 100644 --- a/src/distributions/gamma_shape_rate.jl +++ b/src/distributions/gamma_shape_rate.jl @@ -63,8 +63,8 @@ function prod(::ProdAnalytical, left::GammaShapeRate, right::GammaShapeRate) return GammaShapeRate(shape(left) + shape(right) - one(T), rate(left) + rate(right)) end -function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaDistributionsFamily) - @assert (left.lower ≈ zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * GammaDistributionsFamily only implemented for Truncated{Normal}(0, Inf)" +function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaShapeRate) + @assert (left.lower ≈ zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * GammaShapeRate only implemented for Truncated{Normal}(0, Inf)" return vague(GammaShapeRate) end From 53e1e58b40e60dc3aa171738a7b84d3db4684bdd Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 7 Jul 2023 13:17:10 +0200 Subject: [PATCH 3/7] add methods --- src/distributions/gamma.jl | 22 ++++++++++++++++++++++ src/distributions/gamma_inverse.jl | 12 ++++++++++++ src/distributions/gamma_shape_rate.jl | 5 ----- src/rules/half_normal/out.jl | 2 +- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/distributions/gamma.jl b/src/distributions/gamma.jl index 1824c9996..e737526dd 100644 --- a/src/distributions/gamma.jl +++ b/src/distributions/gamma.jl @@ -78,6 +78,28 @@ function compute_logscale(new_dist::GammaDistributionsFamily, left_dist::GammaDi return loggamma(ay) - loggamma(ax) - loggamma(az) + ax * log(bx) + az * log(bz) - ay * log(by) end +prod_analytical_rule(::Type{<:Truncated{<:Normal}}, ::Type{<:GammaDistributionsFamily}) = ProdAnalyticalRuleAvailable() +prod_analytical_rule(::Type{<:GammaDistributionsFamily}, ::Type{<:Truncated{<:Normal}}) = ProdAnalyticalRuleAvailable() + +prod(::ProdAnalytical, left::GammaDistributionsFamily, right::Truncated{<:Normal}) = prod(ProdAnalytical(), right, left) + +function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaDistributionsFamily) + @assert (left.lower ≈ zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * Gamma only implemented for Truncated{Normal}(0, Inf)" + + samples = rand(MersenneTwister(123), left, 1000) + zeronum = zero(eltype(samples)) + + sx, xlogx, tw = mapreduce(.+, samples; init = (zeronum, zeronum, zeronum)) do sample + w = pdf(right, sample) + return (w * sample, w * log(sample), w) + end + + statistics = Distributions.GammaStats(sx, xlogx, tw) + fit = Distributions.fit_mle(Gamma, statistics, alpha0 = shape(right)) + + return convert(typeof(right), fit) +end + ## Friendly functions function logpdf_sample_friendly(dist::GammaDistributionsFamily) diff --git a/src/distributions/gamma_inverse.jl b/src/distributions/gamma_inverse.jl index f52aa0923..5741036b9 100644 --- a/src/distributions/gamma_inverse.jl +++ b/src/distributions/gamma_inverse.jl @@ -26,3 +26,15 @@ function mean(::typeof(inv), dist::GammaInverse) θ = scale(dist) return α / θ end + +prod_analytical_rule(::Type{<:Truncated{<:Normal}}, ::Type{<:GammaInverse}) = ProdAnalyticalRuleAvailable() +prod_analytical_rule(::Type{<:GammaInverse}, ::Type{<:Truncated{<:Normal}}) = ProdAnalyticalRuleAvailable() + +prod(::ProdAnalytical, left::GammaInverse, right::Truncated{<:Normal}) = prod(ProdAnalytical(), right, left) + +function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaInverse) + @assert (left.lower ≈ zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * GammaInverse only implemented for Truncated{Normal}(0, Inf)" + α, θ = shape(right), scale(right) + Γ = prod(ProdAnalytical(), left, Gamma(α, inv(θ))) + return InverseGamma(shape(Γ), inv(scale(Γ))) +end \ No newline at end of file diff --git a/src/distributions/gamma_shape_rate.jl b/src/distributions/gamma_shape_rate.jl index c7bbc6027..3a1bd7489 100644 --- a/src/distributions/gamma_shape_rate.jl +++ b/src/distributions/gamma_shape_rate.jl @@ -63,11 +63,6 @@ function prod(::ProdAnalytical, left::GammaShapeRate, right::GammaShapeRate) return GammaShapeRate(shape(left) + shape(right) - one(T), rate(left) + rate(right)) end -function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaShapeRate) - @assert (left.lower ≈ zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * GammaShapeRate only implemented for Truncated{Normal}(0, Inf)" - return vague(GammaShapeRate) -end - Distributions.pdf(dist::GammaShapeRate, x::Real) = exp(logpdf(dist, x)) Distributions.logpdf(dist::GammaShapeRate, x::Real) = shape(dist) * log(rate(dist)) - loggamma(shape(dist)) + (shape(dist) - 1) * log(x) - rate(dist) * x diff --git a/src/rules/half_normal/out.jl b/src/rules/half_normal/out.jl index c9d55ca30..c4047e8a5 100644 --- a/src/rules/half_normal/out.jl +++ b/src/rules/half_normal/out.jl @@ -1,3 +1,3 @@ @rule HalfNormal(:out, Marginalisation) (q_v::PointMass,) = begin - return Truncated(Normal(0.0, mean(q_v)), 0.0, Inf) + return Truncated(Normal(0.0, sqrt(mean(q_v))), 0.0, Inf) end \ No newline at end of file From 3ee213701f197b20c5dcd3d5a769b868c01e0031 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 7 Jul 2023 16:44:12 +0200 Subject: [PATCH 4/7] Make format --- src/distributions/gamma.jl | 2 +- src/distributions/gamma_inverse.jl | 2 +- src/nodes/half_normal.jl | 8 ++--- src/rules/half_normal/out.jl | 4 +-- src/rules/prototypes.jl | 2 +- test/nodes/test_gamma_inverse.jl | 6 ++-- test/nodes/test_half_normal.jl | 51 ++++++++++++++++++++++++++++++ test/rules/half_normal/test_out.jl | 17 ++++++++++ 8 files changed, 80 insertions(+), 12 deletions(-) create mode 100644 test/nodes/test_half_normal.jl create mode 100644 test/rules/half_normal/test_out.jl diff --git a/src/distributions/gamma.jl b/src/distributions/gamma.jl index e737526dd..3ba0b5c6c 100644 --- a/src/distributions/gamma.jl +++ b/src/distributions/gamma.jl @@ -88,7 +88,7 @@ function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaDistribut samples = rand(MersenneTwister(123), left, 1000) zeronum = zero(eltype(samples)) - + sx, xlogx, tw = mapreduce(.+, samples; init = (zeronum, zeronum, zeronum)) do sample w = pdf(right, sample) return (w * sample, w * log(sample), w) diff --git a/src/distributions/gamma_inverse.jl b/src/distributions/gamma_inverse.jl index 5741036b9..c1099bcd1 100644 --- a/src/distributions/gamma_inverse.jl +++ b/src/distributions/gamma_inverse.jl @@ -37,4 +37,4 @@ function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaInverse) α, θ = shape(right), scale(right) Γ = prod(ProdAnalytical(), left, Gamma(α, inv(θ))) return InverseGamma(shape(Γ), inv(scale(Γ))) -end \ No newline at end of file +end diff --git a/src/nodes/half_normal.jl b/src/nodes/half_normal.jl index d35337f18..3e6c5d588 100644 --- a/src/nodes/half_normal.jl +++ b/src/nodes/half_normal.jl @@ -2,9 +2,9 @@ export HalfNormal struct HalfNormal end -@node HalfNormal Stochastic [out, (v, aliases = [var, σ²]),] +@node HalfNormal Stochastic [out, (v, aliases = [var, σ²])] -@average_energy HalfNormal (q_out::Any, q_v::Any) = begin +@average_energy HalfNormal (q_out::Any, q_v::Any) = begin out_mean, out_var = mean_var(q_out) - return log2π + mean(log, q_v) + mean(inv, q_v) * (out_var + out_mean^2) / 2 -end \ No newline at end of file + return (log(π / 2) + mean(log, q_v) + mean(inv, q_v) * (out_mean^2 + out_var)) / 2 +end diff --git a/src/rules/half_normal/out.jl b/src/rules/half_normal/out.jl index c4047e8a5..6efcef080 100644 --- a/src/rules/half_normal/out.jl +++ b/src/rules/half_normal/out.jl @@ -1,3 +1,3 @@ -@rule HalfNormal(:out, Marginalisation) (q_v::PointMass,) = begin +@rule HalfNormal(:out, Marginalisation) (q_v::PointMass,) = begin return Truncated(Normal(0.0, sqrt(mean(q_v))), 0.0, Inf) -end \ No newline at end of file +end diff --git a/src/rules/prototypes.jl b/src/rules/prototypes.jl index 00d28e64a..d4d20ddd4 100644 --- a/src/rules/prototypes.jl +++ b/src/rules/prototypes.jl @@ -172,4 +172,4 @@ include("delta/cvi/in.jl") include("delta/cvi/out.jl") include("delta/cvi/marginals.jl") -include("half_normal/out.jl") \ No newline at end of file +include("half_normal/out.jl") diff --git a/test/nodes/test_gamma_inverse.jl b/test/nodes/test_gamma_inverse.jl index 68caf8e26..9462e976a 100644 --- a/test/nodes/test_gamma_inverse.jl +++ b/test/nodes/test_gamma_inverse.jl @@ -1,4 +1,4 @@ -module InverseWishartNodeTest +module InverseGammaNodeTest using Test using ReactiveMP @@ -6,7 +6,7 @@ using Random import ReactiveMP: make_node -@testset "InverseWishartNode" begin +@testset "InverseGammaNode" begin @testset "Creation" begin node = make_node(GammaInverse) @test functionalform(node) === GammaInverse @@ -16,7 +16,7 @@ import ReactiveMP: make_node @test localmarginalnames(node) === (:out_α_θ,) @test metadata(node) === nothing - node = make_node(InverseWishart, FactorNodeCreationOptions(nothing, 1, nothing)) + node = make_node(InverseGamma, FactorNodeCreationOptions(nothing, 1, nothing)) @test metadata(node) === 1 end diff --git a/test/nodes/test_half_normal.jl b/test/nodes/test_half_normal.jl new file mode 100644 index 000000000..9a16e6e9d --- /dev/null +++ b/test/nodes/test_half_normal.jl @@ -0,0 +1,51 @@ +module HalfNormalNodeTest + +using Test +using ReactiveMP +using Random + +import ReactiveMP: make_node + +@testset "HalfNormalNode" begin + @testset "Creation" begin + node = make_node(HalfNormal) + @test functionalform(node) === HalfNormal + @test sdtype(node) === Stochastic() + @test name.(interfaces(node)) === (:out, :v) + @test factorisation(node) === ((1, 2),) + @test localmarginalnames(node) === (:out_v,) + @test metadata(node) === nothing + + node = make_node(HalfNormal, FactorNodeCreationOptions(nothing, 1, nothing)) + @test metadata(node) === 1 + end + + @testset "AverageEnergy" begin + begin + q_out = GammaShapeRate(2.0, 1.0) + q_v = PointMass(2.0) + + marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing)) + + @test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) ≈ 2.072364942925 + end + begin + q_out = GammaShapeScale(2.0, 1.0) + q_v = PointMass(2.0) + + marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing)) + + @test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) ≈ 2.072364942925 + end + + begin + q_out = GammaInverse(3.0, 1.0) + q_v = PointMass(2.0) + + marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing)) + + @test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) ≈ 0.6973649429247 + end + end # testset: AverageEnergy +end # testset +end # module diff --git a/test/rules/half_normal/test_out.jl b/test/rules/half_normal/test_out.jl new file mode 100644 index 000000000..714c0dd95 --- /dev/null +++ b/test/rules/half_normal/test_out.jl @@ -0,0 +1,17 @@ +module RulesHalfNormalOutTest + +using Test +using ReactiveMP +using Random +using Distributions + +import ReactiveMP: @test_rules + +@testset "rules:HalfNormal:out" begin + @testset "Variational Message Passing: (q_v::Any)" begin + @test_rules [check_type_promotion = false] HalfNormal(:out, Marginalisation) [ + (input = (q_v = PointMass(1.0),), output = Truncated(Normal(0.0, 1.0), 0.0, Inf)), (input = (q_v = PointMass(100),), output = Truncated(Normal(0.0, 10.0), 0.0, Inf)) + ] + end +end # testset +end # module From 8efcf7c947160865b72be862dbec3b22370efb94 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Fri, 7 Jul 2023 17:11:24 +0200 Subject: [PATCH 5/7] Fix gamma inverse test --- test/nodes/test_gamma_inverse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nodes/test_gamma_inverse.jl b/test/nodes/test_gamma_inverse.jl index 9462e976a..2cb37f29b 100644 --- a/test/nodes/test_gamma_inverse.jl +++ b/test/nodes/test_gamma_inverse.jl @@ -16,7 +16,7 @@ import ReactiveMP: make_node @test localmarginalnames(node) === (:out_α_θ,) @test metadata(node) === nothing - node = make_node(InverseGamma, FactorNodeCreationOptions(nothing, 1, nothing)) + node = make_node(GammaInverse, FactorNodeCreationOptions(nothing, 1, nothing)) @test metadata(node) === 1 end From 5728106eda5c07e6208165bcb2846bdbf3cb4404 Mon Sep 17 00:00:00 2001 From: Albert Podusenko Date: Mon, 10 Jul 2023 13:53:01 +0200 Subject: [PATCH 6/7] Fix tests --- src/distributions/normal.jl | 5 +++++ src/rules/half_normal/out.jl | 3 ++- test/rules/half_normal/test_out.jl | 6 ++++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/distributions/normal.jl b/src/distributions/normal.jl index 0feeaaa21..baf4e3ad8 100644 --- a/src/distributions/normal.jl +++ b/src/distributions/normal.jl @@ -184,6 +184,11 @@ Base.isapprox(left::JointNormal, right::JointNormal; kwargs...) = isapprox(left. """An alias for the [`JointNormal`](@ref).""" const JointGaussian = JointNormal +# Half-Normal related +function convert_paramfloattype(::Type{T}, distribution::Truncated{<:Normal}) where {T} + return Truncated(convert_paramfloattype(T, distribution.untruncated), convert(T, distribution.lower), convert(T, distribution.upper)) +end + # Variate forms promotion promote_variate_type(::Type{Univariate}, ::Type{F}) where {F <: UnivariateNormalDistributionsFamily} = F diff --git a/src/rules/half_normal/out.jl b/src/rules/half_normal/out.jl index 6efcef080..d25e2c02f 100644 --- a/src/rules/half_normal/out.jl +++ b/src/rules/half_normal/out.jl @@ -1,3 +1,4 @@ @rule HalfNormal(:out, Marginalisation) (q_v::PointMass,) = begin - return Truncated(Normal(0.0, sqrt(mean(q_v))), 0.0, Inf) + mean_v = mean(q_v) + return Truncated(Normal(zero(eltype(q_v)), sqrt(mean_v)), zero(eltype(q_v)), typemax(float(mean_v))) end diff --git a/test/rules/half_normal/test_out.jl b/test/rules/half_normal/test_out.jl index 714c0dd95..48f883a25 100644 --- a/test/rules/half_normal/test_out.jl +++ b/test/rules/half_normal/test_out.jl @@ -9,8 +9,10 @@ import ReactiveMP: @test_rules @testset "rules:HalfNormal:out" begin @testset "Variational Message Passing: (q_v::Any)" begin - @test_rules [check_type_promotion = false] HalfNormal(:out, Marginalisation) [ - (input = (q_v = PointMass(1.0),), output = Truncated(Normal(0.0, 1.0), 0.0, Inf)), (input = (q_v = PointMass(100),), output = Truncated(Normal(0.0, 10.0), 0.0, Inf)) + @test_rules [check_type_promotion = true] HalfNormal(:out, Marginalisation) [ + (input = (q_v = PointMass(1.0),), output = Truncated(Normal(0.0, 1.0), 0.0, Inf)), + (input = (q_v = PointMass(100.0),), output = Truncated(Normal(0.0, 10.0), 0.0, Inf)), + (input = (q_v = PointMass(100),), output = Truncated(Normal(0.0, 10.0), 0.0, Inf)) ] end end # testset From a234d1ba08802aa9cb38985a7b456167eb5443ce Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 11 Jul 2023 17:02:01 +0200 Subject: [PATCH 7/7] add more tests --- test/distributions/test_gamma.jl | 11 +++++++++++ test/distributions/test_gamma_inverse.jl | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/test/distributions/test_gamma.jl b/test/distributions/test_gamma.jl index d579defbc..a54ff9400 100644 --- a/test/distributions/test_gamma.jl +++ b/test/distributions/test_gamma.jl @@ -227,6 +227,17 @@ import ReactiveMP: xtlog @test prod(ProdAnalytical(), GammaShapeRate(1, 2), GammaShapeScale(1, 2)) == GammaShapeRate(1, 5 / 2) @test prod(ProdAnalytical(), GammaShapeRate(2, 2), GammaShapeScale(1, 2)) == GammaShapeRate(2, 5 / 2) @test prod(ProdAnalytical(), GammaShapeRate(2, 2), GammaShapeScale(2, 2)) == GammaShapeRate(3, 5 / 2) + + @test_throws AssertionError prod(ProdAnalytical(), GammaShapeRate(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0)) + @test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaShapeRate(1, 1)) + @test_throws AssertionError prod(ProdAnalytical(), GammaShapeScale(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0)) + @test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaShapeScale(1, 1)) + + # TODO: these tests should check also check the actual result + @test prod(ProdAnalytical(), GammaShapeRate(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaShapeRate + @test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaShapeRate(1, 1)) isa GammaShapeRate + @test prod(ProdAnalytical(), GammaShapeScale(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaShapeScale + @test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaShapeScale(1, 1)) isa GammaShapeScale end end diff --git a/test/distributions/test_gamma_inverse.jl b/test/distributions/test_gamma_inverse.jl index 1800c38dd..d6f8b4fba 100644 --- a/test/distributions/test_gamma_inverse.jl +++ b/test/distributions/test_gamma_inverse.jl @@ -19,6 +19,13 @@ using Random @test prod(ProdAnalytical(), GammaInverse(3.0, 2.0), GammaInverse(2.0, 1.0)) ≈ GammaInverse(6.0, 3.0) @test prod(ProdAnalytical(), GammaInverse(7.0, 1.0), GammaInverse(0.1, 4.5)) ≈ GammaInverse(8.1, 5.5) @test prod(ProdAnalytical(), GammaInverse(1.0, 3.0), GammaInverse(0.2, 0.4)) ≈ GammaInverse(2.2, 3.4) + + @test_throws AssertionError prod(ProdAnalytical(), GammaInverse(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0)) + @test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaInverse(1, 1)) + + # TODO: these tests should check also check the actual result + @test prod(ProdAnalytical(), GammaInverse(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaInverse + @test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaInverse(1, 1)) isa GammaInverse end # log(θ) - digamma(α)