Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MF rules for CTransition #406

Merged
merged 3 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion src/nodes/predefined/continuous_transition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import LazyArrays
import StatsFuns: log2π

@doc raw"""
The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a`.
The functional form of the ContinuousTransition node is given by:
y ~ Normal(K(a) * x, W⁻¹)

This node transforms an m-dimensional vector x into an n-dimensional vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a` via a transformation K(a).
ContinuousTransition node is primarily used in two regimes:

# When no structure on A is specified:
Expand Down Expand Up @@ -37,6 +40,20 @@ Interfaces:
4. W - `n×n`-dimensional precision matrix used to soften the transition and perform variational message passing.

Note that you can set W to a fixed value or put a prior on it to control the amount of jitter.

The ContinuousTransition node support two factorizations:
1. Mean-field factorization:
```julia
@constraints begin
q(y, x, a, W) = q(y)q(x)q(a)q(W)
end
```
2. Structured factorization:
```julia
@constraints begin
q(y, x, a, W) = q(y, x)q(a)q(W)
end
```
"""
struct ContinuousTransition end

Expand Down Expand Up @@ -121,3 +138,27 @@ end

return AE
end

@average_energy ContinuousTransition (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
my, Vy = mean_cov(q_y)
mx, Vx = mean_cov(q_x)
mW = mean(q_W)

Fs = getjacobians(meta, ma)
dy = length(Fs)

n = div(ndims(q_y), 2)
mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta)

trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma))
xxt = mx * mx'
for (i, j) in Iterators.product(1:dy, 1:dy)
FjVaFi = Fs[j] * Va * Fs[i]'
trWSU += mW[j, i] * tr(FjVaFi)
trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi)
end
AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2

return AE
end
19 changes: 19 additions & 0 deletions src/rules/continuous_transition/W.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs)
dy = length(my)
G₁ = (my * my' + Vy)

G₂ = ((my * mx' + Vyx) * mA')
G₃ = transpose(G₂)
Ex_xx = rank1update(Vx, mx)
Expand All @@ -15,6 +16,7 @@ function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs)
return G₁ - G₂ - G₃ + G₅ + G₆
end

# VMP: Stuctured
@rule ContinuousTransition(:W, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
Fs = getjacobians(meta, ma)
Expand All @@ -33,3 +35,20 @@ end

return WishartFast(dy + 2, Δ)
end

# VMP: Mean-field
@rule ContinuousTransition(:W, Marginalisation) (q_y::Any, q_x::Any, q_a::Any, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
my, Vy = mean_cov(q_y)
mx, Vx = mean_cov(q_x)

Fs = getjacobians(meta, ma)
dy = length(Fs)

epsilon = sqrt.(var(q_a))
mA = ctcompanion_matrix(ma, epsilon, meta)

Δ = compute_delta(my, Vy, mx, Vx, zeros(eltype(ma), dy, length(mx)), mA, Va, ma, Fs)

return WishartFast(dy + 2, Δ)
end
28 changes: 28 additions & 0 deletions src/rules/continuous_transition/a.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Important note: ContinuousTransition node requires q(a) as input to compute the update message for a. This is a particular requirement for the ContinuousTransition node as it might need the expansion point for the transformation. This is not a general requirement for the VMP rules.

# VMP: Stuctured
@rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin
ma = mean(q_a)
mW = mean(q_W)
Expand All @@ -23,3 +26,28 @@

return MvNormalWeightedMeanPrecision(xi, W)
end

# VMP: Mean-field
@rule ContinuousTransition(:a, Marginalisation) (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin
mx, Vx = mean_cov(q_x)
mW = mean(q_W)
my = mean(q_y)
ma = mean(q_a)

Fs = getjacobians(meta, ma)
dy = length(Fs)

xi, W = zeros(eltype(ma), length(ma)), zeros(eltype(ma), length(ma), length(ma))

mxmy = mx * my'
Vxmx = rank1update(Vx, mx)

for i in 1:dy
xi += Fs[i]' * mxmy * mW[:, i]
for j in 1:dy
W += mW[j, i] * Fs[i]' * Vxmx * Fs[j]
end
end

return MvNormalWeightedMeanPrecision(xi, W)
end
24 changes: 24 additions & 0 deletions src/rules/continuous_transition/x.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# VMP: Stuctured
@rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
my, Wy = mean_precision(m_y)
Expand All @@ -22,3 +23,26 @@

return MvNormalWeightedMeanPrecision(z, Ξ)
end

# VMP: Mean-field
@rule ContinuousTransition(:x, Marginalisation) (q_y::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
my = mean(q_y)
mW = mean(q_W)

Fs = getjacobians(meta, ma)
dy = length(Fs)

epsilon = sqrt.(var(q_a))
mA = ctcompanion_matrix(ma, epsilon, meta)

Ξ = mA' * mW * mA

for (i, j) in Iterators.product(1:dy, 1:dy)
Ξ += mW[j, i] * Fs[j] * Va * Fs[i]'
end

z = mA' * mW * my

return MvNormalWeightedMeanPrecision(z, Ξ)
end
6 changes: 6 additions & 0 deletions src/rules/continuous_transition/y.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# VMP: Stuctured
@rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin
ma = mean(q_a)
mx, Vx = mean_cov(m_x)
Expand All @@ -12,3 +13,8 @@

return MvNormalMeanCovariance(my, Vy)
end

# VMP: Mean-field
@rule ContinuousTransition(:y, Marginalisation) (q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = MvNormalMeanPrecision(
ctcompanion_matrix(mean(q_a), sqrt.(var(q_a)), meta) * mean(q_x), mean(q_W)
)
19 changes: 13 additions & 6 deletions test/nodes/predefined/continuous_transition_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,25 @@
using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily

import ReactiveMP: getjacobians, gettransformation, ctcompanion_matrix
# TODO: A more rigorous test suit for the average energy of CTransition needs to be added
dy, dx = 2, 3
meta = CTMeta(a -> reshape(a, dy, dx))

@testset "AverageEnergy" begin
q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5))
q_a = MvNormalMeanCovariance(zeros(6), diageye(6))
q_W = Wishart(3, diageye(2))
q_y = MvNormalMeanCovariance(zeros(dy), diageye(dy))
q_x = MvNormalMeanCovariance(zeros(dx), diageye(dx))

marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing))
q_y_x = MvNormalMeanCovariance([mean(q_y); mean(q_x)], [cov(q_y) zeros(dy, dx); zeros(dx, dy) cov(q_x)])
q_a = MvNormalMeanCovariance(zeros(dx * dy), diageye(dx * dy))
q_W = Wishart(dy + 1, diageye(dy))

@test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) ≈ 13.0 atol = 1e-2
@show getjacobians(meta, mean(q_a))
marginals_st = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing))
marginals_mf = (Marginal(q_y, false, false, nothing), Marginal(q_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing))

