From 07c9ed93693d214d5a158225bd4594627d02311a Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Fri, 8 Mar 2024 11:41:56 +0100 Subject: [PATCH] Add compile time keys for the extra properties --- src/graph_engine.jl | 15 ++++++++ src/plugins/meta/meta_engine.jl | 6 ++- .../variational_constraints_engine.jl | 4 +- test/graph_engine_tests.jl | 38 +++++++++++++++++++ 4 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/graph_engine.jl b/src/graph_engine.jl index b07c0d8f..bd4987aa 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -537,6 +537,21 @@ hasextra(node::NodeData, key::Symbol) = haskey(node.extra, key) getextra(node::NodeData, key::Symbol) = getindex(node.extra, key) setextra!(node::NodeData, key::Symbol, value) = insert!(node.extra, key, value) +""" +A compile time key to access the `extra` properties of the `NodeData` structure. +""" +struct NodeDataExtraKey{K, T} end + +function hasextra(node::NodeData, key::NodeDataExtraKey{K}) where K + return haskey(node.extra, K) +end +function getextra(node::NodeData, key::NodeDataExtraKey{K, T})::T where {K, T} + return getindex(node.extra, K)::T +end +function setextra!(node::NodeData, key::NodeDataExtraKey{K}, value::T) where {K, T} + return insert!(node.extra, K, value) +end + is_factor(node::NodeData) = is_factor(getproperties(node)) is_variable(node::NodeData) = is_variable(getproperties(node)) diff --git a/src/plugins/meta/meta_engine.jl b/src/plugins/meta/meta_engine.jl index 98e14d48..73b1ed00 100644 --- a/src/plugins/meta/meta_engine.jl +++ b/src/plugins/meta/meta_engine.jl @@ -141,6 +141,8 @@ function apply_meta!(model::Model, context::Context, meta::MetaObject{S, T} wher end end +const MetaExtraKey = NodeDataExtraKey{:meta, Any}() + function save_meta!(model::Model, node::NodeLabel, meta::MetaObject{S, T} where {S, T <: NamedTuple}) data = getmetainfo(meta) if !haskey(data, :meta) @@ -155,7 +157,7 @@ end function save_meta!(model::Model, node::NodeLabel, meta::MetaObject{S, T} where {S, T}) nodedata = model[node] - if !hasextra(nodedata, :meta) - setextra!(nodedata, :meta, getmetainfo(meta)) + if !hasextra(nodedata, MetaExtraKey) + setextra!(nodedata, MetaExtraKey, getmetainfo(meta)) end end diff --git a/src/plugins/variational_constraints/variational_constraints_engine.jl b/src/plugins/variational_constraints/variational_constraints_engine.jl index 093d5cd0..bddf20f4 100644 --- a/src/plugins/variational_constraints/variational_constraints_engine.jl +++ b/src/plugins/variational_constraints/variational_constraints_engine.jl @@ -554,6 +554,8 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data return materialize_constraints!(model, node_label, node_data, getproperties(node_data)) end +const VariationalConstraintsFactorizationIndicesKey = NodeDataExtraKey{:factorization_constraint_indices, Tuple}() + function materialize_constraints!(model::Model, node_label::NodeLabel, node_data::NodeData, ::FactorNodeProperties) constraint_bitset = getextra(node_data, :factorization_constraint_bitset) num_neighbors = length(constraint_bitset) @@ -572,7 +574,7 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data ) end rows = Tuple(map(row -> filter(!iszero, map(elem -> elem[2] == 1 ? elem[1] : 0, enumerate(row))), constraint_set)) - setextra!(node_data, :factorization_constraint_indices, rows) + setextra!(node_data, VariationalConstraintsFactorizationIndicesKey, rows) end function is_valid_partition(contents) diff --git a/test/graph_engine_tests.jl b/test/graph_engine_tests.jl index eadd5f3f..903c5a79 100644 --- a/test/graph_engine_tests.jl +++ b/test/graph_engine_tests.jl @@ -101,6 +101,44 @@ end end end +@testitem "NodeData extra properties" begin + import GraphPPL: create_model, getcontext, NodeData, FactorNodeProperties, VariableNodeProperties, getproperties, setextra!, getextra, hasextra, NodeDataExtraKey + + model = create_model() + context = getcontext(model) + + @testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1)) + nodedata = NodeData(context, properties) + + @test !hasextra(nodedata, :a) + setextra!(nodedata, :a, 1) + @test hasextra(nodedata, :a) + @test getextra(nodedata, :a) === 1 + + # In the current implementation it is not possible to update extra properties + @test_throws Exception setextra!(nodedata, :a, 2) + + @test !hasextra(nodedata, :b) + setextra!(nodedata, :b, 2) + @test hasextra(nodedata, :b) + @test getextra(nodedata, :b) === 2 + + constkey_c_float = NodeDataExtraKey{:c, Float64}() + + @test !@inferred(hasextra(nodedata, constkey_c_float)) + @inferred(setextra!(nodedata, constkey_c_float, 3.0)) + @test @inferred(hasextra(nodedata, constkey_c_float)) + @test @inferred(getextra(nodedata, constkey_c_float)) === 3.0 + + constkey_d_int = NodeDataExtraKey{:d, Int64}() + + @test !@inferred(hasextra(nodedata, constkey_d_int)) + @inferred(setextra!(nodedata, constkey_d_int, 4)) + @test @inferred(hasextra(nodedata, constkey_d_int)) + @test @inferred(getextra(nodedata, constkey_d_int)) === 4 + end +end + @testitem "factor_nodes" begin import GraphPPL: factor_nodes, is_factor, labels include("model_zoo.jl")