From 14a98527b2c23b3cb289c88d0203222abfd75b69 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Mon, 27 Nov 2023 16:15:50 +0100 Subject: [PATCH] Change saving of neighbors between variable and factor nodes --- src/constraints_engine.jl | 1 - src/graph_engine.jl | 18 ++++++++++++------ test/graph_engine_tests.jl | 15 +++++++-------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/constraints_engine.jl b/src/constraints_engine.jl index 4ebc4883..596fd1cd 100644 --- a/src/constraints_engine.jl +++ b/src/constraints_engine.jl @@ -568,7 +568,6 @@ function materialize_constraints!(model::Model, node_label::NodeLabel, node_data edges = GraphPPL.edges(model, node_label) constraint = Tuple(sort!(collect(constraint_set), by = first)) constraint = map(clusters -> Tuple(getindex.(Ref(edges), clusters)), constraint) - node_data.factorization_constraint = constraint end diff --git a/src/graph_engine.jl b/src/graph_engine.jl index 256499b8..7230da6b 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -70,7 +70,6 @@ to_symbol(label::NodeLabel) = Symbol(String(label.name) * "_" * string(label.glo Base.show(io::IO, label::NodeLabel) = print(io, label.name, "_", label.global_counter) - struct EdgeLabel name::Symbol index::Union{Int, Nothing} @@ -278,9 +277,17 @@ Graphs.ne(model::Model) = Graphs.ne(model.graph) Graphs.edges(model::Model) = Graphs.edges(model.graph) MetaGraphsNext.label_for(model::Model, node_id::Int) = MetaGraphsNext.label_for(model.graph, node_id) -Graphs.neighbors(model::Model, node::NodeLabel) = map(neighbor -> neighbor[1], model[node].neighbors) +Graphs.neighbors(model::Model, node::NodeLabel) = Graphs.neighbors(model, node, model[node]) +Graphs.neighbors(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[1], nodedata.neighbors) +Graphs.neighbors(model::Model, node::NodeLabel, nodedata::VariableNodeData) = MetaGraphsNext.neighbor_labels(model.graph, node) Graphs.neighbors(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.neighbors(model, node), nodes)) -Graphs.edges(model::Model, node::NodeLabel) = map(edge -> edge[2], model[node].neighbors) + +Graphs.edges(model::Model, node::NodeLabel) = Graphs.edges(model, node, model[node]) +Graphs.edges(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[2], nodedata.neighbors) +function Graphs.edges(model::Model, node::NodeLabel, nodedata::VariableNodeData) + return Tuple(model[node, dst] for dst in MetaGraphsNext.neighbor_labels(model.graph, node)) +end +Graphs.edges(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.edges(model, node), nodes)) abstract type AbstractModelFilterPredicate end @@ -768,9 +775,8 @@ iterator(interfaces::NamedTuple) = zip(keys(interfaces), values(interfaces)) function add_edge!(model::Model, factor_node_id::NodeLabel, variable_node_id::Union{ProxyLabel, NodeLabel}, interface_name::Symbol; index = nothing) label = EdgeLabel(interface_name, index) - # model.graph[unroll(variable_node_id), factor_node_id] = label model[factor_node_id].neighbors = (model[factor_node_id].neighbors..., (unroll(variable_node_id), label)) - # model[unroll(variable_node_id)].neighbors = (model[unroll(variable_node_id)].neighbors..., (factor_node_id, label)) + model.graph[unroll(variable_node_id), factor_node_id] = label end function add_edge!(model::Model, factor_node_id::NodeLabel, variable_nodes::Union{AbstractArray, Tuple, NamedTuple}, interface_name::Symbol; index = 1) @@ -783,7 +789,7 @@ increase_index(any) = 1 increase_index(x::AbstractArray) = length(x) function add_factorization_constraint!(model::Model, factor_node_id::NodeLabel) - out_degree = length(model[factor_node_id].neighbors) + out_degree = length(model[factor_node_id].neighbors) constraint = BitSetTuple(out_degree) set_factorization_constraint!(model[factor_node_id], constraint) end diff --git a/test/graph_engine_tests.jl b/test/graph_engine_tests.jl index b62d81a0..7cd354b8 100644 --- a/test/graph_engine_tests.jl +++ b/test/graph_engine_tests.jl @@ -321,12 +321,12 @@ end b = NodeLabel(:b, 2) model[a] = VariableNodeData(:a, VariableNodeOptions(), nothing, nothing, nothing, ()) model[b] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, nothing, ()) - add_edge!(model, a, b, :edge; index=1) + add_edge!(model, a, b, :edge; index = 1) @test length(edges(model)) == 1 c = NodeLabel(:c, 2) model[NodeLabel(:c, 2)] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, nothing, ()) - add_edge!(model, a, c, :edge; index=2) + add_edge!(model, a, c, :edge; index = 2) @test length(edges(model)) == 2 # Test 2: Test getting all edges from a model with a specific node @@ -345,7 +345,7 @@ end b = NodeLabel(:b, 2) model[a] = VariableNodeData(:a, VariableNodeOptions(), nothing, nothing, __context__, ()) model[b] = VariableNodeData(:b, VariableNodeOptions(), nothing, nothing, __context__, ()) - add_edge!(model, a, b, :edge; index=1) + add_edge!(model, a, b, :edge; index = 1) @test collect(neighbors(model, NodeLabel(:a, 1))) == [NodeLabel(:b, 2)] model = create_model() @@ -357,7 +357,7 @@ end model[a[i]] = VariableNodeData(:a, VariableNodeOptions(), i, nothing, __context__, ()) b[i] = NodeLabel(:b, i) model[b[i]] = VariableNodeData(:b, VariableNodeOptions(), i, nothing, __context__, ()) - add_edge!(model, a[i], b[i], :edge; index=i) + add_edge!(model, a[i], b[i], :edge; index = i) end for n in b @test n ∈ neighbors(model, a) @@ -1294,8 +1294,8 @@ end end @testitem "sort_interfaces" begin - import GraphPPL: sort_interfaces - include("model_zoo.jl") + import GraphPPL: sort_interfaces + include("model_zoo.jl") # Test 1: Test that sort_interfaces sorts the interfaces in the correct order @test sort_interfaces(NormalMeanVariance, (μ = 1, σ = 1, out = 1)) == (out = 1, μ = 1, σ = 1) @@ -1308,5 +1308,4 @@ end @test sort_interfaces(NormalMeanPrecision, (τ = 1, μ = 1, out = 1)) == (out = 1, μ = 1, τ = 1) @test_throws ErrorException sort_interfaces(NormalMeanVariance, (σ = 1, μ = 1, τ = 1)) - -end \ No newline at end of file +end