diff --git a/src/constraints_engine.jl b/src/constraints_engine.jl index 5297794d..128a1628 100644 --- a/src/constraints_engine.jl +++ b/src/constraints_engine.jl @@ -690,14 +690,7 @@ function is_decoupled(var_1::VariableNodeData, var_2::VariableNodeData, constrai linkvar_1 = getlink(var_1) linkvar_2 = getlink(var_2) - if !isnothing(linkvar_1) && !isnothing(linkvar_2) - error( - """ - Cannot resolve the factorization constraint $(constraint) for linked for anonymous variables anon_1 and anon_2 connected to variables $(join(linkvar_1, ',')) and $(join(linkvar_2, ',')) respectively. - As a workaround specify the name and the factorization constraint for the anonymous variables explicitly. - """ - ) - elseif !isnothing(linkvar_1) + if !isnothing(linkvar_1) return is_decoupled_one_linked(linkvar_1, var_2, constraint) elseif !isnothing(linkvar_2) return is_decoupled_one_linked(linkvar_2, var_1, constraint) diff --git a/src/graph_engine.jl b/src/graph_engine.jl index c6717271..98054af4 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -290,7 +290,8 @@ Graphs.ne(model::Model) = Graphs.ne(model.graph) Graphs.edges(model::Model) = Graphs.edges(model.graph) Graphs.neighbors(model::Model, node::NodeLabel) = Graphs.neighbors(model, node, model[node]) -Graphs.neighbors(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[1], nodedata.neighbors) +Graphs.neighbors(model::Model, node::NodeLabel, nodedata::FactorNodeData) = Graphs.neighbors(model[node]) +Graphs.neighbors(nodedata::FactorNodeData) = map(neighbor -> neighbor[1], nodedata.neighbors) Graphs.neighbors(model::Model, node::NodeLabel, nodedata::VariableNodeData) = MetaGraphsNext.neighbor_labels(model.graph, node) Graphs.neighbors(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.neighbors(model, node), nodes)) diff --git a/test/constraints_engine_tests.jl b/test/constraints_engine_tests.jl index 3dac65b5..0dabcb26 100644 --- a/test/constraints_engine_tests.jl +++ b/test/constraints_engine_tests.jl @@ -761,13 +761,13 @@ end @test GraphPPL.is_applicable(neighbors, constraint) # This shouldn't throw and resolve because both anonymous variables are 1-to-1 and referenced by constraint. - @test_broken GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2, 3], [1, 2], [1, 3]]) + @test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2, 3], [1, 2], [1, 3]]) end # Test ResolvedFactorizationConstraints over ambiguous anonymouys variables model = create_terminated_model(node_with_ambiguous_anonymous) context = GraphPPL.getcontext(model) - normal_node = context[NormalMeanVariance, 6] + normal_node = last(filter(GraphPPL.as_node(NormalMeanVariance), model)) neighbors = model[GraphPPL.neighbors(model, normal_node)] let constraint = ResolvedFactorizationConstraint( ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),), @@ -776,13 +776,31 @@ end @test GraphPPL.is_applicable(neighbors, constraint) # This test should throw since we cannot resolve the constraint - @test_broken ( - try - GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) - catch e - e - end - ) isa Exception + @test_throws ErrorException GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) + end + + # Test ResolvedFactorizationConstraint with a Mixture node + model = create_terminated_model(mixture) + context = GraphPPL.getcontext(model) + mixture_node = first(filter(GraphPPL.as_node(Mixture), model)) + neighbors = model[GraphPPL.neighbors(model, mixture_node)] + let constraint = ResolvedFactorizationConstraint( + ResolvedConstraintLHS(( + ResolvedIndexedVariable(:m1, nothing, context), + ResolvedIndexedVariable(:m2, nothing, context), + ResolvedIndexedVariable(:m3, nothing, context), + ResolvedIndexedVariable(:m4, nothing, context) + ),), + ( + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m1, nothing, context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m2, nothing, context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m3, nothing, context),)), + ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m4, nothing, context),)) + ) + ) + @test GraphPPL.is_applicable(neighbors, constraint) + @test GraphPPL.convert_to_bitsets(model, mixture_node, neighbors, constraint) == + BitSetTuple([collect(1:9), [1, 2, 6, 7, 8, 9], [1, 3, 6, 7, 8, 9], [1, 4, 6, 7, 8, 9], [1, 5, 6, 7, 8, 9], collect(1:9), collect(1:9), collect(1:9), collect(1:9)]) end end diff --git a/test/integration_tests.jl b/test/integration_tests.jl index 95edf9a9..89390b1e 100644 --- a/test/integration_tests.jl +++ b/test/integration_tests.jl @@ -120,7 +120,7 @@ end end end -@testitem "simple @model + structured @constraints + anonymous variable linked through a deterministic relation" begin +@testitem "simple @model + structured @constraints + anonymous variable linked through a deterministic relation with constants/datavars" begin using Distributions, LinearAlgebra using GraphPPL: create_model, getcontext, getorcreate!, add_terminated_submodel!, apply!, as_node, factorization_constraint, VariableNodeOptions diff --git a/test/model_zoo.jl b/test/model_zoo.jl index f814e7fe..a38c8dab 100644 --- a/test/model_zoo.jl +++ b/test/model_zoo.jl @@ -50,6 +50,12 @@ function create_terminated_model(fform) return __model__ end +struct Mixture end + +GraphPPL.interfaces(::Type{Mixture}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :m, :τ)) + +GraphPPL.NodeBehaviour(::Type{Mixture}) = GraphPPL.Stochastic() + @model function simple_model() x ~ Normal(0, 1) y ~ Gamma(1, 1) @@ -117,8 +123,8 @@ end x[1] ~ Normal(0, 1) y[1] ~ Normal(0, 1) for i in 2:10 - y[i] ~ Normal(0, 1) - x[i] ~ Normal(y[i - 1] + y[i], 1) + x[i] ~ Normal(x[i - 1], 1) + y[i] ~ Normal(x[i] + y[i - 1], 1) end end @@ -225,6 +231,18 @@ end end end +@model function mixture() + m1 ~ Normal(0, 1) + m2 ~ Normal(0, 1) + m3 ~ Normal(0, 1) + m4 ~ Normal(0, 1) + t1 ~ Normal(0, 1) + t2 ~ Normal(0, 1) + t3 ~ Normal(0, 1) + t4 ~ Normal(0, 1) + y ~ Mixture(m = [m1, m2, m3, m4], τ = [t1, t2, t3, t4]) +end + GraphPPL.default_constraints(::typeof(model_with_default_constraints)) = @constraints( begin q(a, d) = q(a)q(d)