# 12,992 is a result of manual calculation
@test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals_st, meta) ≈ 12.992 atol = 1e-2
# 12,07336 is a result of manual calculation
@test score(AverageEnergy(), ContinuousTransition, Val{(:y, :x, :a, :W)}(), marginals_mf, meta) ≈ 12.07736 atol = 1e-2
end

@testset "ContinuousTransition Functionality" begin
Expand Down
39 changes: 37 additions & 2 deletions test/rules/continuous_transition/W_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@testset "Linear transformation" begin
# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule(q_y_x, mA, ΣA, UA)
function benchmark_rule_structured(q_y_x, mA, ΣA, UA)
myx, Vyx = mean_cov(q_y_x)

dy = size(mA, 1)
Expand Down Expand Up @@ -39,7 +39,7 @@
qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA))

@test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [(
input = (q_y_x = qyx, q_a = qa, meta = metal), output = benchmark_rule(qyx, mA, ΣA, UA)
input = (q_y_x = qyx, q_a = qa, meta = metal), output = benchmark_rule_structured(qyx, mA, ΣA, UA)
)]
end
end
Expand All @@ -62,4 +62,39 @@
)]
end
end

# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule_meanfield(q_y, q_x, mA, ΣA, UA)
my, Vy = mean_cov(q_y)
mx, Vx = mean_cov(q_x)

dy = size(mA, 1)

G = tr(Vx * UA) * ΣA + mA * Vx * mA' + Vy + ΣA * mx' * UA * mx + (mA * mx - my) * (mA * mx - my)'

return WishartFast(dy + 2, G)
end

@testset "Mean-field: (q_y::Any, q_x::Any, q_a::Any, meta::CTMeta)" begin
for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)]
dydx = dy * dx
transformation = (a) -> reshape(a, dy, dx)
mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx)

