Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor optimizations and improvements #173

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions src/plugins/variational_constraints/variational_constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,12 @@ using DataStructures
MeanField

Generic factorisation constraint used to specify a mean-field factorisation for recognition distribution `q`.
This constraint ignores `default_constraints` from submodels and forces everything to be factorized.

See also: [`BetheFactorisation`](@ref)
"""
struct MeanField end

"""
BetheFactorisation

Generic factorisation constraint used to specify the Bethe factorisation for recognition distribution `q`.

See also: [`MeanField`](@ref)
"""
struct BetheFactorization end

include("variational_constraints_macro.jl")
include("variational_constraints_engine.jl")

Expand All @@ -37,10 +29,22 @@ struct VariationalConstraintsPlugin{C}
constraints::C
end

const EmptyConstraints = UnspecifiedConstraints()
const UnspecifiedConstraints = Constraints((), (), (), (;), (;))

default_constraints(::Any) = UnspecifiedConstraints

"""
BetheFactorization

Generic factorisation constraint used to specify the Bethe factorisation for recognition distribution `q`.
An alias to `UnspecifiedConstraints`.

See also: [`MeanField`](@ref)
"""
BetheFactorization() = UnspecifiedConstraints

VariationalConstraintsPlugin() = VariationalConstraintsPlugin(EmptyConstraints)
VariationalConstraintsPlugin(::Nothing) = VariationalConstraintsPlugin(EmptyConstraints)
VariationalConstraintsPlugin() = VariationalConstraintsPlugin(UnspecifiedConstraints)
VariationalConstraintsPlugin(::Nothing) = VariationalConstraintsPlugin(UnspecifiedConstraints)

GraphPPL.plugin_type(::VariationalConstraintsPlugin) = FactorAndVariableNodesPlugin()

Expand Down
119 changes: 48 additions & 71 deletions src/plugins/variational_constraints/variational_constraints_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,12 @@ getconstraint(c::SpecificSubModelConstraints) = c.constraints

An instance of `Constraints` represents a set of constraints to be applied to a variational posterior in a factor graph model.
"""
struct Constraints
factorization_constraints::Vector{FactorizationConstraint}
posterior_form_constraints::Vector{PosteriorFormConstraint}
message_form_constraints::Vector{MessageFormConstraint}
general_submodel_constraints::Dict{Function, GeneralSubModelConstraints}
specific_submodel_constraints::Dict{FactorID, SpecificSubModelConstraints}
struct Constraints{F, P, M, G, S}
factorization_constraints::F
posterior_form_constraints::P
message_form_constraints::M
general_submodel_constraints::G
specific_submodel_constraints::S
end

factorization_constraints(c::Constraints) = c.factorization_constraints
Expand All @@ -287,9 +287,9 @@ specific_submodel_constraints(c::Constraints) = c.specific_submodel_constraints

