Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compile time keys for the extra properties of graph nodes #161

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,21 @@ hasextra(node::NodeData, key::Symbol) = haskey(node.extra, key)
getextra(node::NodeData, key::Symbol) = getindex(node.extra, key)
setextra!(node::NodeData, key::Symbol, value) = insert!(node.extra, key, value)

"""
A compile time key to access the `extra` properties of the `NodeData` structure.
"""
struct NodeDataExtraKey{K, T} end

function hasextra(node::NodeData, key::NodeDataExtraKey{K}) where K
return haskey(node.extra, K)
end
function getextra(node::NodeData, key::NodeDataExtraKey{K, T})::T where {K, T}
return getindex(node.extra, K)::T
end
function setextra!(node::NodeData, key::NodeDataExtraKey{K}, value::T) where {K, T}
return insert!(node.extra, K, value)
end

is_factor(node::NodeData) = is_factor(getproperties(node))
is_variable(node::NodeData) = is_variable(getproperties(node))

Expand Down
6 changes: 4 additions & 2 deletions src/plugins/meta/meta_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ function apply_meta!(model::Model, context::Context, meta::MetaObject{S, T} wher
end
end

const MetaExtraKey = NodeDataExtraKey{:meta, Any}()

function save_meta!(model::Model, node::NodeLabel, meta::MetaObject{S, T} where {S, T <: NamedTuple})
data = getmetainfo(meta)
if !haskey(data, :meta)
Expand All @@ -155,7 +157,7 @@ end

function save_meta!(model::Model, node::NodeLabel, meta::MetaObject{S, T} where {S, T})
nodedata = model[node]
if !hasextra(nodedata, :meta)
setextra!(nodedata, :meta, getmetainfo(meta))
if !hasextra(nodedata, MetaExtraKey)
setextra!(nodedata, MetaExtraKey, getmetainfo(meta))
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data
return materialize_constraints!(model, node_label, node_data, getproperties(node_data))
end

const VariationalConstraintsFactorizationIndicesKey = NodeDataExtraKey{:factorization_constraint_indices, Tuple}()

function materialize_constraints!(model::Model, node_label::NodeLabel, node_data::NodeData, ::FactorNodeProperties)
constraint_bitset = getextra(node_data, :factorization_constraint_bitset)
num_neighbors = length(constraint_bitset)
Expand All @@ -572,7 +574,7 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data
)
end
rows = Tuple(map(row -> filter(!iszero, map(elem -> elem[2] == 1 ? elem[1] : 0, enumerate(row))), constraint_set))
setextra!(node_data, :factorization_constraint_indices, rows)
setextra!(node_data, VariationalConstraintsFactorizationIndicesKey, rows)
end

function is_valid_partition(contents)
Expand Down
38 changes: 38 additions & 0 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ end
end
end

@testitem "NodeData extra properties" begin
import GraphPPL: create_model, getcontext, NodeData, FactorNodeProperties, VariableNodeProperties, getproperties, setextra!, getextra, hasextra, NodeDataExtraKey

model = create_model()
context = getcontext(model)

@testset for properties in (FactorNodeProperties(fform = String), VariableNodeProperties(name = :x, index = 1))
nodedata = NodeData(context, properties)

@test !hasextra(nodedata, :a)
setextra!(nodedata, :a, 1)
@test hasextra(nodedata, :a)
@test getextra(nodedata, :a) === 1

# In the current implementation it is not possible to update extra properties
@test_throws Exception setextra!(nodedata, :a, 2)

@test !hasextra(nodedata, :b)
setextra!(nodedata, :b, 2)
@test hasextra(nodedata, :b)
@test getextra(nodedata, :b) === 2

constkey_c_float = NodeDataExtraKey{:c, Float64}()

@test !@inferred(hasextra(nodedata, constkey_c_float))
@inferred(setextra!(nodedata, constkey_c_float, 3.0))
@test @inferred(hasextra(nodedata, constkey_c_float))
@test @inferred(getextra(nodedata, constkey_c_float)) === 3.0

constkey_d_int = NodeDataExtraKey{:d, Int64}()

@test !@inferred(hasextra(nodedata, constkey_d_int))
@inferred(setextra!(nodedata, constkey_d_int, 4))
@test @inferred(hasextra(nodedata, constkey_d_int))
@test @inferred(getextra(nodedata, constkey_d_int)) === 4
end
end

@testitem "factor_nodes" begin
import GraphPPL: factor_nodes, is_factor, labels
include("model_zoo.jl")
Expand Down
Loading