Skip to content

Commit

Permalink
Merge pull request #351 from biaslab/dev-matrix-correction-tools
Browse files Browse the repository at this point in the history
Use MatrixCorrectionTools.jl
  • Loading branch information
bvdmitri authored Sep 26, 2023
2 parents db592a4 + cc70acd commit fcfc9c1
Show file tree
Hide file tree
Showing 23 changed files with 203 additions and 303 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactiveMP"
uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3"
authors = ["Dmitry Bagaev <d.v.bagaev@tue.nl>", "Albert Podusenko <a.podusenko@tue.nl>", "Bart van Erp <b.v.erp@tue.nl>", "Ismail Senoz <i.senoz@tue.nl>"]
version = "3.10.0"
version = "3.11.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -15,6 +15,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MatrixCorrectionTools = "41f81499-25de-46de-b591-c3cfc21e9eaf"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -49,6 +50,7 @@ HCubature = "1.0.0"
LazyArrays = "0.21, 0.22, 1"
LoopVectorization = "0.12"
MacroTools = "0.5"
MatrixCorrectionTools = "1.2.0"
Optim = "1.0.0"
Optimisers = "0.2"
PositiveFactorizations = "0.2"
Expand Down
5 changes: 3 additions & 2 deletions src/ReactiveMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
module ReactiveMP

# List global dependencies here
using TinyHugeNumbers
using TinyHugeNumbers, MatrixCorrectionTools

import MatrixCorrectionTools: AbstractCorrectionStrategy, correction!

# Reexport `tiny` and `huge` from the `TinyHugeNumbers`
export tiny, huge
Expand All @@ -16,7 +18,6 @@ include("score/counting.jl")

include("helpers/algebra/cholesky.jl")
include("helpers/algebra/companion_matrix.jl")
include("helpers/algebra/correction.jl")
include("helpers/algebra/common.jl")
include("helpers/algebra/permutation_matrix.jl")
include("helpers/algebra/standard_basis_vector.jl")
Expand Down
44 changes: 24 additions & 20 deletions src/constraints/specifications/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,34 @@ See also: [`ConstraintsSpecification`](@ref)
function resolve_meta(metaspec, fform, variables)
symfform = as_node_symbol(fform)

var_names = map(name, TupleTools.flatten(variables))
var_refs = map(resolve_variable_proxy, TupleTools.flatten(variables))
var_refs_names = map(r -> r[1], var_refs)

found = nothing

unrolled_foreach(getentries(metaspec)) do fentry
# We iterate over all entries in the meta specification
if functionalform(fentry) === symfform && (all(s -> s var_names, getnames(fentry)) || all(s -> s var_refs_names, getnames(fentry)))
if isnothing(found)
# if we find an appropriate meta spec we simply set it
found = fentry
elseif !isnothing(found) && issubset(getnames(fentry), getnames(found)) && issubset(getnames(found), getnames(fentry))
# The error case is the meta specification collision, two sets of names are exactly the same
error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).")
elseif !isnothing(found) && issubset(getnames(fentry), getnames(found))
# If we find another matching meta spec, but it has fewer names in it we simply keep the previous one
nothing
elseif !isnothing(found) && issubset(getnames(found), getnames(fentry))
# If we find another matching meta spec, and it has more names we override the previous one
found = fentry
elseif !isnothing(found) && !issubset(getnames(fentry), getnames(found)) && !issubset(getnames(found), getnames(fentry))
# The error case is the meta specification collision, two sets of names are different and do not include each other
error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).")
if functionalform(fentry) === symfform
# The `var_names` & `var_refs_names` should be done only if we hit the required entry
# otherwise it would be too error prone, because many nodes cannot properly resolve their `var_names` (e.g. deterministic nodes with more than one input)
# but there might be no meta specification for such nodes, currently the algorithm recompute those for each hit, this can probably be improved
local var_names = map(name, TupleTools.flatten(variables))
local var_refs = map(resolve_variable_proxy, TupleTools.flatten(variables))
local var_refs_names = map(r -> r[1], var_refs)
if (all(s -> s var_names, getnames(fentry)) || all(s -> s var_refs_names, getnames(fentry)))
if isnothing(found)
# if we find an appropriate meta spec we simply set it
found = fentry
elseif !isnothing(found) && issubset(getnames(fentry), getnames(found)) && issubset(getnames(found), getnames(fentry))
# The error case is the meta specification collision, two sets of names are exactly the same
error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).")
elseif !isnothing(found) && issubset(getnames(fentry), getnames(found))
# If we find another matching meta spec, but it has fewer names in it we simply keep the previous one
nothing
elseif !isnothing(found) && issubset(getnames(found), getnames(fentry))
# If we find another matching meta spec, and it has more names we override the previous one
found = fentry
elseif !isnothing(found) && !issubset(getnames(fentry), getnames(found)) && !issubset(getnames(found), getnames(fentry))
# The error case is the meta specification collision, two sets of names are different and do not include each other
error("Ambigous meta object resolution for the node $(fform). Check $(found) and $(fentry).")
end
end
end
end
Expand Down
102 changes: 0 additions & 102 deletions src/helpers/algebra/correction.jl

This file was deleted.

4 changes: 2 additions & 2 deletions src/nodes/dot_product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ import LinearAlgebra: dot

@node typeof(dot) Deterministic [out, in1, in2]

# By default dot-product node uses TinyCorrection() strategy for precision matrix on `in1` and `in2` edges to ensure precision is always invertible
default_meta(::typeof(dot)) = TinyCorrection()
# By default dot-product node uses `MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)` strategy for precision matrix on `in1` and `in2` edges to ensure precision is always invertible
default_meta(::typeof(dot)) = MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)
4 changes: 2 additions & 2 deletions src/nodes/multiplication.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@node typeof(*) Deterministic [out, A, in]

