Skip to content

Commit

Permalink
Add structured VMP rules for softdot node
Browse files Browse the repository at this point in the history
  • Loading branch information
albertpod committed Jun 25, 2024
1 parent 881231e commit 060eeaf
Show file tree
Hide file tree
Showing 19 changed files with 291 additions and 5 deletions.
18 changes: 18 additions & 0 deletions src/nodes/predefined/softdot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,21 @@ const softdot = SoftDot
m_γ = mean(q_γ)
return (-mean(log, q_γ) + log2π + m_γ * (V_y + m_y^2 - 2m_γ * m_y * m_θ'm_x + mul_trace(V_θ, V_x) + m_x'V_θ * m_x + m_θ' * (V_x + m_x * m_x') * m_θ)) / 2
end

@average_energy softdot (q_y_x::MultivariateNormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_γ::GammaShapeRate) = begin
mθ, Vθ = mean_cov(q_θ)
myx, Vyx = mean_cov(q_y_x)
= mean(q_γ)

order = length(mθ)
F = order == 1 ? Univariate : Multivariate

mx, Vx = ar_slice(F, myx, (order + 1):(2order)), ar_slice(F, Vyx, (order + 1):(2order), (order + 1):(2order))
my1, Vy1 = first(myx), first(Vyx)
Vy1x = ar_slice(F, Vyx, 1, (order + 1):(2order))

