Skip to content

Commit

Permalink
Merge pull request #185 from ReactiveBayes/dev-engine-ignores-constan…
Browse files Browse the repository at this point in the history
…t-nodes

Fix invalid deterministic fform evaluation
  • Loading branch information
bvdmitri authored Mar 19, 2024
2 parents 37c54e3 + 21c7478 commit 20a9458
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
30 changes: 12 additions & 18 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,6 @@ check_variate_compatability(node::ResizableArray{NodeLabel, V, N}, index::NTuple
check_variate_compatability(node::ResizableArray{NodeLabel, V, N}, index::Vararg{Int, M}) where {V, N, M} =
error("Index of length $(length(index)) not possible for $N-dimensional vector of random variables")


check_variate_compatability(node::ResizableArray{NodeLabel, V, N}, index::Nothing) where {V, N} =
error("Cannot call vector of random variables on the left-hand-side by an unindexed statement")

Expand Down Expand Up @@ -1463,18 +1462,11 @@ make_node!(
) where {F} =
make_node!(contains_nodelabel(rhs_interfaces), atomic, deterministic, model, ctx, options, fform, lhs_interface, rhs_interfaces)

# If the node should not be materialized (if it's Atomic, Deterministic and contains no NodeLabel objects), we return the function evaluated at the interfaces
make_node!(
::False,
::Atomic,
::Deterministic,
model::Model,
ctx::Context,
options::NodeCreationOptions,
fform::F,
lhs_interface,
rhs_interfaces::Tuple
) where {F} = (nothing, fform(rhs_interfaces...))
# If the node should not be materialized (if it's Atomic, Deterministic and contains no NodeLabel objects), we return the `fform` evaluated at the interfaces
# This works only if the `lhs_interface` is `AnonymousVariable` (or the corresponding `ProxyLabel` with `AnonymousVariable` as the proxied variable)
__evaluate_fform(fform::F, args::Tuple) where {F} = fform(args...)
__evaluate_fform(fform::F, args::NamedTuple) where {F} = fform(; args...)
__evaluate_fform(fform::F, args::MixedArguments) where {F} = fform(args.args...; args.kwargs...)

make_node!(
::False,
Expand All @@ -1484,10 +1476,12 @@ make_node!(
ctx::Context,
options::NodeCreationOptions,
fform::F,
lhs_interface,
rhs_interfaces::NamedTuple
) where {F} = (nothing, fform(; rhs_interfaces...))
lhs_interface::Union{AnonymousVariable, ProxyLabel{<:T, <:AnonymousVariable} where {T}},
rhs_interfaces::Union{Tuple, NamedTuple, MixedArguments}
) where {F} = (nothing, __evaluate_fform(fform, rhs_interfaces))

# In case if the `lhs_interface` is something else we throw an error saying that `fform` cannot be instantiated since
# arguments are not stochastic and the `fform` is not stochastic either, thus the usage of `~` is invalid
make_node!(
::False,
::Atomic,
Expand All @@ -1497,8 +1491,8 @@ make_node!(
options::NodeCreationOptions,
fform::F,
lhs_interface,
rhs_interfaces::MixedArguments
) where {F} = (nothing, fform(rhs_interfaces.args...; rhs_interfaces.kwargs...))
rhs_interfaces::Union{Tuple, NamedTuple, MixedArguments}
) where {F} = error("`$(fform)` cannot be used as a factor node. Both the arguments and the node are not stochastic.")

# If a node is Stochastic, we always materialize.
make_node!(
Expand Down
19 changes: 19 additions & 0 deletions test/graph_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -752,3 +752,22 @@ end
@test length(collect(filter(as_node(Bernoulli), model))) === 10
@test length(collect(filter(as_node(prior), model))) === 1
end

@testitem "Model creation should throw if a `~` using with a constant on RHS" begin
using Distributions
import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex

include("testutils.jl")

@model function broken_beta_bernoulli(y)
# This should throw an error since `Matrix` is not defined as a proper node
θ ~ Matrix([1.0 0.0; 0.0 1.0])
for i in eachindex(y)
y[i] ~ Bernoulli(θ)
end
end

@test_throws "`Matrix` cannot be used as a factor node. Both the arguments and the node are not stochastic." create_model(broken_beta_bernoulli()) do model, context
return (; y = getorcreate!(model, context, NodeCreationOptions(kind = :data), :y, LazyIndex(rand(10))))
end
end
14 changes: 10 additions & 4 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,7 @@ end
make_node!,
create_model,
getorcreate!,
AnonymousVariable,
ProxyLabel,
getname,
label_for,
Expand All @@ -1855,10 +1856,15 @@ end
model = create_test_model()
ctx = getcontext(model)
options = NodeCreationOptions()
x = getorcreate!(model, ctx, :x, nothing)
x = AnonymousVariable(model, ctx)
@test make_node!(model, ctx, options, +, x, (1, 1)) == (nothing, 2)
@test make_node!(model, ctx, options, sin, x, (0,)) == (nothing, 0)
@test nv(model) == 0

x = ProxyLabel(:proxy, nothing, AnonymousVariable(model, ctx))
@test make_node!(model, ctx, options, +, x, (1, 1)) == (nothing, 2)
@test make_node!(model, ctx, options, sin, x, (0,)) == (nothing, 0)
@test nv(model) == 1
@test nv(model) == 0

# Test 2: Stochastic atomic call returns a new node id
node_id, _ = make_node!(model, ctx, options, Normal, x, (μ = 0, σ = 1))
Expand Down Expand Up @@ -1947,7 +1953,7 @@ end
model = create_test_model()
ctx = getcontext(model)
options = NodeCreationOptions()
out = getorcreate!(model, ctx, :out, nothing)
out = AnonymousVariable(model, ctx)
@test make_node!(model, ctx, options, abc, out, (a = 1, b = 2)) == (nothing, 3)

# Test 11: Deterministic node with mixed arguments
Expand All @@ -1957,7 +1963,7 @@ end
model = create_test_model()
ctx = getcontext(model)
options = NodeCreationOptions()
out = getorcreate!(model, ctx, :out, nothing)
out = AnonymousVariable(model, ctx)
@test make_node!(model, ctx, options, abc, out, MixedArguments((2,), (b = 2,))) == (nothing, 4)

# Test 12: Deterministic node with mixed arguments that has to be materialized should throw error
Expand Down

0 comments on commit 20a9458

Please sign in to comment.