# By default multiplication node uses TinyCorrection() strategy for precision matrix on `in` edge to ensure precision is always invertible
default_meta(::typeof(*)) = TinyCorrection()
# By default multiplication node uses `MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)` strategy for precision matrix on `in` edge to ensure precision is always invertible
default_meta(::typeof(*)) = MatrixCorrectionTools.ReplaceZeroDiagonalEntries(tiny)
2 changes: 1 addition & 1 deletion src/rules/dot_product/in1.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin
@rule typeof(dot)(:in1, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(dot)(:in2, Marginalisation) (m_out = m_out, m_in1 = m_in2, meta = meta)
end
2 changes: 1 addition & 1 deletion src/rules/dot_product/in2.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@rule typeof(dot)(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::PointMass, meta::AbstractCorrection) = begin
@rule typeof(dot)(:in2, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_in1::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in1)
out_wmean, out_prec = weightedmean_precision(m_out)

Expand Down
4 changes: 2 additions & 2 deletions src/rules/dot_product/marginals.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin
@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin

# Forward message towards `in2` edge
mf_in2 = @call_rule typeof(dot)(:in2, Marginalisation) (m_out = m_out, m_in1 = m_in1, meta = meta)
Expand All @@ -8,7 +8,7 @@
return convert_paramfloattype((in1 = m_in1, in2 = q_in2))
end

@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin
@marginalrule typeof(dot)(:in1_in2) (m_out::NormalDistributionsFamily, m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
symmetric = @call_marginalrule typeof(dot)(:in1_in2) (m_out = m_out, m_in1 = m_in2, m_in2 = m_in1, meta = meta)
return convert_paramfloattype((in1 = symmetric[:in2], in2 = symmetric[:in1]))
end
4 changes: 2 additions & 2 deletions src/rules/dot_product/out.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

@rule typeof(dot)(:out, Marginalisation) (m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::AbstractCorrection) = begin
@rule typeof(dot)(:out, Marginalisation) (m_in1::NormalDistributionsFamily, m_in2::PointMass, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
return @call_rule typeof(dot)(:out, Marginalisation) (m_in1 = m_in2, m_in2 = m_in1, meta = meta)
end

@rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin
@rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::Union{AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in1)
in2_mean, in2_cov = mean_cov(m_in2)
return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A))
Expand Down
20 changes: 11 additions & 9 deletions src/rules/multiplication/A.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@

@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrection, Nothing}) = PointMass(mean(m_in) \ mean(m_out))
@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = PointMass(mean(m_in) \ mean(m_out))

@rule typeof(*)(:A, Marginalisation) (m_out::GammaDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::GammaDistributionsFamily, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
return GammaShapeRate(shape(m_out), rate(m_out) * mean(m_in))
end

# if A is a matrix, then the result is multivariate
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, A' * W_out * A)
Expand All @@ -15,23 +15,23 @@ end

# if A is a vector, then the result is univariate
# this rule links to the special case (AbstractVector * Univariate) for forward (:out) rule
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, dot(A, W_out, A))
return NormalWeightedMeanPrecision(dot(A, ξ_out), W)
end

# if A is a scalar, then the input is either univariate or multivariate
@rule typeof(*)(:A, Marginalisation) (m_out::F, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrection, Nothing}) where {F <: NormalDistributionsFamily} = begin
@rule typeof(*)(:A, Marginalisation) (m_out::F, m_in::PointMass{<:Real}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) where {F <: NormalDistributionsFamily} = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, A^2 * W_out)
return convert(promote_variate_type(F, NormalWeightedMeanPrecision), A * ξ_out, W)
end

# specialized versions for mean-covariance parameterization
@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractMatrix}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in)
μ_out, Σ_out = mean_cov(m_out)

Expand All @@ -42,7 +42,7 @@ end
return MvNormalWeightedMeanPrecision(tmp * μ_out, W)
end

@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::MvNormalMeanCovariance, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
A = mean(m_in)
μ_out, Σ_out = mean_cov(m_out)

Expand All @@ -53,14 +53,16 @@ end
return NormalWeightedMeanPrecision(dot(tmp, μ_out), W)
end

@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (
m_out::UnivariateGaussianDistributionsFamily, m_in::UnivariateGaussianDistributionsFamily, meta::Union{<:AbstractCorrectionStrategy, Nothing}
) = begin
μ_in, var_in = mean_var(m_in)
μ_out, var_out = mean_var(m_out)
log_backwardpass = (x) -> -log(abs(x)) - 0.5 * log(2π * (var_in + var_out / x^2)) - 1 / 2 * (μ_out - x * μ_in)^2 / (var_in * x^2 + var_out)
return ContinuousUnivariateLogPdf(log_backwardpass)
end

@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateDistribution, m_in::UnivariateDistribution, meta::Union{<:AbstractCorrection, Nothing}) = begin
@rule typeof(*)(:A, Marginalisation) (m_out::UnivariateDistribution, m_in::UnivariateDistribution, meta::Union{<:AbstractCorrectionStrategy, Nothing}) = begin
nsamples = 3000
samples_in = rand(m_in, nsamples)
p = make_inversedist_message(samples_in, m_out)
Expand Down
Loading

2 comments on commit fcfc9c1

@bvdmitri
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/92259

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.11.0 -m "<description of version>" fcfc9c1e682d0ead4d23e627d1f07ca38f268a93
git push origin v3.11.0

Please sign in to comment.