function Constraints()
return Constraints(
Vector{FactorizationConstraint}[],
Vector{PosteriorFormConstraint}[],
Vector{MessageFormConstraint}[],
Vector{FactorizationConstraint}(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually a funny one @wouterwln , not sure who of us made this mistake, but it basically created a vector of vector of constraints instead of just vector of constraints. It somehow worked though (probs some weird convert mechanics)

Vector{PosteriorFormConstraint}(),
Vector{MessageFormConstraint}(),
Dict{Function, GeneralSubModelConstraints}(),
Dict{FactorID, SpecificSubModelConstraints}()
)
Expand Down Expand Up @@ -369,16 +369,6 @@ getconstraints(c::Constraints) = Iterators.flatten((
Base.push!(c_set::GeneralSubModelConstraints, c) = push!(getconstraint(c_set), c)
Base.push!(c_set::SpecificSubModelConstraints, c) = push!(getconstraint(c_set), c)

struct UnspecifiedConstraints end

factorization_constraints(::UnspecifiedConstraints) = ()
posterior_form_constraints(::UnspecifiedConstraints) = ()
message_form_constraints(::UnspecifiedConstraints) = ()
general_submodel_constraints(::UnspecifiedConstraints) = (;)
specific_submodel_constraints(::UnspecifiedConstraints) = (;)

default_constraints(::Any) = UnspecifiedConstraints()

struct ResolvedIndexedVariable{T}
variable::IndexedVariable{T}
context::Context
Expand Down Expand Up @@ -491,30 +481,19 @@ end

Base.iterate(stack::ConstraintStack, state = 1) = iterate(constraints(stack), state)

function intersect_constraint_bitset!(nodedata::NodeData, constraint_data::BoundedBitSetTuple)
constraint = getextra(nodedata, VariationalConstraintsFactorizationBitSetKey)::BoundedBitSetTuple
intersect!(constraint, constraint_data)
return constraint
end

function constant_constraint(num_neighbors::Int, index_constant::Int)
constraint = BoundedBitSetTuple(num_neighbors)
constraint[index_constant, :] = false
constraint[:, index_constant] = false
constraint[index_constant, index_constant] = true
return constraint
end

function mean_field_constraint(num_neighbors::Int)
constraint = BoundedBitSetTuple(zeros, num_neighbors)
for i in 1:num_neighbors
function mean_field_constraint!(constraint::BoundedBitSetTuple)
fill!(contents(constraint), false)
for i in 1:length(constraint)
constraint[i, i] = true
end
return constraint
end

function mean_field_constraint(num_neighbors::Int, referenced_indices::NTuple{N, Int} where {N})
constraint = BoundedBitSetTuple(num_neighbors)
function mean_field_constraint!(constraint::BoundedBitSetTuple, index::Int)
return mean_field_constraint!(constraint, (index,))
end

function mean_field_constraint!(constraint::BoundedBitSetTuple, referenced_indices::NTuple{N, Int} where {N})
for i in referenced_indices
constraint[i, :] = false
constraint[:, i] = false
Expand Down Expand Up @@ -559,12 +538,9 @@ const VariationalConstraintsFactorizationBitSetKey = NodeDataExtraKey{:factoriza

function materialize_constraints!(model::Model, node_label::NodeLabel, node_data::NodeData, properties::FactorNodeProperties)
constraint_bitset = getextra(node_data, VariationalConstraintsFactorizationBitSetKey)
num_neighbors = length(constraint_bitset)
for (i, neighbor) in enumerate(neighbor_data(properties))
if is_factorized(neighbor)
intersect_constraint_bitset!(node_data, constant_constraint(num_neighbors, i))
end
end

# Factorize out `neighbors` for which `is_factorized` is `true`
materialize_is_factorized_neighbors!(constraint_bitset, neighbor_data(properties))

constraint_set = unique(eachcol(contents(constraint_bitset)))

Expand All @@ -573,10 +549,20 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data
lazy"Factorization constraint set at node $node_label is not a valid constraint set. Please check your model definition and constraint specification. (Constraint set: $constraint_bitset)"
)
end

rows = Tuple(map(row -> filter(!iszero, map(elem -> elem[2] == 1 ? elem[1] : 0, enumerate(row))), constraint_set))
setextra!(node_data, VariationalConstraintsFactorizationIndicesKey, rows)
end

function materialize_is_factorized_neighbors!(constraint_bitset::BoundedBitSetTuple, neighbors)
for (i, neighbor) in enumerate(neighbors)
if is_factorized(neighbor)
mean_field_constraint!(constraint_bitset, i)
end
end
return constraint_bitset
end

function is_valid_partition(contents)
max_element = length(first(contents))
for element in 1:max_element
Expand Down Expand Up @@ -776,8 +762,6 @@ function convert_to_bitsets(model::Model, node::NodeLabel, neighbors, constraint
return result
end

apply_constraints!(model::Model, context::Context, constraints) = apply_constraints!(model, context, constraints, ConstraintStack())

function apply_constraints!(
model::Model, context::Context, posterior_constraint::PosteriorFormConstraint{T, F} where {T <: IndexedVariable, F}
)
Expand Down Expand Up @@ -808,49 +792,42 @@ function apply_constraints!(model::Model, context::Context, message_constraint::
end
end

function apply_constraints!(model::Model, context::Context, constraint::MeanField)
foreach(filter(as_node(), model)) do node
data = model[node]
intersect_constraint_bitset!(data, mean_field_constraint(length(neighbor_data(getproperties(data)))))
end
function apply_constraints!(model::Model, context::Context, constraints)
return apply_constraints!(model, context, constraints, ConstraintStack())
end

function apply_constraints!(model::Model, context::Context, constraint::BetheFactorization)
nothing # Change if the Bethe Factorization is no longer the default factorization
# Mean-field constraint simply applies the entire mean-field factorization to all the nodes in the model
# Ignores `default_constraints` from the submodels and forces everything to be `MeanField`
function apply_constraints!(model::Model, ::Context, ::MeanField, ::ConstraintStack)
factor_nodes(model) do _, data
constraint_bitset = getextra(data, VariationalConstraintsFactorizationBitSetKey)
mean_field_constraint!(constraint_bitset)
end
end

function apply_constraints!(
model::Model,
context::Context,
constraint_set::Union{Constraints, UnspecifiedConstraints},
resolved_factorization_constraints::ConstraintStack
)
function apply_constraints!(model::Model, context::Context, constraint_set::Constraints, stack::ConstraintStack)
foreach(factorization_constraints(constraint_set)) do fc
push!(resolved_factorization_constraints, resolve(model, context, fc), context)
push!(stack, resolve(model, context, fc), context)
end
foreach(posterior_form_constraints(constraint_set)) do ffc
apply_constraints!(model, context, ffc)
end
foreach(message_form_constraints(constraint_set)) do mc
apply_constraints!(model, context, mc)
end
foreach(constraints(resolved_factorization_constraints)) do rfc
foreach(constraints(stack)) do rfc
apply_constraints!(model, context, rfc)
end
for (factor_id, child) in pairs(children(context))
if factor_id ∈ keys(specific_submodel_constraints(constraint_set))
apply_constraints!(
model, child, getconstraint(specific_submodel_constraints(constraint_set)[factor_id]), resolved_factorization_constraints
)
apply_constraints!(model, child, getconstraint(specific_submodel_constraints(constraint_set)[factor_id]), stack)
elseif fform(factor_id) ∈ keys(general_submodel_constraints(constraint_set))
apply_constraints!(
model, child, getconstraint(general_submodel_constraints(constraint_set)[fform(child)]), resolved_factorization_constraints
)
apply_constraints!(model, child, getconstraint(general_submodel_constraints(constraint_set)[fform(child)]), stack)
else
apply_constraints!(model, child, default_constraints(fform(factor_id)), resolved_factorization_constraints)
apply_constraints!(model, child, default_constraints(fform(factor_id)), stack)
end
end
while pop!(resolved_factorization_constraints, context)
while pop!(stack, context)
continue
end
end
Expand Down Expand Up @@ -882,9 +859,9 @@ function apply_constraints!(
)
# Get data for the neighbors of the node and check if the constraint is applicable
neighbors = neighbor_data(node_properties)
constraint_bitset = getextra(node_data, VariationalConstraintsFactorizationBitSetKey)
if is_applicable(neighbors, constraint)
constraint = convert_to_bitsets(model, node, neighbors, constraint)
intersect_constraint_bitset!(node_data, constraint)
intersect!(constraint_bitset, convert_to_bitsets(model, node, neighbors, constraint))
end
return nothing
end
Expand Down
Loading
Loading