From 29782b4704355469a6aacc7ae1dbd135c8a4b055 Mon Sep 17 00:00:00 2001 From: Albert Date: Thu, 27 Jun 2024 21:09:33 +0200 Subject: [PATCH 1/3] Add MF rules for CTransition --- src/nodes/predefined/continuous_transition.jl | 26 +++++++++++++ src/rules/continuous_transition/W.jl | 19 +++++++++ src/rules/continuous_transition/a.jl | 28 +++++++++++++ src/rules/continuous_transition/x.jl | 24 ++++++++++++ src/rules/continuous_transition/y.jl | 6 +++ test/rules/continuous_transition/W_tests.jl | 39 ++++++++++++++++++- test/rules/continuous_transition/a_tests.jl | 38 +++++++++++++++++- test/rules/continuous_transition/x_tests.jl | 37 ++++++++++++++++-- test/rules/continuous_transition/y_tests.jl | 21 ++++++++++ 9 files changed, 231 insertions(+), 7 deletions(-) diff --git a/src/nodes/predefined/continuous_transition.jl b/src/nodes/predefined/continuous_transition.jl index e39ecbea2..ae96a6b99 100644 --- a/src/nodes/predefined/continuous_transition.jl +++ b/src/nodes/predefined/continuous_transition.jl @@ -121,3 +121,29 @@ end return AE end + +@average_energy ContinuousTransition (q_y::Any, q_x::Any, q_W::Any, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) + my, Vy = mean_cov(q_y) + mx, Vx = mean_cov(q_x) + mW = mean(q_W) + + Fs = getjacobians(meta, ma) + dy = length(Fs) + + n = div(ndims(q_y), 2) + mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) + + g1 = -mA + g2 = g1' + trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma)) + xxt = mx * mx' + for (i, j) in Iterators.product(1:dy, 1:dy) + FjVaFi = Fs[j] * Va * Fs[i]' + trWSU += mW[j, i] * tr(FjVaFi) + trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi) + end + AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2 + + return AE +end diff --git a/src/rules/continuous_transition/W.jl b/src/rules/continuous_transition/W.jl index 340ae5577..5ec9f6d99 100644 --- a/src/rules/continuous_transition/W.jl +++ b/src/rules/continuous_transition/W.jl @@ -1,6 +1,7 @@ function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs) dy = length(my) G₁ = (my * my' + Vy) + G₂ = ((my * mx' + Vyx) * mA') G₃ = transpose(G₂) Ex_xx = rank1update(Vx, mx) @@ -15,6 +16,7 @@ function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs) return G₁ - G₂ - G₃ + G₅ + G₆ end +# VMP: Stuctured @rule ContinuousTransition(:W, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta) = begin ma, Va = mean_cov(q_a) Fs = getjacobians(meta, ma) @@ -33,3 +35,20 @@ end return WishartFast(dy + 2, Δ) end + +# VMP: Mean-field +@rule ContinuousTransition(:W, Marginalisation) (q_y::Any, q_x::Any, q_a::Any, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) + my, Vy = mean_cov(q_y) + mx, Vx = mean_cov(q_x) + + Fs = getjacobians(meta, ma) + dy = length(Fs) + + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) + + Δ = compute_delta(my, Vy, mx, Vx, zeros(eltype(ma), dy, length(mx)), mA, Va, ma, Fs) + + return WishartFast(dy + 2, Δ) +end diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index ccc768598..4a506134f 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -1,3 +1,6 @@ +# NOTE: Both rules require q_a as input. This is a particular requirement for the ContinuousTransition node as it might need the expansion point for the transformation. This is not a general requirement for the VMP rules. + +# VMP: Stuctured @rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin ma = mean(q_a) mW = mean(q_W) @@ -23,3 +26,28 @@ return MvNormalWeightedMeanPrecision(xi, W) end + +# VMP: Mean-field +@rule ContinuousTransition(:a, Marginalisation) (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin + mx, Vx = mean_cov(q_x) + mW = mean(q_W) + my = mean(q_y) + ma = mean(q_a) + + Fs = getjacobians(meta, ma) + dy = length(Fs) + + xi, W = zeros(eltype(ma), length(ma)), zeros(eltype(ma), length(ma), length(ma)) + + mxmy = mx * my' + Vxmx = rank1update(Vx, mx) + + for i in 1:dy + xi += Fs[i]' * mxmy * mW[:, i] + for j in 1:dy + W += mW[j, i] * Fs[i]' * Vxmx * Fs[j] + end + end + + return MvNormalWeightedMeanPrecision(xi, W) +end diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index 56169db82..20a917ebe 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -1,3 +1,4 @@ +# VMP: Stuctured @rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin ma, Va = mean_cov(q_a) my, Wy = mean_precision(m_y) @@ -22,3 +23,26 @@ return MvNormalWeightedMeanPrecision(z, Ξ) end + +# VMP: Mean-field +@rule ContinuousTransition(:x, Marginalisation) (q_y::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin + ma, Va = mean_cov(q_a) + my = mean(q_y) + mW = mean(q_W) + + Fs = getjacobians(meta, ma) + dy = length(Fs) + + epsilon = sqrt.(var(q_a)) + mA = ctcompanion_matrix(ma, epsilon, meta) + + Ξ = mA' * mW * mA + + for (i, j) in Iterators.product(1:dy, 1:dy) + Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' + end + + z = mA' * mW * my + + return MvNormalWeightedMeanPrecision(z, Ξ) +end diff --git a/src/rules/continuous_transition/y.jl b/src/rules/continuous_transition/y.jl index 864ce4237..480eff682 100644 --- a/src/rules/continuous_transition/y.jl +++ b/src/rules/continuous_transition/y.jl @@ -1,3 +1,4 @@ +# VMP: Stuctured @rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin ma = mean(q_a) mx, Vx = mean_cov(m_x) @@ -12,3 +13,8 @@ return MvNormalMeanCovariance(my, Vy) end + +# VMP: Mean-field +@rule ContinuousTransition(:y, Marginalisation) (q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = MvNormalMeanPrecision( + ctcompanion_matrix(mean(q_a), sqrt.(var(q_a)), meta) * mean(q_x), mean(q_W) +) diff --git a/test/rules/continuous_transition/W_tests.jl b/test/rules/continuous_transition/W_tests.jl index 3f3f5f060..12a4ed678 100644 --- a/test/rules/continuous_transition/W_tests.jl +++ b/test/rules/continuous_transition/W_tests.jl @@ -9,7 +9,7 @@ @testset "Linear transformation" begin # the following rule is used for testing purposes only # It is derived separately by Thijs van de Laar - function benchmark_rule(q_y_x, mA, ΣA, UA) + function benchmark_rule_structured(q_y_x, mA, ΣA, UA) myx, Vyx = mean_cov(q_y_x) dy = size(mA, 1) @@ -39,7 +39,7 @@ qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [( - input = (q_y_x = qyx, q_a = qa, meta = metal), output = benchmark_rule(qyx, mA, ΣA, UA) + input = (q_y_x = qyx, q_a = qa, meta = metal), output = benchmark_rule_structured(qyx, mA, ΣA, UA) )] end end @@ -62,4 +62,39 @@ )] end end + + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule_meanfield(q_y, q_x, mA, ΣA, UA) + my, Vy = mean_cov(q_y) + mx, Vx = mean_cov(q_x) + + dy = size(mA, 1) + + G = tr(Vx * UA) * ΣA + mA * Vx * mA' + Vy + ΣA * mx' * UA * mx + (mA * mx - my) * (mA * mx - my)' + + return WishartFast(dy + 2, G) + end + + @testset "Mean-field: (q_y::Any, q_x::Any, q_a::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + + metal = CTMeta(transformation) + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μx, Σx = rand(rng, dx), Lx * Lx' + μy, Σy = rand(rng, dy), Ly * Ly' + + qy = MvNormalMeanCovariance(μy, Σy) + qx = MvNormalMeanCovariance(μx, Σx) + + qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA)) + + @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [( + input = (q_y = qy, q_x = qx, q_a = qa, meta = metal), output = benchmark_rule_meanfield(qy, qx, mA, ΣA, UA) + )] + end + end end diff --git a/test/rules/continuous_transition/a_tests.jl b/test/rules/continuous_transition/a_tests.jl index 289b8b900..e16700546 100644 --- a/test/rules/continuous_transition/a_tests.jl +++ b/test/rules/continuous_transition/a_tests.jl @@ -10,7 +10,7 @@ # the following rule is used for testing purposes only # It is derived separately by Thijs van de Laar - function benchmark_rule(q_y_x, q_W) + function benchmark_rule_structured(q_y_x, q_W) myx, Vyx = mean_cov(q_y_x) dy = size(q_W.S, 1) Vx = Vyx[(dy + 1):end, (dy + 1):end] @@ -36,7 +36,7 @@ qa = MvNormalMeanCovariance(a0, diageye(dydx)) qW = Wishart(dy + 1, diageye(dy)) @test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [( - input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qyx, qW) + input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_structured(qyx, qW) )] end end @@ -60,4 +60,38 @@ )] end end + + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + # NOTE: this test rule does not allow q_x as a PointMass as it involves the covariance matrix of q_x + function benchmark_rule_meanfield(q_y, q_x, q_W) + my = mean(q_y) + mx, Vx = mean_cov(q_x) + mW = mean(q_W) + Λ = kron(Vx + mx * mx', mW) + return MvNormalWeightedMeanPrecision(Λ * (vec(my * mx' * inv(Vx + mx * mx'))), Λ) + end + + @testset "Mean-field: (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + a0 = rand(Float32, dydx) + metal = CTMeta(transformation) + Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + μx, Σx = rand(rng, dx), Lx * Lx' + μy, Σy = rand(rng, dy), Ly * Ly' + qy = MvNormalMeanCovariance(μy, Σy) + qx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(a0, diageye(dydx)) + qW = Wishart(dy + 1, diageye(dy)) + @test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [( + input = (q_y = qy, q_x = qx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(qy, qx, qW) + )] + + @test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [( + input = (q_y = PointMass(μy), q_x = qx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(PointMass(μy), qx, qW) + )] + end + end end diff --git a/test/rules/continuous_transition/x_tests.jl b/test/rules/continuous_transition/x_tests.jl index 05b9a6794..83ad2352f 100644 --- a/test/rules/continuous_transition/x_tests.jl +++ b/test/rules/continuous_transition/x_tests.jl @@ -9,7 +9,7 @@ @testset "Linear transformation" begin # the following rule is used for testing purposes only # It is derived separately by Thijs van de Laar - function benchmark_rule(q_y, q_W, mA, ΣA, UA) + function benchmark_rule_strucutred(q_y, q_W, mA, ΣA, UA) my, Vy = mean_cov(q_y) mW = mean(q_W) @@ -27,7 +27,7 @@ mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) metal = CTMeta(transformation) - Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy) + Ly = rand(rng, dy, dy) μy, Σy = rand(rng, dy), Ly * Ly' qy = MvNormalMeanCovariance(μy, Σy) @@ -35,7 +35,7 @@ qW = Wishart(dy + 1, diageye(dy)) @test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [( - input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qy, qW, mA, ΣA, UA) + input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_strucutred(qy, qW, mA, ΣA, UA) )] end end @@ -59,4 +59,35 @@ )] end end + + # the following rule is used for testing purposes only + # It is derived separately by Thijs van de Laar + function benchmark_rule_meanfield(q_y, q_W, mA, ΣA, UA) + mW = mean(q_W) + + Λ = mA'mW * mA + tr(mW * ΣA) * UA + ξ = mA' * mW * mean(q_y) + return MvNormalWeightedMeanPrecision(ξ, Λ) + end + + @testset "Mean-field: (q_y::Any, q_a::Any, q_W::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + + mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx) + + metal = CTMeta(transformation) + Ly = rand(rng, dy, dy) + μy, Σy = rand(rng, dy), Ly * Ly' + + qy = MvNormalMeanCovariance(μy, Σy) + qa = MvNormalMeanCovariance(vec(mA), diageye(dydx)) + qW = Wishart(dy + 1, diageye(dy)) + + @test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [( + input = (q_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(qy, qW, mA, ΣA, UA) + )] + end + end end diff --git a/test/rules/continuous_transition/y_tests.jl b/test/rules/continuous_transition/y_tests.jl index 4416e86d6..46b623215 100644 --- a/test/rules/continuous_transition/y_tests.jl +++ b/test/rules/continuous_transition/y_tests.jl @@ -55,4 +55,25 @@ )] end end + + @testset "Mean-field: (q_y::Any, q_a::Any, q_W::Any, meta::CTMeta)" begin + for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)] + dydx = dy * dx + transformation = (a) -> reshape(a, dy, dx) + + mA = rand(rng, dy, dx) + + metal = CTMeta(transformation) + Lx = rand(rng, dx, dx) + μx, Σx = rand(rng, dx), Lx * Lx' + + qx = MvNormalMeanCovariance(μx, Σx) + qa = MvNormalMeanCovariance(vec(mA), diageye(dydx)) + qW = Wishart(dy + 1, diageye(dy)) + + @test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:y, Marginalisation) [( + input = (q_x = qx, q_a = qa, q_W = qW, meta = metal), output = MvNormalMeanPrecision(mA * μx, mean(qW)) + )] + end + end end From be3ab4246957d094574ea91d7690a1669b574eff Mon Sep 17 00:00:00 2001 From: Albert Date: Thu, 27 Jun 2024 21:28:32 +0200 Subject: [PATCH 2/3] Add tests for CTransition --- src/nodes/predefined/continuous_transition.jl | 6 ++---- .../predefined/continuous_transition_tests.jl | 19 +++++++++++++------ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/nodes/predefined/continuous_transition.jl b/src/nodes/predefined/continuous_transition.jl index ae96a6b99..b4247c63e 100644 --- a/src/nodes/predefined/continuous_transition.jl +++ b/src/nodes/predefined/continuous_transition.jl @@ -122,7 +122,7 @@ end return AE end -@average_energy ContinuousTransition (q_y::Any, q_x::Any, q_W::Any, meta::CTMeta) = begin +@average_energy ContinuousTransition (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin ma, Va = mean_cov(q_a) my, Vy = mean_cov(q_y) mx, Vx = mean_cov(q_x) @@ -134,8 +134,6 @@ end n = div(ndims(q_y), 2) mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) - g1 = -mA - g2 = g1' trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma)) xxt = mx * mx' for (i, j) in Iterators.product(1:dy, 1:dy) @@ -143,7 +141,7 @@ end trWSU += mW[j, i] * tr(FjVaFi) trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi) end - AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2 + AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2 return AE end diff --git a/test/nodes/predefined/continuous_transition_tests.jl b/test/nodes/predefined/continuous_transition_tests.jl index b3c56ae1b..b4415df48 100644 --- a/test/nodes/predefined/continuous_transition_tests.jl +++ b/test/nodes/predefined/continuous_transition_tests.jl @@ -3,18 +3,25 @@ using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily import ReactiveMP: getjacobians, gettransformation, ctcompanion_matrix + # TODO: A more rigorous test suit for the average energy of CTransition needs to be added dy, dx = 2, 3 meta = CTMeta(a -> reshape(a, dy, dx)) @testset "AverageEnergy" begin - q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5)) - q_a = MvNormalMeanCovariance(zeros(6), diageye(6)) - q_W = Wishart(3, diageye(2)) + q_y = MvNormalMeanCovariance(zeros(dy), diageye(dy)) + q_x = MvNormalMeanCovariance(zeros(dx), diageye(dx)) - marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) + q_y_x = MvNormalMeanCovariance([mean(q_y); mean(q_x)], [cov(q_y) zeros(dy, dx); zeros(dx, dy) cov(q_x)]) + q_a = MvNormalMeanCovariance(zeros(dx * dy), diageye(dx * dy)) + q_W = Wishart(dy + 1, diageye(dy)) - @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.0 atol = 1e-2 - @show getjacobians(meta, mean(q_a)) + marginals_st = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) + marginals_mf = (Marginal(q_y, false, false, nothing), Marginal(q_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing)) + + # 12,992 is a result of manual calculation + @test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals_st, meta) ≈ 12.992 atol = 1e-2 + # 12,07336 is a result of manual calculation + @test score(AverageEnergy(), ContinuousTransition, Val{(:y, :x, :a, :W)}(), marginals_mf, meta) ≈ 12.07736 atol = 1e-2 end @testset "ContinuousTransition Functionality" begin From 748bb7aa8151461f784a801add7eb50bb6033adf Mon Sep 17 00:00:00 2001 From: Albert Date: Mon, 1 Jul 2024 19:17:29 +0200 Subject: [PATCH 3/3] Update docstrings --- src/nodes/predefined/continuous_transition.jl | 19 ++++++++++++++++++- src/rules/continuous_transition/a.jl | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/nodes/predefined/continuous_transition.jl b/src/nodes/predefined/continuous_transition.jl index b4247c63e..9845b40b8 100644 --- a/src/nodes/predefined/continuous_transition.jl +++ b/src/nodes/predefined/continuous_transition.jl @@ -4,7 +4,10 @@ import LazyArrays import StatsFuns: log2π @doc raw""" -The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a`. +The functional form of the ContinuousTransition node is given by: +y ~ Normal(K(a) * x, W⁻¹) + +This node transforms an m-dimensional vector x into an n-dimensional vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a` via a transformation K(a). ContinuousTransition node is primarily used in two regimes: # When no structure on A is specified: @@ -37,6 +40,20 @@ Interfaces: 4. W - `n×n`-dimensional precision matrix used to soften the transition and perform variational message passing. Note that you can set W to a fixed value or put a prior on it to control the amount of jitter. + +The ContinuousTransition node support two factorizations: +1. Mean-field factorization: +```julia +@constraints begin + q(y, x, a, W) = q(y)q(x)q(a)q(W) +end +``` +2. Structured factorization: +```julia +@constraints begin + q(y, x, a, W) = q(y, x)q(a)q(W) +end +``` """ struct ContinuousTransition end diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index 4a506134f..d3dd2f87b 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -1,4 +1,4 @@ -# NOTE: Both rules require q_a as input. This is a particular requirement for the ContinuousTransition node as it might need the expansion point for the transformation. This is not a general requirement for the VMP rules. +# Important note: ContinuousTransition node requires q(a) as input to compute the update message for a. This is a particular requirement for the ContinuousTransition node as it might need the expansion point for the transformation. This is not a general requirement for the VMP rules. # VMP: Stuctured @rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin