Skip to content

Commit

Permalink
Merge pull request #162 from ReactiveBayes/nodedata-in-properties
Browse files Browse the repository at this point in the history
Nodedata in properties
  • Loading branch information
bvdmitri authored Mar 8, 2024
2 parents f760291 + c636643 commit 5bc5984
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
9 changes: 5 additions & 4 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ Data associated with a factor node in a probabilistic graphical model.
"""
struct FactorNodeProperties
fform::Any
neighbors::Vector{Tuple{NodeLabel, EdgeLabel}}
neighbors::Vector{Tuple{NodeLabel, EdgeLabel, Any}}
end

FactorNodeProperties(; fform, neighbors = Tuple{NodeLabel, EdgeLabel}[]) = FactorNodeProperties(fform, neighbors)
Expand All @@ -496,7 +496,8 @@ end

fform(properties::FactorNodeProperties) = properties.fform
neighbors(properties::FactorNodeProperties) = properties.neighbors
addneighbor!(properties::FactorNodeProperties, variable::NodeLabel, edge::EdgeLabel) = push!(properties.neighbors, (variable, edge))
addneighbor!(properties::FactorNodeProperties, variable::NodeLabel, edge::EdgeLabel, data::F) where {F} = push!(properties.neighbors, (variable, edge, data))
neighbor_data(properties::FactorNodeProperties) = Iterators.map(neighbor -> neighbor[3], neighbors(properties))

function Base.show(io::IO, properties::FactorNodeProperties)
print(io, "fform = ", properties.fform, ", neighbors = ", properties.neighbors)
Expand Down Expand Up @@ -1245,9 +1246,9 @@ function add_edge!(
index
)
label = EdgeLabel(interface_name, index)
neighbor_node_label = unroll(variable_node_id)
# TODO: (bvdmitri) perhaps we should use a different data structure for neighbors, tuples extension might be slow
addneighbor!(factor_node_propeties, unroll(variable_node_id), label)
# factor_node_propeties.neighbors = (factor_node_propeties.neighbors..., (unroll(variable_node_id), label))
addneighbor!(factor_node_propeties, neighbor_node_label, label, model[neighbor_node_label])
model.graph[unroll(variable_node_id), factor_node_id] = label
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,12 +556,11 @@ end

const VariationalConstraintsFactorizationIndicesKey = NodeDataExtraKey{:factorization_constraint_indices, Tuple}()

function materialize_constraints!(model::Model, node_label::NodeLabel, node_data::NodeData, ::FactorNodeProperties)
function materialize_constraints!(model::Model, node_label::NodeLabel, node_data::NodeData, properties::FactorNodeProperties)
constraint_bitset = getextra(node_data, :factorization_constraint_bitset)
num_neighbors = length(constraint_bitset)
for (i, neighbor) in enumerate(GraphPPL.neighbors(model, node_label))
neighbor_data = model[neighbor]
if is_factorized(neighbor_data)
for (i, neighbor) in enumerate(neighbor_data(properties))
if is_factorized(neighbor)
intersect_constraint_bitset!(node_data, constant_constraint(num_neighbors, i))
end
end
Expand Down Expand Up @@ -764,10 +763,10 @@ end
function convert_to_bitsets(model::Model, node::NodeLabel, neighbors, constraint::ResolvedFactorizationConstraint)
result = BoundedBitSetTuple(length(neighbors))
for (i, v1) in enumerate(neighbors)
for (j, v2) in enumerate(view(neighbors, (i + 1):lastindex(neighbors)))
if is_decoupled(v1, v2, constraint)
delete!(result, i, j + i)
delete!(result, j + i, i)
for (j, v2) in enumerate(neighbors)
if j > i && is_decoupled(v1, v2, constraint)
delete!(result, i, j)
delete!(result, j, i)
end
end
end
Expand Down Expand Up @@ -810,16 +809,16 @@ function apply_constraints!(
constraint_set::Union{Constraints, UnspecifiedConstraints},
resolved_factorization_constraints::ConstraintStack
)
for fc in factorization_constraints(constraint_set)
foreach(factorization_constraints(constraint_set)) do fc
push!(resolved_factorization_constraints, resolve(model, context, fc), context)
end
for ffc in posterior_form_constraints(constraint_set)
foreach(posterior_form_constraints(constraint_set)) do ffc
apply_constraints!(model, context, ffc)
end
for mc in message_form_constraints(constraint_set)
foreach(message_form_constraints(constraint_set)) do mc
apply_constraints!(model, context, mc)
end
for rfc in constraints(resolved_factorization_constraints)
foreach(constraints(resolved_factorization_constraints)) do rfc
apply_constraints!(model, context, rfc)
end
for (factor_id, child) in pairs(children(context))
Expand Down Expand Up @@ -866,7 +865,7 @@ function apply_constraints!(
constraint::ResolvedFactorizationConstraint
)
# Get data for the neighbors of the node and check if the constraint is applicable
neighbors = model[GraphPPL.neighbors(model, node)]
neighbors = neighbor_data(node_properties)
if is_applicable(neighbors, constraint)
constraint = convert_to_bitsets(model, node, neighbors, constraint)
intersect_constraint_bitset!(node_data, constraint)
Expand Down

0 comments on commit 5bc5984

Please sign in to comment.