Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch constraints over anonymous variables bug #150

Merged
merged 3 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/constraints_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -690,14 +690,7 @@ function is_decoupled(var_1::VariableNodeData, var_2::VariableNodeData, constrai
linkvar_1 = getlink(var_1)
linkvar_2 = getlink(var_2)

if !isnothing(linkvar_1) && !isnothing(linkvar_2)
error(
"""
Cannot resolve the factorization constraint $(constraint) for linked for anonymous variables anon_1 and anon_2 connected to variables $(join(linkvar_1, ',')) and $(join(linkvar_2, ',')) respectively.
As a workaround specify the name and the factorization constraint for the anonymous variables explicitly.
"""
)
elseif !isnothing(linkvar_1)
if !isnothing(linkvar_1)
return is_decoupled_one_linked(linkvar_1, var_2, constraint)
elseif !isnothing(linkvar_2)
return is_decoupled_one_linked(linkvar_2, var_1, constraint)
Expand Down
3 changes: 2 additions & 1 deletion src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ Graphs.ne(model::Model) = Graphs.ne(model.graph)
Graphs.edges(model::Model) = Graphs.edges(model.graph)

Graphs.neighbors(model::Model, node::NodeLabel) = Graphs.neighbors(model, node, model[node])
Graphs.neighbors(model::Model, node::NodeLabel, nodedata::FactorNodeData) = map(neighbor -> neighbor[1], nodedata.neighbors)
Graphs.neighbors(model::Model, node::NodeLabel, nodedata::FactorNodeData) = Graphs.neighbors(model[node])
Graphs.neighbors(nodedata::FactorNodeData) = map(neighbor -> neighbor[1], nodedata.neighbors)
Graphs.neighbors(model::Model, node::NodeLabel, nodedata::VariableNodeData) = MetaGraphsNext.neighbor_labels(model.graph, node)
Graphs.neighbors(model::Model, nodes::AbstractArray{<:NodeLabel}) = Iterators.flatten(map(node -> Graphs.neighbors(model, node), nodes))

Expand Down
36 changes: 27 additions & 9 deletions test/constraints_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -761,13 +761,13 @@ end
@test GraphPPL.is_applicable(neighbors, constraint)

# This shouldn't throw and resolve because both anonymous variables are 1-to-1 and referenced by constraint.
@test_broken GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2, 3], [1, 2], [1, 3]])
@test GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint) == BitSetTuple([[1, 2, 3], [1, 2], [1, 3]])
end

# Test ResolvedFactorizationConstraints over ambiguous anonymouys variables
model = create_terminated_model(node_with_ambiguous_anonymous)
context = GraphPPL.getcontext(model)
normal_node = context[NormalMeanVariance, 6]
normal_node = last(filter(GraphPPL.as_node(NormalMeanVariance), model))
neighbors = model[GraphPPL.neighbors(model, normal_node)]
let constraint = ResolvedFactorizationConstraint(
ResolvedConstraintLHS((ResolvedIndexedVariable(:y, nothing, context),),),
Expand All @@ -776,13 +776,31 @@ end
@test GraphPPL.is_applicable(neighbors, constraint)

# This test should throw since we cannot resolve the constraint
@test_broken (
try
GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)
catch e
e
end
) isa Exception
@test_throws ErrorException GraphPPL.convert_to_bitsets(model, normal_node, neighbors, constraint)
end

# Test ResolvedFactorizationConstraint with a Mixture node
model = create_terminated_model(mixture)
context = GraphPPL.getcontext(model)
mixture_node = first(filter(GraphPPL.as_node(Mixture), model))
neighbors = model[GraphPPL.neighbors(model, mixture_node)]
let constraint = ResolvedFactorizationConstraint(
ResolvedConstraintLHS((
ResolvedIndexedVariable(:m1, nothing, context),
ResolvedIndexedVariable(:m2, nothing, context),
ResolvedIndexedVariable(:m3, nothing, context),
ResolvedIndexedVariable(:m4, nothing, context)
),),
(
ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m1, nothing, context),)),
ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m2, nothing, context),)),
ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m3, nothing, context),)),
ResolvedFactorizationConstraintEntry((ResolvedIndexedVariable(:m4, nothing, context),))
)
)
@test GraphPPL.is_applicable(neighbors, constraint)
@test GraphPPL.convert_to_bitsets(model, mixture_node, neighbors, constraint) ==
BitSetTuple([collect(1:9), [1, 2, 6, 7, 8, 9], [1, 3, 6, 7, 8, 9], [1, 4, 6, 7, 8, 9], [1, 5, 6, 7, 8, 9], collect(1:9), collect(1:9), collect(1:9), collect(1:9)])
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/integration_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ end
end
end

@testitem "simple @model + structured @constraints + anonymous variable linked through a deterministic relation" begin
@testitem "simple @model + structured @constraints + anonymous variable linked through a deterministic relation with constants/datavars" begin
using Distributions, LinearAlgebra
using GraphPPL: create_model, getcontext, getorcreate!, add_terminated_submodel!, apply!, as_node, factorization_constraint, VariableNodeOptions

Expand Down
22 changes: 20 additions & 2 deletions test/model_zoo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ function create_terminated_model(fform)
return __model__
end

struct Mixture end

GraphPPL.interfaces(::Type{Mixture}, ::StaticInt{3}) = GraphPPL.StaticInterfaces((:out, :m, :τ))

GraphPPL.NodeBehaviour(::Type{Mixture}) = GraphPPL.Stochastic()

@model function simple_model()
x ~ Normal(0, 1)
y ~ Gamma(1, 1)
Expand Down Expand Up @@ -117,8 +123,8 @@ end
x[1] ~ Normal(0, 1)
y[1] ~ Normal(0, 1)
for i in 2:10
y[i] ~ Normal(0, 1)
x[i] ~ Normal(y[i - 1] + y[i], 1)
x[i] ~ Normal(x[i - 1], 1)
y[i] ~ Normal(x[i] + y[i - 1], 1)
end
end

Expand Down Expand Up @@ -225,6 +231,18 @@ end
end
end

@model function mixture()
m1 ~ Normal(0, 1)
m2 ~ Normal(0, 1)
m3 ~ Normal(0, 1)
m4 ~ Normal(0, 1)
t1 ~ Normal(0, 1)
t2 ~ Normal(0, 1)
t3 ~ Normal(0, 1)
t4 ~ Normal(0, 1)
y ~ Mixture(m = [m1, m2, m3, m4], τ = [t1, t2, t3, t4])
end

GraphPPL.default_constraints(::typeof(model_with_default_constraints)) = @constraints(
begin
q(a, d) = q(a)q(d)
Expand Down
Loading