metal = CTMeta(transformation)
Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy)
μx, Σx = rand(rng, dx), Lx * Lx'
μy, Σy = rand(rng, dy), Ly * Ly'

qy = MvNormalMeanCovariance(μy, Σy)
qx = MvNormalMeanCovariance(μx, Σx)

qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA))

@test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [(
input = (q_y = qy, q_x = qx, q_a = qa, meta = metal), output = benchmark_rule_meanfield(qy, qx, mA, ΣA, UA)
)]
end
end
end
38 changes: 36 additions & 2 deletions test/rules/continuous_transition/a_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule(q_y_x, q_W)
function benchmark_rule_structured(q_y_x, q_W)
myx, Vyx = mean_cov(q_y_x)
dy = size(q_W.S, 1)
Vx = Vyx[(dy + 1):end, (dy + 1):end]
Expand All @@ -36,7 +36,7 @@
qa = MvNormalMeanCovariance(a0, diageye(dydx))
qW = Wishart(dy + 1, diageye(dy))
@test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [(
input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qyx, qW)
input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_structured(qyx, qW)
)]
end
end
Expand All @@ -60,4 +60,38 @@
)]
end
end

# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
# NOTE: this test rule does not allow q_x as a PointMass as it involves the covariance matrix of q_x
function benchmark_rule_meanfield(q_y, q_x, q_W)
my = mean(q_y)
mx, Vx = mean_cov(q_x)
mW = mean(q_W)
Λ = kron(Vx + mx * mx', mW)
return MvNormalWeightedMeanPrecision(Λ * (vec(my * mx' * inv(Vx + mx * mx'))), Λ)
end

@testset "Mean-field: (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta)" begin
for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)]
dydx = dy * dx
transformation = (a) -> reshape(a, dy, dx)
a0 = rand(Float32, dydx)
metal = CTMeta(transformation)
Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy)
μx, Σx = rand(rng, dx), Lx * Lx'
μy, Σy = rand(rng, dy), Ly * Ly'
qy = MvNormalMeanCovariance(μy, Σy)
qx = MvNormalMeanCovariance(μx, Σx)
qa = MvNormalMeanCovariance(a0, diageye(dydx))
qW = Wishart(dy + 1, diageye(dy))
@test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [(
input = (q_y = qy, q_x = qx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(qy, qx, qW)
)]

@test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [(
input = (q_y = PointMass(μy), q_x = qx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(PointMass(μy), qx, qW)
)]
end
end
end
37 changes: 34 additions & 3 deletions test/rules/continuous_transition/x_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@testset "Linear transformation" begin
# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule(q_y, q_W, mA, ΣA, UA)
function benchmark_rule_strucutred(q_y, q_W, mA, ΣA, UA)
my, Vy = mean_cov(q_y)

mW = mean(q_W)
Expand All @@ -27,15 +27,15 @@
mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx)

metal = CTMeta(transformation)
Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy)
Ly = rand(rng, dy, dy)
μy, Σy = rand(rng, dy), Ly * Ly'

qy = MvNormalMeanCovariance(μy, Σy)
qa = MvNormalMeanCovariance(vec(mA), diageye(dydx))
qW = Wishart(dy + 1, diageye(dy))

@test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [(
input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qy, qW, mA, ΣA, UA)
input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_strucutred(qy, qW, mA, ΣA, UA)
)]
end
end
Expand All @@ -59,4 +59,35 @@
)]
end
end

# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule_meanfield(q_y, q_W, mA, ΣA, UA)
mW = mean(q_W)

Λ = mA'mW * mA + tr(mW * ΣA) * UA
ξ = mA' * mW * mean(q_y)
return MvNormalWeightedMeanPrecision(ξ, Λ)
end

@testset "Mean-field: (q_y::Any, q_a::Any, q_W::Any, meta::CTMeta)" begin
for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)]
dydx = dy * dx
transformation = (a) -> reshape(a, dy, dx)

mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx)

metal = CTMeta(transformation)
Ly = rand(rng, dy, dy)
μy, Σy = rand(rng, dy), Ly * Ly'

qy = MvNormalMeanCovariance(μy, Σy)
qa = MvNormalMeanCovariance(vec(mA), diageye(dydx))
qW = Wishart(dy + 1, diageye(dy))

@test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [(
input = (q_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(qy, qW, mA, ΣA, UA)
)]
end
end
end
Loading
Loading