# Equivalent to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2
AE = (-mean(log, q_γ) + log2π +* (Vy1 + my1^2 - 2 *' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2

return AE
end
4 changes: 4 additions & 0 deletions src/rules/dot_product/in1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(dot)(:in2, Marginalisation) (m_out = m_out, m_in1 = m_in2, meta = meta)
end

@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return error("The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead.")
end
4 changes: 4 additions & 0 deletions src/rules/dot_product/in2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@

return convert(promote_variate_type(variate_form(typeof(m_in1)), NormalWeightedMeanPrecision), ξ, W)
end

@rule typeof(dot)(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return error("The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead.")
end
4 changes: 4 additions & 0 deletions src/rules/dot_product/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ end
in2_mean, in2_cov = mean_cov(m_in2)
return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A))
end

@rule typeof(dot)(:out, Marginalisation) (m_in1::NormalDistributionsFamily, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return error("The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead.")
end
11 changes: 6 additions & 5 deletions src/rules/predefined.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,6 @@ include("dot_product/out.jl")
include("dot_product/in1.jl")
include("dot_product/in2.jl")

include("softdot/y.jl")
include("softdot/x.jl")
include("softdot/theta.jl")
include("softdot/gamma.jl")

include("transition/marginals.jl")
include("transition/out.jl")
include("transition/in.jl")
Expand All @@ -123,6 +118,12 @@ include("autoregressive/theta.jl")
include("autoregressive/gamma.jl")
include("autoregressive/marginals.jl")

include("softdot/y.jl")
include("softdot/x.jl")
include("softdot/theta.jl")
include("softdot/gamma.jl")
include("softdot/marginals.jl")

include("probit/marginals.jl")
include("probit/in.jl")
include("probit/out.jl")
Expand Down
22 changes: 22 additions & 0 deletions src/rules/softdot/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,25 @@
β += (mul_trace(Vx, Vθ) +'Vx *+ mx'* mx +'mx * mx'mθ) / 2
return GammaShapeRate(α, β)
end

# Variational MP: Structured
@rule softdot(, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_θ::Any) = begin
# q_y is always Univariate
order = length(q_y_x) - 1
F = order == 1 ? Univariate : Multivariate

y_x_mean, y_x_cov = mean_cov(q_y_x)
mθ, Vθ = mean_cov(q_θ)

my, Vy = first(y_x_mean), first(y_x_cov)
mx, Vx = ar_slice(F, y_x_mean, 2:(order + 1)), ar_slice(F, y_x_cov, 2:(order + 1), 2:(order + 1))
Vyx = ar_slice(F, y_x_cov, 2:(order + 1))

C = rank1update(Vx, mx)
R = rank1update(Vy, my)
L = Vyx + mx * my

B = first(R) - 2 * first(mθ' * L) + first(mθ' * C * mθ) + mul_trace(Vθ, C)

return GammaShapeRate(convert(eltype(B), 3//2), B / 2)
end
27 changes: 27 additions & 0 deletions src/rules/softdot/marginals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

# The following marginal rule is adaptation of the marginal rule for Autoregressive node
@marginalrule SoftDot(:y_x) (m_y::NormalDistributionsFamily, m_x::NormalDistributionsFamily, q_θ::NormalDistributionsFamily, q_γ::Any) = begin
mθ, Vθ = mean_cov(q_θ)
= mean(q_γ)

b_my, b_Vy = mean_cov(m_y)
f_mx, f_Vx = mean_cov(m_x)

inv_b_Vy = cholinv(b_Vy)
inv_f_Vx = cholinv(f_Vx)

D = inv_f_Vx +*

W_11 = inv_b_Vy +

W_12 = -*'

W_21 = -*

W_22 = D +**'

W = [W_11 W_12; W_21 W_22]
ξ = [inv_b_Vy * b_my; inv_f_Vx * f_mx]

return MvNormalWeightedMeanPrecision(ξ, W)
end
20 changes: 20 additions & 0 deletions src/rules/softdot/theta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,23 @@
=* mx * my
return convert(promote_variate_type(variate_form(typeof(q_x)), NormalWeightedMeanPrecision), zθ, Dθ)
end

# Variational MP: Structured
@rule softdot(, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_γ::Any) = begin
# q_y is always Univariate
order = length(q_y_x) - 1
F = order == 1 ? Univariate : Multivariate

myx, Vyx = mean_cov(q_y_x)
my, Vy = first(myx), first(Vyx)
mx, Vx = ar_slice(F, myx, 2:(order + 1)), ar_slice(F, Vyx, 2:(order + 1), 2:(order + 1))
Vyx = ar_slice(F, Vyx, 2:(order + 1))

= mean(q_γ)

W =* (Vx + mx * mx')

ξ = (Vyx + mx * my') *

return convert(promote_variate_type(F, NormalWeightedMeanPrecision), ξ, W)
end
18 changes: 18 additions & 0 deletions src/rules/softdot/x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,21 @@
zx =** my
return convert(promote_variate_type(variate_form(typeof(q_θ)), NormalWeightedMeanPrecision), zx, Dx)
end

# Variational MP: Structured
@rule softdot(:x, Marginalisation) (m_y::UnivariateNormalDistributionsFamily, q_θ::Any, q_γ::Any) = begin
# the naive call of AR rule is not possible, because the softdot rule expects m_y to be a UnivariateNormalDistributionsFamily
mθ, Vθ = mean_cov(q_θ)
my, Vy = mean_cov(m_y)

= mean(q_γ)

mV = inv(mγ)

C =* inv(add_transition(Vy, mV))

W = C *' +*
ξ = C * my

return convert(promote_variate_type(variate_form(typeof(q_θ)), NormalWeightedMeanPrecision), ξ, W)
end
4 changes: 4 additions & 0 deletions src/rules/softdot/y.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@

# Variational MP: Mean-field
@rule softdot(:y, Marginalisation) (q_θ::Any, q_x::Any, q_γ::Any) = NormalMeanPrecision(mean(q_θ)'mean(q_x), mean(q_γ))

@rule softdot(:y, Marginalisation) (q_θ::Any, m_x::Any, q_γ::Any) = NormalMeanVariance(
first.(mean_cov((@call_rule AR(:y, Marginalisation) (m_x = m_x, q_θ = q_θ, q_γ = q_γ, meta = ARMeta(variate_form(typeof(m_x)), length(q_θ), ARsafe())))))...
)
9 changes: 9 additions & 0 deletions test/nodes/predefined/softdot_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,14 @@

@test score(AverageEnergy(), SoftDot, Val{(:y, :θ, :x, :γ)}(), marginals, nothing) 8.15193210352257
end

begin
q_y_x = MvNormalMeanCovariance(zeros(2), diageye(2))
q_θ = NormalMeanVariance(0.0, 1.0)
q_γ = GammaShapeRate(2.0, 3.0)

marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_θ, false, false, nothing), Marginal(q_γ, false, false, nothing))
@test score(AverageEnergy(), SoftDot, Val{(:y_x, :θ, :γ)}(), marginals, nothing) 1.92351917665616
end
end # testset: AverageEnergy
end # testset
8 changes: 8 additions & 0 deletions test/rules/dot_product/in1_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,12 @@
)
]
end

@testset "Error Belief Propagation: (m_out::UnivariateNormalDistributionsFamily, m_in2::NormalDistributionsFamily)" begin
@test_throws r"The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead." @call_rule typeof(
dot
)(
:in1, Marginalisation
) (m_out = NormalMeanVariance(2.0, 2.0), m_in2 = MvNormalMeanCovariance([-1.0, 1.0], [2.0 -1.0; -1.0 4.0]), meta = NoCorrection())
end
end
8 changes: 8 additions & 0 deletions test/rules/dot_product/in2_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,12 @@
)
]
end

@testset "Error Belief Propagation: (m_out::UnivariateNormalDistributionsFamily, m_in1::NormalDistributionsFamily)" begin
@test_throws r"The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead." @call_rule typeof(
dot
)(
:in2, Marginalisation
) (m_out = NormalMeanVariance(2.0, 2.0), m_in1 = MvNormalMeanCovariance([-1.0, 1.0], [2.0 -1.0; -1.0 4.0]), meta = NoCorrection())
end
end
8 changes: 8 additions & 0 deletions test/rules/dot_product/out_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,12 @@
)
]
end

