Skip to content

Commit

Permalink
Merge pull request #161 from ReactiveBayes/dev-extra-compile-time-keys
Browse files Browse the repository at this point in the history
Add compile time keys for the extra properties of graph nodes
  • Loading branch information
wouterwln authored Mar 8, 2024
2 parents d52ab51 + 07c9ed9 commit f760291
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
15 changes: 15 additions & 0 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,21 @@ hasextra(node::NodeData, key::Symbol) = haskey(node.extra, key)
getextra(node::NodeData, key::Symbol) = getindex(node.extra, key)
setextra!(node::NodeData, key::Symbol, value) = insert!(node.extra, key, value)

"""
A compile time key to access the `extra` properties of the `NodeData` structure.
"""
struct NodeDataExtraKey{K, T} end

function hasextra(node::NodeData, key::NodeDataExtraKey{K}) where K
return haskey(node.extra, K)
end
function getextra(node::NodeData, key::NodeDataExtraKey{K, T})::T where {K, T}
return getindex(node.extra, K)::T
end
function setextra!(node::NodeData, key::NodeDataExtraKey{K}, value::T) where {K, T}
return insert!(node.extra, K, value)
end

is_factor(node::NodeData) = is_factor(getproperties(node))
is_variable(node::NodeData) = is_variable(getproperties(node))

Expand Down
6 changes: 4 additions & 2 deletions src/plugins/meta/meta_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ function apply_meta!(model::Model, context::Context, meta::MetaObject{S, T} wher
end
end

const MetaExtraKey = NodeDataExtraKey{:meta, Any}()

function save_meta!(model::Model, node::NodeLabel, meta::MetaObject{S, T} where {S, T <: NamedTuple})
data = getmetainfo(meta)
if !haskey(data, :meta)
Expand All @@ -155,7 +157,7 @@ end

function save_meta!(model::Model, node::NodeLabel, meta::MetaObject{S, T} where {S, T})
nodedata = model[node]
if !hasextra(nodedata, :meta)
setextra!(nodedata, :meta, getmetainfo(meta))
if !hasextra(nodedata, MetaExtraKey)
setextra!(nodedata, MetaExtraKey, getmetainfo(meta))
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data
return materialize_constraints!(model, node_label, node_data, getproperties(node_data))
end

const VariationalConstraintsFactorizationIndicesKey = NodeDataExtraKey{:factorization_constraint_indices, Tuple}()

function materialize_constraints!(model::Model, node_label::NodeLabel, node_data::NodeData, ::FactorNodeProperties)
constraint_bitset = getextra(node_data, :factorization_constraint_bitset)
num_neighbors = length(constraint_bitset)
Expand All @@ -572,7 +574,7 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data
)
end
rows = Tuple(map(row -> filter(!iszero, map(elem -> elem[2] == 1 ? elem[1] : 0, enumerate(row))), constraint_set))
setextra!(node_data, :factorization_constraint_indices, rows)
setextra!(node_data, VariationalConstraintsFactorizationIndicesKey, rows)
end

function is_valid_partition(contents)
Expand Down
38 changes: 38 additions & 0 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ end
end
end

@testitem "NodeData extra properties" begin
import GraphPPL: create_model, getcontext, NodeData, FactorNodeProperties, VariableNodeProperties, getproperties, setextra!, getextra, hasextra, NodeDataExtraKey

model = create_model()
context = getcontext(model)

@testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1))
nodedata = NodeData(context, properties)

@test !hasextra(nodedata, :a)
setextra!(nodedata, :a, 1)
@test hasextra(nodedata, :a)
@test getextra(nodedata, :a) === 1

# In the current implementation it is not possible to update extra properties
@test_throws Exception setextra!(nodedata, :a, 2)

@test !hasextra(nodedata, :b)
setextra!(nodedata, :b, 2)
@test hasextra(nodedata, :b)
@test getextra(nodedata, :b) === 2

constkey_c_float = NodeDataExtraKey{:c, Float64}()

@test !@inferred(hasextra(nodedata, constkey_c_float))
@inferred(setextra!(nodedata, constkey_c_float, 3.0))
@test @inferred(hasextra(nodedata, constkey_c_float))
@test @inferred(getextra(nodedata, constkey_c_float)) === 3.0

constkey_d_int = NodeDataExtraKey{:d, Int64}()

@test !@inferred(hasextra(nodedata, constkey_d_int))
@inferred(setextra!(nodedata, constkey_d_int, 4))
@test @inferred(hasextra(nodedata, constkey_d_int))
@test @inferred(getextra(nodedata, constkey_d_int)) === 4
end
end

@testitem "factor_nodes" begin
import GraphPPL: factor_nodes, is_factor, labels
include("model_zoo.jl")
Expand Down

0 comments on commit f760291

Please sign in to comment.