diff --git a/src/plugins/variational_constraints/variational_constraints.jl b/src/plugins/variational_constraints/variational_constraints.jl index 41e6f863..e3787ee0 100644 --- a/src/plugins/variational_constraints/variational_constraints.jl +++ b/src/plugins/variational_constraints/variational_constraints.jl @@ -11,20 +11,12 @@ using DataStructures MeanField Generic factorisation constraint used to specify a mean-field factorisation for recognition distribution `q`. +This constraint ignores `default_constraints` from submodels and forces everything to be factorized. See also: [`BetheFactorisation`](@ref) """ struct MeanField end -""" - BetheFactorisation - -Generic factorisation constraint used to specify the Bethe factorisation for recognition distribution `q`. - -See also: [`MeanField`](@ref) -""" -struct BetheFactorization end - include("variational_constraints_macro.jl") include("variational_constraints_engine.jl") @@ -37,10 +29,22 @@ struct VariationalConstraintsPlugin{C} constraints::C end -const EmptyConstraints = UnspecifiedConstraints() +const UnspecifiedConstraints = Constraints((), (), (), (;), (;)) + +default_constraints(::Any) = UnspecifiedConstraints + +""" + BetheFactorization + +Generic factorisation constraint used to specify the Bethe factorisation for recognition distribution `q`. +An alias to `UnspecifiedConstraints`. + +See also: [`MeanField`](@ref) +""" +BetheFactorization() = UnspecifiedConstraints -VariationalConstraintsPlugin() = VariationalConstraintsPlugin(EmptyConstraints) -VariationalConstraintsPlugin(::Nothing) = VariationalConstraintsPlugin(EmptyConstraints) +VariationalConstraintsPlugin() = VariationalConstraintsPlugin(UnspecifiedConstraints) +VariationalConstraintsPlugin(::Nothing) = VariationalConstraintsPlugin(UnspecifiedConstraints) GraphPPL.plugin_type(::VariationalConstraintsPlugin) = FactorAndVariableNodesPlugin() diff --git a/src/plugins/variational_constraints/variational_constraints_engine.jl b/src/plugins/variational_constraints/variational_constraints_engine.jl index f4670ba9..c70611b9 100644 --- a/src/plugins/variational_constraints/variational_constraints_engine.jl +++ b/src/plugins/variational_constraints/variational_constraints_engine.jl @@ -271,12 +271,12 @@ getconstraint(c::SpecificSubModelConstraints) = c.constraints An instance of `Constraints` represents a set of constraints to be applied to a variational posterior in a factor graph model. """ -struct Constraints - factorization_constraints::Vector{FactorizationConstraint} - posterior_form_constraints::Vector{PosteriorFormConstraint} - message_form_constraints::Vector{MessageFormConstraint} - general_submodel_constraints::Dict{Function, GeneralSubModelConstraints} - specific_submodel_constraints::Dict{FactorID, SpecificSubModelConstraints} +struct Constraints{F, P, M, G, S} + factorization_constraints::F + posterior_form_constraints::P + message_form_constraints::M + general_submodel_constraints::G + specific_submodel_constraints::S end factorization_constraints(c::Constraints) = c.factorization_constraints @@ -287,9 +287,9 @@ specific_submodel_constraints(c::Constraints) = c.specific_submodel_constraints function Constraints() return Constraints( - Vector{FactorizationConstraint}[], - Vector{PosteriorFormConstraint}[], - Vector{MessageFormConstraint}[], + Vector{FactorizationConstraint}(), + Vector{PosteriorFormConstraint}(), + Vector{MessageFormConstraint}(), Dict{Function, GeneralSubModelConstraints}(), Dict{FactorID, SpecificSubModelConstraints}() ) @@ -369,16 +369,6 @@ getconstraints(c::Constraints) = Iterators.flatten(( Base.push!(c_set::GeneralSubModelConstraints, c) = push!(getconstraint(c_set), c) Base.push!(c_set::SpecificSubModelConstraints, c) = push!(getconstraint(c_set), c) -struct UnspecifiedConstraints end - -factorization_constraints(::UnspecifiedConstraints) = () -posterior_form_constraints(::UnspecifiedConstraints) = () -message_form_constraints(::UnspecifiedConstraints) = () -general_submodel_constraints(::UnspecifiedConstraints) = (;) -specific_submodel_constraints(::UnspecifiedConstraints) = (;) - -default_constraints(::Any) = UnspecifiedConstraints() - struct ResolvedIndexedVariable{T} variable::IndexedVariable{T} context::Context @@ -491,30 +481,19 @@ end Base.iterate(stack::ConstraintStack, state = 1) = iterate(constraints(stack), state) -function intersect_constraint_bitset!(nodedata::NodeData, constraint_data::BoundedBitSetTuple) - constraint = getextra(nodedata, VariationalConstraintsFactorizationBitSetKey)::BoundedBitSetTuple - intersect!(constraint, constraint_data) - return constraint -end - -function constant_constraint(num_neighbors::Int, index_constant::Int) - constraint = BoundedBitSetTuple(num_neighbors) - constraint[index_constant, :] = false - constraint[:, index_constant] = false - constraint[index_constant, index_constant] = true - return constraint -end - -function mean_field_constraint(num_neighbors::Int) - constraint = BoundedBitSetTuple(zeros, num_neighbors) - for i in 1:num_neighbors +function mean_field_constraint!(constraint::BoundedBitSetTuple) + fill!(contents(constraint), false) + for i in 1:length(constraint) constraint[i, i] = true end return constraint end -function mean_field_constraint(num_neighbors::Int, referenced_indices::NTuple{N, Int} where {N}) - constraint = BoundedBitSetTuple(num_neighbors) +function mean_field_constraint!(constraint::BoundedBitSetTuple, index::Int) + return mean_field_constraint!(constraint, (index,)) +end + +function mean_field_constraint!(constraint::BoundedBitSetTuple, referenced_indices::NTuple{N, Int} where {N}) for i in referenced_indices constraint[i, :] = false constraint[:, i] = false @@ -559,12 +538,9 @@ const VariationalConstraintsFactorizationBitSetKey = NodeDataExtraKey{:factoriza function materialize_constraints!(model::Model, node_label::NodeLabel, node_data::NodeData, properties::FactorNodeProperties) constraint_bitset = getextra(node_data, VariationalConstraintsFactorizationBitSetKey) - num_neighbors = length(constraint_bitset) - for (i, neighbor) in enumerate(neighbor_data(properties)) - if is_factorized(neighbor) - intersect_constraint_bitset!(node_data, constant_constraint(num_neighbors, i)) - end - end + + # Factorize out `neighbors` for which `is_factorized` is `true` + materialize_is_factorized_neighbors!(constraint_bitset, neighbor_data(properties)) constraint_set = unique(eachcol(contents(constraint_bitset))) @@ -573,10 +549,20 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data lazy"Factorization constraint set at node $node_label is not a valid constraint set. Please check your model definition and constraint specification. (Constraint set: $constraint_bitset)" ) end + rows = Tuple(map(row -> filter(!iszero, map(elem -> elem[2] == 1 ? elem[1] : 0, enumerate(row))), constraint_set)) setextra!(node_data, VariationalConstraintsFactorizationIndicesKey, rows) end +function materialize_is_factorized_neighbors!(constraint_bitset::BoundedBitSetTuple, neighbors) + for (i, neighbor) in enumerate(neighbors) + if is_factorized(neighbor) + mean_field_constraint!(constraint_bitset, i) + end + end + return constraint_bitset +end + function is_valid_partition(contents) max_element = length(first(contents)) for element in 1:max_element @@ -776,8 +762,6 @@ function convert_to_bitsets(model::Model, node::NodeLabel, neighbors, constraint return result end -apply_constraints!(model::Model, context::Context, constraints) = apply_constraints!(model, context, constraints, ConstraintStack()) - function apply_constraints!( model::Model, context::Context, posterior_constraint::PosteriorFormConstraint{T, F} where {T <: IndexedVariable, F} ) @@ -808,25 +792,22 @@ function apply_constraints!(model::Model, context::Context, message_constraint:: end end -function apply_constraints!(model::Model, context::Context, constraint::MeanField) - foreach(filter(as_node(), model)) do node - data = model[node] - intersect_constraint_bitset!(data, mean_field_constraint(length(neighbor_data(getproperties(data))))) - end +function apply_constraints!(model::Model, context::Context, constraints) + return apply_constraints!(model, context, constraints, ConstraintStack()) end -function apply_constraints!(model::Model, context::Context, constraint::BetheFactorization) - nothing # Change if the Bethe Factorization is no longer the default factorization +# Mean-field constraint simply applies the entire mean-field factorization to all the nodes in the model +# Ignores `default_constraints` from the submodels and forces everything to be `MeanField` +function apply_constraints!(model::Model, ::Context, ::MeanField, ::ConstraintStack) + factor_nodes(model) do _, data + constraint_bitset = getextra(data, VariationalConstraintsFactorizationBitSetKey) + mean_field_constraint!(constraint_bitset) + end end -function apply_constraints!( - model::Model, - context::Context, - constraint_set::Union{Constraints, UnspecifiedConstraints}, - resolved_factorization_constraints::ConstraintStack -) +function apply_constraints!(model::Model, context::Context, constraint_set::Constraints, stack::ConstraintStack) foreach(factorization_constraints(constraint_set)) do fc - push!(resolved_factorization_constraints, resolve(model, context, fc), context) + push!(stack, resolve(model, context, fc), context) end foreach(posterior_form_constraints(constraint_set)) do ffc apply_constraints!(model, context, ffc) @@ -834,23 +815,19 @@ function apply_constraints!( foreach(message_form_constraints(constraint_set)) do mc apply_constraints!(model, context, mc) end - foreach(constraints(resolved_factorization_constraints)) do rfc + foreach(constraints(stack)) do rfc apply_constraints!(model, context, rfc) end for (factor_id, child) in pairs(children(context)) if factor_id ∈ keys(specific_submodel_constraints(constraint_set)) - apply_constraints!( - model, child, getconstraint(specific_submodel_constraints(constraint_set)[factor_id]), resolved_factorization_constraints - ) + apply_constraints!(model, child, getconstraint(specific_submodel_constraints(constraint_set)[factor_id]), stack) elseif fform(factor_id) ∈ keys(general_submodel_constraints(constraint_set)) - apply_constraints!( - model, child, getconstraint(general_submodel_constraints(constraint_set)[fform(child)]), resolved_factorization_constraints - ) + apply_constraints!(model, child, getconstraint(general_submodel_constraints(constraint_set)[fform(child)]), stack) else - apply_constraints!(model, child, default_constraints(fform(factor_id)), resolved_factorization_constraints) + apply_constraints!(model, child, default_constraints(fform(factor_id)), stack) end end - while pop!(resolved_factorization_constraints, context) + while pop!(stack, context) continue end end @@ -882,9 +859,9 @@ function apply_constraints!( ) # Get data for the neighbors of the node and check if the constraint is applicable neighbors = neighbor_data(node_properties) + constraint_bitset = getextra(node_data, VariationalConstraintsFactorizationBitSetKey) if is_applicable(neighbors, constraint) - constraint = convert_to_bitsets(model, node, neighbors, constraint) - intersect_constraint_bitset!(node_data, constraint) + intersect!(constraint_bitset, convert_to_bitsets(model, node, neighbors, constraint)) end return nothing end diff --git a/test/plugins/variational_constraints/variational_constraints_engine_tests.jl b/test/plugins/variational_constraints/variational_constraints_engine_tests.jl index b3e24563..6745eb6a 100644 --- a/test/plugins/variational_constraints/variational_constraints_engine_tests.jl +++ b/test/plugins/variational_constraints/variational_constraints_engine_tests.jl @@ -504,14 +504,6 @@ end end end -@testitem "constant_constraint" begin - using BitSetTuples - import GraphPPL: constant_constraint - - @test tupled_contents(constant_constraint(1, 1)) == ((1,),) - @test tupled_contents(constant_constraint(5, 3)) == ((1, 2, 4, 5), (1, 2, 4, 5), (3,), (1, 2, 4, 5), (1, 2, 4, 5)) -end - @testitem "Application of PosteriorFormConstraint" begin import GraphPPL: PosteriorFormConstraint, IndexedVariable, apply_constraints!, getextra, hasextra @@ -597,32 +589,34 @@ end end end -@testitem "save constraints with constants" begin +@testitem "save constraints with constants via `mean_field_constraint!`" begin include("../../model_zoo.jl") using BitSetTuples using GraphPPL import GraphPPL: - intersect_constraint_bitset!, - constant_constraint, - factorization_constraint, + getextra, + mean_field_constraint!, getproperties, VariationalConstraintsPlugin, - PluginsCollection + PluginsCollection, + VariationalConstraintsFactorizationBitSetKey model = create_terminated_model(simple_model; plugins = GraphPPL.PluginsCollection(VariationalConstraintsPlugin())) ctx = GraphPPL.getcontext(model) - @test tupled_contents(constant_constraint(3, 1)) == ((1,), (2, 3), (2, 3)) - @test tupled_contents(constant_constraint(3, 2)) == ((1, 3), (2,), (1, 3)) - @test tupled_contents(constant_constraint(3, 3)) == ((1, 2), (1, 2), (3,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 1)) == ((1,), (2, 3), (2, 3)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 2)) == ((1, 3), (2,), (1, 3)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(3), 3)) == ((1, 2), (1, 2), (3,)) node = ctx[NormalMeanVariance, 2] - @test tupled_contents(intersect_constraint_bitset!(model[node], constant_constraint(3, 1))) == ((1,), (2, 3), (2, 3)) - @test tupled_contents(intersect_constraint_bitset!(model[node], constant_constraint(3, 2))) == ((1,), (2,), (3,)) + constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2, 3), (2, 3)) + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 2))) == ((1,), (2,), (3,)) node = ctx[NormalMeanVariance, 1] + constraint_bitset = getextra(model[node], VariationalConstraintsFactorizationBitSetKey) # Here it is the mean field because the original model has `x ~ Normal(0, 1)` and `0` and `1` are constants - @test tupled_contents(intersect_constraint_bitset!(model[node], constant_constraint(3, 1))) == ((1,), (2,), (3,)) + @test tupled_contents(intersect!(constraint_bitset, mean_field_constraint!(BoundedBitSetTuple(3), 1))) == ((1,), (2,), (3,)) end @testitem "materialize_constraints!(:Model, ::NodeLabel, ::FactorNodeData)" begin @@ -1023,17 +1017,16 @@ end @testitem "default_constraints" begin import GraphPPL: default_constraints, - factorization_constraint, getproperties, PluginsCollection, VariationalConstraintsPlugin, hasextra, getextra, - EmptyConstraints + UnspecifiedConstraints include("../../model_zoo.jl") - @test default_constraints(simple_model) == EmptyConstraints + @test default_constraints(simple_model) == UnspecifiedConstraints @test default_constraints(model_with_default_constraints) == @constraints( begin q(a, d) = q(a)q(d) @@ -1082,82 +1075,20 @@ end end end -@testitem "mean_field_constraint" begin +@testitem "mean_field_constraint!" begin using BitSetTuples - import GraphPPL: mean_field_constraint - - @test tupled_contents(mean_field_constraint(5)) == ((1,), (2,), (3,), (4,), (5,)) - @test tupled_contents(mean_field_constraint(10)) == ((1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)) - - @test tupled_contents(mean_field_constraint(1, (1,))) == ((1,),) - @test tupled_contents(mean_field_constraint(2, (1,))) == ((1,), (2,)) - @test tupled_contents(mean_field_constraint(2, (2,))) == ((1,), (2,)) - @test tupled_contents(mean_field_constraint(5, (1, 3, 5))) == ((1,), (2, 4), (3,), (2, 4), (5,)) - @test tupled_contents(mean_field_constraint(5, (1, 2, 3, 4, 5))) == ((1,), (2,), (3,), (4,), (5,)) - @test_throws BoundsError mean_field_constraint(5, (1, 2, 3, 4, 5, 6)) == ((1,), (2,), (3,), (4,), (5,)) - @test tupled_contents(mean_field_constraint(5, (1, 2))) == ((1,), (2,), (3, 4, 5), (3, 4, 5), (3, 4, 5)) -end - -@testitem "Apply MeanField constraints" begin - using GraphPPL - import GraphPPL: getproperties, neighbor_data - - include("../../model_zoo.jl") - - for model_fform in [ - simple_model, - vector_model, - tensor_model, - outer, - multidim_array, - node_with_only_anonymous, - node_with_two_anonymous, - node_with_ambiguous_anonymous, - multidim_array - ] - model = create_terminated_model( - model_fform; plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(MeanField())) - ) - - for node in filter(as_node(), model) - node_data = model[node] - @test GraphPPL.getextra(node_data, :factorization_constraint_indices) == - Tuple([[i] for i in 1:(length(neighbor_data(getproperties(node_data))))]) - end - end -end - -@testitem "Apply BetheFactorization constraints" begin - using GraphPPL - import GraphPPL: getproperties, neighbor_data, is_factorized - - include("../../model_zoo.jl") - - for model_fform in [ - simple_model, - vector_model, - tensor_model, - outer, - multidim_array, - node_with_only_anonymous, - node_with_two_anonymous, - node_with_ambiguous_anonymous, - multidim_array - ] - model = create_terminated_model( - model_fform; plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(BetheFactorization())) - ) - - for node in filter(as_node(), model) - node_data = model[node] - neighbors_data = neighbor_data(getproperties(node_data)) - factorized_neighbors = is_factorized.(neighbors_data) - new_constraint = [findall(!, factorized_neighbors)] - for j in findall(identity, factorized_neighbors) - push!(new_constraint, [j]) - end - sort!(new_constraint, by = first) - @test GraphPPL.getextra(node_data, :factorization_constraint_indices) == Tuple(new_constraint) - end - end + import GraphPPL: mean_field_constraint! + + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5))) == ((1,), (2,), (3,), (4,), (5,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(10))) == ((1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)) + + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(1), 1)) == ((1,),) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), 3)) == ((1, 2, 4, 5), (1, 2, 4, 5), (3,), (1, 2, 4, 5), (1, 2, 4, 5)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(1), (1,))) == ((1,),) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(2), (1,))) == ((1,), (2,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(2), (2,))) == ((1,), (2,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 2))) == ((1,), (2,), (3, 4, 5), (3, 4, 5), (3, 4, 5)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 3, 5))) == ((1,), (2, 4), (3,), (2, 4), (5,)) + @test tupled_contents(mean_field_constraint!(BoundedBitSetTuple(5), (1, 2, 3, 4, 5))) == ((1,), (2,), (3,), (4,), (5,)) + @test_throws BoundsError mean_field_constraint!(BoundedBitSetTuple(5), (1, 2, 3, 4, 5, 6)) == ((1,), (2,), (3,), (4,), (5,)) end diff --git a/test/plugins/variational_constraints/variational_constraints_tests.jl b/test/plugins/variational_constraints/variational_constraints_tests.jl index ae92be11..3f0ee499 100644 --- a/test/plugins/variational_constraints/variational_constraints_tests.jl +++ b/test/plugins/variational_constraints/variational_constraints_tests.jl @@ -1,8 +1,8 @@ @testitem "Empty constraints" begin - import GraphPPL: VariationalConstraintsPlugin, EmptyConstraints + import GraphPPL: VariationalConstraintsPlugin, UnspecifiedConstraints - @test VariationalConstraintsPlugin() == VariationalConstraintsPlugin(EmptyConstraints) - @test VariationalConstraintsPlugin(nothing) == VariationalConstraintsPlugin(EmptyConstraints) + @test VariationalConstraintsPlugin() == VariationalConstraintsPlugin(UnspecifiedConstraints) + @test VariationalConstraintsPlugin(nothing) == VariationalConstraintsPlugin(UnspecifiedConstraints) end @testitem "simple @model + various constraints" begin @@ -925,3 +925,67 @@ end end end end + +@testitem "Apply MeanField constraints" begin + using GraphPPL + import GraphPPL: getproperties, neighbor_data + + include("../../model_zoo.jl") + + for model_fform in [ + simple_model, + vector_model, + tensor_model, + outer, + multidim_array, + node_with_only_anonymous, + node_with_two_anonymous, + node_with_ambiguous_anonymous, + multidim_array + ] + model = create_terminated_model( + model_fform; plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(MeanField())) + ) + + for node in filter(as_node(), model) + node_data = model[node] + @test GraphPPL.getextra(node_data, :factorization_constraint_indices) == + Tuple([[i] for i in 1:(length(neighbor_data(getproperties(node_data))))]) + end + end +end + +@testitem "Apply BetheFactorization constraints" begin + using GraphPPL + import GraphPPL: getproperties, neighbor_data, is_factorized + + include("../../model_zoo.jl") + + for model_fform in [ + simple_model, + vector_model, + tensor_model, + outer, + multidim_array, + node_with_only_anonymous, + node_with_two_anonymous, + node_with_ambiguous_anonymous, + multidim_array + ] + model = create_terminated_model( + model_fform; plugins = GraphPPL.PluginsCollection(GraphPPL.VariationalConstraintsPlugin(BetheFactorization())) + ) + + for node in filter(as_node(), model) + node_data = model[node] + neighbors_data = neighbor_data(getproperties(node_data)) + factorized_neighbors = is_factorized.(neighbors_data) + new_constraint = [findall(!, factorized_neighbors)] + for j in findall(identity, factorized_neighbors) + push!(new_constraint, [j]) + end + sort!(new_constraint, by = first) + @test GraphPPL.getextra(node_data, :factorization_constraint_indices) == Tuple(new_constraint) + end + end +end \ No newline at end of file