From 254ab50630c997a076c5bf5538dd90ef12256285 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 26 Jun 2024 10:08:08 +0200 Subject: [PATCH 1/2] Add rules and tests for `GammaShapeRate` node --- src/nodes/predefined/gamma_mixture.jl | 2 ++ src/rules/gamma_shape_rate/a.jl | 6 +++-- src/rules/gamma_shape_rate/b.jl | 3 +++ src/rules/gamma_shape_rate/marginals.jl | 2 +- src/rules/gamma_shape_rate/out.jl | 2 +- src/rules/predefined.jl | 1 + test/rules/gamma_shape_rate/a_tests.jl | 12 ++++++++++ test/rules/gamma_shape_rate/b_tests.jl | 14 ++++++++++++ .../rules/gamma_shape_rate/marginals_tests.jl | 22 +++++++++++++++++++ test/rules/gamma_shape_rate/out_tests.jl | 21 ++++++++++++++++++ 10 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 src/rules/gamma_shape_rate/b.jl create mode 100644 test/rules/gamma_shape_rate/a_tests.jl create mode 100644 test/rules/gamma_shape_rate/b_tests.jl create mode 100644 test/rules/gamma_shape_rate/marginals_tests.jl create mode 100644 test/rules/gamma_shape_rate/out_tests.jl diff --git a/src/nodes/predefined/gamma_mixture.jl b/src/nodes/predefined/gamma_mixture.jl index 304e18177..52df3b38c 100644 --- a/src/nodes/predefined/gamma_mixture.jl +++ b/src/nodes/predefined/gamma_mixture.jl @@ -183,6 +183,8 @@ struct GammaShapeLikelihood{T <: Real} <: ContinuousUnivariateDistribution γ::T # p * β end +Distributions.params(distribution::GammaShapeLikelihood) = (distribution.p, distribution.γ) + Distributions.@distr_support GammaShapeLikelihood 0.0 Inf BayesBase.support(dist::GammaShapeLikelihood) = Distributions.RealInterval(0.0, Inf) diff --git a/src/rules/gamma_shape_rate/a.jl b/src/rules/gamma_shape_rate/a.jl index 20d296d98..765b46a08 100644 --- a/src/rules/gamma_shape_rate/a.jl +++ b/src/rules/gamma_shape_rate/a.jl @@ -1,5 +1,7 @@ import DomainSets -@rule GammaShapeRate(:α, Marginalisation) (q_out::Gamma, q_β::Gamma) = begin - return ContinuousUnivariateLogPdf(DomainSets.HalfLine(), (α) -> α * mean(log, q_β) + (α - 1) * mean(log, q_out) - loggamma(α)) +@rule GammaShapeRate(:α, Marginalisation) (q_out::Any, q_β::GammaDistributionsFamily) = begin + γ = mean(log, q_β) + mean(log, q_out) + params = promote(1, γ) + return GammaShapeLikelihood(params...) end diff --git a/src/rules/gamma_shape_rate/b.jl b/src/rules/gamma_shape_rate/b.jl new file mode 100644 index 000000000..3957a1bcd --- /dev/null +++ b/src/rules/gamma_shape_rate/b.jl @@ -0,0 +1,3 @@ +export rule + +@rule GammaShapeRate(:β, Marginalisation) (q_out::Any, q_α::Any) = GammaShapeRate(1 + mean(q_α), mean(q_out)) diff --git a/src/rules/gamma_shape_rate/marginals.jl b/src/rules/gamma_shape_rate/marginals.jl index 1b454b1bf..4c113fdc5 100644 --- a/src/rules/gamma_shape_rate/marginals.jl +++ b/src/rules/gamma_shape_rate/marginals.jl @@ -1,5 +1,5 @@ export marginalrule -@marginalrule GammaShapeRate(:out_α_β) (m_out::Gamma, m_α::PointMass, m_β::PointMass) = begin +@marginalrule GammaShapeRate(:out_α_β) (m_out::GammaDistributionsFamily, m_α::PointMass, m_β::PointMass) = begin return (out = prod(ClosedProd(), GammaShapeRate(mean(m_α), mean(m_β)), m_out), α = m_α, β = m_β) end diff --git a/src/rules/gamma_shape_rate/out.jl b/src/rules/gamma_shape_rate/out.jl index 38cde3463..807c375d7 100644 --- a/src/rules/gamma_shape_rate/out.jl +++ b/src/rules/gamma_shape_rate/out.jl @@ -2,4 +2,4 @@ export rule @rule GammaShapeRate(:out, Marginalisation) (m_α::PointMass, m_β::PointMass) = GammaShapeRate(mean(m_α), mean(m_β)) -@rule GammaShapeRate(:out, Marginalisation) (q_α::PointMass, q_β::PointMass) = GammaShapeRate(mean(q_α), mean(q_β)) +@rule GammaShapeRate(:out, Marginalisation) (q_α::Any, q_β::Any) = GammaShapeRate(mean(q_α), mean(q_β)) diff --git a/src/rules/predefined.jl b/src/rules/predefined.jl index 3e9b9f1ce..eeccdfe3a 100644 --- a/src/rules/predefined.jl +++ b/src/rules/predefined.jl @@ -27,6 +27,7 @@ include("gamma_inverse/marginals.jl") include("gamma_shape_rate/out.jl") include("gamma_shape_rate/marginals.jl") include("gamma_shape_rate/a.jl") +include("gamma_shape_rate/b.jl") include("beta/out.jl") include("beta/marginals.jl") diff --git a/test/rules/gamma_shape_rate/a_tests.jl b/test/rules/gamma_shape_rate/a_tests.jl new file mode 100644 index 000000000..f9d187556 --- /dev/null +++ b/test/rules/gamma_shape_rate/a_tests.jl @@ -0,0 +1,12 @@ +@testitem "rules:GammaShapeRate:α" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + + import ReactiveMP: @test_rules, GammaShapeLikelihood + + @testset "Variational Message Passing: (q_out::Any, q_β::GammaDistributionsFamily)" begin + @test_rules [check_type_promotion = true] GammaShapeRate(:α, Marginalisation) [ + (input = (q_out = GammaShapeRate(1.0, 1.0), q_β = GammaShapeRate(1.0, 1.0)), output = GammaShapeLikelihood(1.0, 2.0 * -0.5772156649015315)), + (input = (q_out = PointMass(1.0), q_β = GammaShapeRate(1.0, 1.0)), output = GammaShapeLikelihood(1.0, -0.5772156649015315)) + ] + end +end diff --git a/test/rules/gamma_shape_rate/b_tests.jl b/test/rules/gamma_shape_rate/b_tests.jl new file mode 100644 index 000000000..7a0064492 --- /dev/null +++ b/test/rules/gamma_shape_rate/b_tests.jl @@ -0,0 +1,14 @@ +@testitem "rules:GammaShapeRate:β" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + + import ReactiveMP: @test_rules, GammaShapeLikelihood + + @testset "Variational Message Passing: (q_out::Any, q_α::Any)" begin + @test_rules [check_type_promotion = true] GammaShapeRate(:β, Marginalisation) [ + (input = (q_out = GammaShapeRate(1.0, 1.0), q_α = GammaShapeRate(1.0, 1.0)), output = GammaShapeRate(2.0, 1.0)), + (input = (q_out = PointMass(1.0), q_α = GammaShapeRate(1.0, 1.0)), output = GammaShapeRate(2.0, 1.0)), + (input = (q_out = GammaShapeScale(1.0, 1.0), q_α = PointMass(10.0)), output = GammaShapeRate(11.0, 1.0)), + (input = (q_out = GammaShapeScale(1.0, 10.0), q_α = GammaShapeRate(1.0, 1.0)), output = GammaShapeRate(2, 10)) + ] + end +end diff --git a/test/rules/gamma_shape_rate/marginals_tests.jl b/test/rules/gamma_shape_rate/marginals_tests.jl new file mode 100644 index 000000000..34a80ac8a --- /dev/null +++ b/test/rules/gamma_shape_rate/marginals_tests.jl @@ -0,0 +1,22 @@ +@testitem "marginalrules:GammaShapeRate" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + + import ReactiveMP: @test_marginalrules + + @testset "out_α_β: (m_out::GammaDistributionsFamily, m_α::PointMass, m_β::PointMass)" begin + @test_marginalrules [check_type_promotion = false] GammaShapeRate(:out_α_β) [ + ( + input = (m_out = GammaShapeRate(1.0, 2.0), m_α = PointMass(1.0), m_β = PointMass(2.0)), + output = (out = GammaShapeRate(1.0, 4.0), α = PointMass(1.0), β = PointMass(2.0)) + ), + ( + input = (m_out = GammaShapeScale(2.0, 2.0), m_α = PointMass(2.0), m_β = PointMass(3.0)), + output = (out = GammaShapeRate(3.0, 3.5), α = PointMass(2.0), β = PointMass(3.0)) + ), + ( + input = (m_out = GammaShapeRate(2.0, 3.0), m_α = PointMass(1.0), m_β = PointMass(3.0)), + output = (out = GammaShapeRate(2.0, 6.0), α = PointMass(1.0), β = PointMass(3.0)) + ) + ] + end +end diff --git a/test/rules/gamma_shape_rate/out_tests.jl b/test/rules/gamma_shape_rate/out_tests.jl new file mode 100644 index 000000000..0755ded50 --- /dev/null +++ b/test/rules/gamma_shape_rate/out_tests.jl @@ -0,0 +1,21 @@ + +@testitem "rules:GammaShapeRate:out" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + + import ReactiveMP: @test_rules + + @testset "Belief Propagation: (m_α::Any, m_θ::Any)" begin + @test_rules [check_type_promotion = true] GammaShapeRate(:out, Marginalisation) [ + (input = (m_α = PointMass(1.0), m_β = PointMass(2.0)), output = GammaShapeRate(1.0, 2.0)), + (input = (m_α = PointMass(3.0), m_β = PointMass(3.0)), output = GammaShapeRate(3.0, 3.0)), + (input = (m_α = PointMass(42.0), m_β = PointMass(42.0)), output = GammaShapeRate(42.0, 42.0)) + ] + end + + @testset "Variational Message Passing: (q_α::Any, q_β::Any)" begin + @test_rules [check_type_promotion = true] GammaShapeRate(:out, Marginalisation) [ + (input = (q_α = PointMass(1.0), q_β = PointMass(2.0)), output = GammaShapeRate(1.0, 2.0)), + (input = (q_α = GammaShapeScale(1.0, 1.0), q_β = GammaShapeRate(1.0, 1.0)), output = GammaShapeRate(1.0, 1.0)) + ] + end +end # testset From 5f8bcf20b20eda9a7d1324d39eef2e5bef173bb8 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 26 Jun 2024 10:08:49 +0200 Subject: [PATCH 2/2] Generalize dispatch for `Gamma` node --- src/rules/gamma/marginals.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/gamma/marginals.jl b/src/rules/gamma/marginals.jl index 1158efcb2..8a29ad431 100644 --- a/src/rules/gamma/marginals.jl +++ b/src/rules/gamma/marginals.jl @@ -1,5 +1,5 @@ export marginalrule -@marginalrule Gamma(:out_α_θ) (m_out::Gamma, m_α::PointMass, m_θ::PointMass) = begin +@marginalrule Gamma(:out_α_θ) (m_out::GammaDistributionsFamily, m_α::PointMass, m_θ::PointMass) = begin return (out = prod(ClosedProd(), Gamma(mean(m_α), mean(m_θ)), m_out), α = m_α, θ = m_θ) end