@testset "Error belief Propagation: (m_in1::NormalDistributionsFamily, m_in2::NormalDistributionsFamily)" begin
@test_throws r"The rule for the dot product node between two NormalDistributionsFamily instances is not available in closed form. Please use SoftDot instead." @call_rule typeof(
dot
)(
:out, Marginalisation
) (m_in1 = NormalMeanVariance(2.0, 2.0), m_in2 = NormalMeanVariance(2.0, 2.0), meta = NoCorrection())
end
end
21 changes: 21 additions & 0 deletions test/rules/softdot/gamma_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,25 @@
)
end
end # testset: mean-field

@testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
@test_rules [check_type_promotion = true] SoftDot(, Marginalisation) [
(input = (q_y_x = MvNormalMeanCovariance(ones(2), diageye(2)), q_θ = NormalMeanPrecision(1.0, 1.0)), output = GammaShapeRate(3 / 2, 2.0)),
(input = (q_y_x = MvNormalMeanCovariance(2 * ones(2), diageye(2)), q_θ = NormalMeanPrecision(2.0, 1.0)), output = GammaShapeRate(3 / 2, 7.0))
]
end

@testset "Structured : (q_y_x::MultivariateNormalDistributionsFamily, q_θ::Any)" begin
order = 2
@test_rules [check_type_promotion = true] SoftDot(, Marginalisation) [
(
input = (q_y_x = MvNormalMeanCovariance(ones(order + 1), diageye(order + 1)), q_θ = MvNormalMeanPrecision(ones(order), diageye(order))),
output = GammaShapeRate(3 / 2, 4.0)
),
(
input = (q_y_x = MvNormalMeanCovariance(ones(order + 1), diageye(order + 1)), q_θ = MvNormalMeanPrecision(zeros(order), diageye(order))),
output = GammaShapeRate(3 / 2, 3.0)
)
]
end
end # testset
26 changes: 26 additions & 0 deletions test/rules/softdot/test_marginals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

@testitem "marginalrules:SoftDot" begin
using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions

import ReactiveMP: @test_marginalrules

@testset "y_x: (m_y::UnivariateNormalDistributionsFamily, m_x::UnivariateNormalDistributionsFamily, q_θ::UnivariateNormalDistributionsFamily, q_γ::Any)" begin
@test_marginalrules [check_type_promotion = true] SoftDot(:y_x) [(
input = (m_y = NormalMeanPrecision(0.0, 1.0), m_x = NormalMeanPrecision(0.0, 1.0), q_θ = NormalMeanPrecision(1.0, 1.0), q_γ = GammaShapeRate(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision(zeros(2), [2.0 -1.0; -1.0 3.0])
)]
end

@testset "y_x: (m_y::UnivariateNormalDistributionsFamily), m_x::MultivariateNormalDistributionsFamily, q_θ::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
order = 2
@test_marginalrules [check_type_promotion = true] SoftDot(:y_x) [(
input = (
m_y = NormalMeanPrecision(1.0, 1.0),
m_x = MvNormalMeanCovariance(ones(order), diageye(order)),
q_θ = MvNormalMeanCovariance(ones(order), diageye(order)),
q_γ = GammaShapeRate(1.0, 1.0)
),
output = MvNormalWeightedMeanPrecision(ones(3), [2.0 -1.0 -1.0; -1.0 3.0 1.0; -1.0 1.0 3.0])
)]
end
end
21 changes: 21 additions & 0 deletions test/rules/softdot/theta_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,25 @@
end
# NOTE: γ can theoretically be Any, so also NormalMeanVariance
end

@testset "Structured: (q_y_x::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
@test_rules [check_type_promotion = true] SoftDot(, Marginalisation) [
(input = (q_y_x = MvNormalMeanCovariance(ones(2), diageye(2)), q_γ = GammaShapeRate(1.0, 1.0)), output = NormalWeightedMeanPrecision(1.0, 2.0)),
(input = (q_y_x = MvNormalMeanCovariance(2 * ones(2), diageye(2)), q_γ = GammaShapeScale(2.0, 1.0)), output = NormalWeightedMeanPrecision(8.0, 10.0))
]
end

@testset "Structured : (q_y_x::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
order = 2
@test_rules [check_type_promotion = true] SoftDot(, Marginalisation) [
(
input = (q_y_x = MvNormalMeanCovariance(ones(order + 1), diageye(order + 1)), q_γ = GammaShapeRate(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision(ones(order), [2.0 1.0; 1.0 2.0])
),
(
input = (q_y_x = MvNormalMeanCovariance(zeros(order + 1), diageye(order + 1)), q_γ = GammaShapeRate(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision(zeros(order), [1.0 0.0; 0.0 1.0])
)
]
end
end # testset
30 changes: 30 additions & 0 deletions test/rules/softdot/x_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,34 @@
end
# NOTE: γ can theoretically be Any, so also NormalMeanVariance
end

@testset "VMP: structured rules" begin
@testset "(m_y::NormalMeanVariance, q_θ::NormalMeanVariance, q_γ::Any)" begin
@test_rules [check_type_promotion = true] SoftDot(:x, Marginalisation) [
(input = (m_y = NormalMeanVariance(1.0, 1.0), q_θ = NormalMeanVariance(1.0, 1.0), q_γ = GammaShapeRate(1.0, 1.0)), output = NormalWeightedMeanPrecision(0.5, 1.5)),
(
input = (m_y = NormalWeightedMeanPrecision(1.0, 1.0), q_θ = NormalMeanPrecision(1.0, 2.0), q_γ = GammaShapeScale(1.0, 1.0)),
output = NormalWeightedMeanPrecision(0.5, 1.0)
)
]
end

@testset "(m_y::UnivariateNormalDistributionsFamily, q_θ::MultivariateNormalDistributionsFamily, q_γ::Any)" begin
order = 2
@test_rules [check_type_promotion = false] SoftDot(:x, Marginalisation) [
(
input = (m_y = NormalMeanVariance(0.0, 1.0), q_θ = MvNormalMeanCovariance(ones(order), diageye(order)), q_γ = GammaShapeRate(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision([0.0, 0.0], [1.5 0.5; 0.5 1.5])
),
(
input = (m_y = NormalMeanVariance(1.0, 1.0), q_θ = MvNormalMeanCovariance(zeros(order), diageye(order)), q_γ = GammaShapeScale(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision([0.0, 0.0], [1.0 0.0; 0.0 1.0])
),
(
input = (m_y = NormalMeanVariance(1.0, 1.0), q_θ = MvNormalMeanCovariance(ones(order), diageye(order)), q_γ = Gamma(1.0, 1.0)),
output = MvNormalWeightedMeanPrecision([0.5, 0.5], [1.5 0.5; 0.5 1.5])
)
]
end
end
end # testset
33 changes: 33 additions & 0 deletions test/rules/softdot/y_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,37 @@
)
end
end

@testset "VMP: structured rules" begin
@testset "(q_θ::NormalMeanVariance, m_x::NormalMeanVariance, q_γ::Any" begin
@test_rules [check_type_promotion = true] SoftDot(:y, Marginalisation) [
(input = (q_θ = PointMass(3.0), q_x = PointMass(11.0), q_γ = GammaShapeRate(7.0, 5.0)), output = NormalMeanPrecision(33.0, 1.4)),
(input = (q_θ = PointMass(3.0), q_x = PointMass(11.0), q_γ = GammaShapeScale(7.0, 5.0)), output = NormalMeanPrecision(33.0, 35.0))
]

@test_rules [check_type_promotion = true] SoftDot(:y, Marginalisation) [
(input = (m_x = NormalMeanVariance(1.0, 1.0), q_θ = NormalMeanVariance(1.0, 1.0), q_γ = GammaShapeRate(1.0, 1.0)), output = NormalMeanVariance(0.5, 1.5)),
(
input = (m_x = NormalWeightedMeanPrecision(1.0, 1.0), q_θ = NormalMeanPrecision(1.0, 2.0), q_γ = GammaShapeScale(2.0, 1.0)),
output = NormalMeanVariance(0.5, 1.0)
)
]
end

@testset "(q_θ::MvNormalMeanCovariance, m_x::MvNormalMeanCovariance, q_γ::Any" begin
order = 2
@test_rules [check_type_promotion = true] SoftDot(:y, Marginalisation) [
(
input = (
m_x = MvNormalMeanCovariance(ones(order), diageye(order)), q_θ = MvNormalMeanCovariance(zeros(order), diageye(order)), q_γ = GammaShapeScale(1.0, 1.0)
),
output = NormalMeanVariance(0.0, 1.0)
),
(
input = (m_x = MvNormalMeanCovariance(ones(order), diageye(order)), q_θ = MvNormalMeanCovariance(ones(order), diageye(order)), q_γ = Gamma(1.0, 1.0)),
output = NormalMeanVariance(1.0, 2.0)
)
]
end
end
end # testset

0 comments on commit 060eeaf

Please sign in to comment.