Skip to content

Commit

Permalink
Merge branch 'dev-4.0.0' into 181-passing-a-vector-of-rvs-to-submodel…
Browse files Browse the repository at this point in the history
…-doesnt-work
  • Loading branch information
wouterwln authored Mar 19, 2024
2 parents 503e943 + 20a9458 commit e4052db
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 22 deletions.
29 changes: 12 additions & 17 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1469,18 +1469,11 @@ make_node!(
) where {F} =
make_node!(contains_nodelabel(rhs_interfaces), atomic, deterministic, model, ctx, options, fform, lhs_interface, rhs_interfaces)

# If the node should not be materialized (if it's Atomic, Deterministic and contains no NodeLabel objects), we return the function evaluated at the interfaces
make_node!(
::False,
::Atomic,
::Deterministic,
model::Model,
ctx::Context,
options::NodeCreationOptions,
fform::F,
lhs_interface,
rhs_interfaces::Tuple
) where {F} = (nothing, fform(rhs_interfaces...))
# If the node should not be materialized (if it's Atomic, Deterministic and contains no NodeLabel objects), we return the `fform` evaluated at the interfaces
# This works only if the `lhs_interface` is `AnonymousVariable` (or the corresponding `ProxyLabel` with `AnonymousVariable` as the proxied variable)
__evaluate_fform(fform::F, args::Tuple) where {F} = fform(args...)
__evaluate_fform(fform::F, args::NamedTuple) where {F} = fform(; args...)
__evaluate_fform(fform::F, args::MixedArguments) where {F} = fform(args.args...; args.kwargs...)

make_node!(
::False,
Expand All @@ -1490,10 +1483,12 @@ make_node!(
ctx::Context,
options::NodeCreationOptions,
fform::F,
lhs_interface,
rhs_interfaces::NamedTuple
) where {F} = (nothing, fform(; rhs_interfaces...))
lhs_interface::Union{AnonymousVariable, ProxyLabel{<:T, <:AnonymousVariable} where {T}},
rhs_interfaces::Union{Tuple, NamedTuple, MixedArguments}
) where {F} = (nothing, __evaluate_fform(fform, rhs_interfaces))

# In case if the `lhs_interface` is something else we throw an error saying that `fform` cannot be instantiated since
# arguments are not stochastic and the `fform` is not stochastic either, thus the usage of `~` is invalid
make_node!(
::False,
::Atomic,
Expand All @@ -1503,8 +1498,8 @@ make_node!(
options::NodeCreationOptions,
fform::F,
lhs_interface,
rhs_interfaces::MixedArguments
) where {F} = (nothing, fform(rhs_interfaces.args...; rhs_interfaces.kwargs...))
rhs_interfaces::Union{Tuple, NamedTuple, MixedArguments}
) where {F} = error("`$(fform)` cannot be used as a factor node. Both the arguments and the node are not stochastic.")

# If a node is Stochastic, we always materialize.
make_node!(
Expand Down
23 changes: 22 additions & 1 deletion test/graph_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ end
@test length(collect(filter(as_node(prior), model))) === 1
end


@testitem "Model that passes a slice to child model" begin
using GraphPPL
include("testutils.jl")
Expand Down Expand Up @@ -820,4 +821,24 @@ end
end

@test_throws GraphPPL.NotImplementedError local model = GraphPPL.create_model(mixed_m())
end
end

@testitem "Model creation should throw if a `~` using with a constant on RHS" begin
using Distributions
import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex

include("testutils.jl")

@model function broken_beta_bernoulli(y)
# This should throw an error since `Matrix` is not defined as a proper node
θ ~ Matrix([1.0 0.0; 0.0 1.0])
for i in eachindex(y)
y[i] ~ Bernoulli(θ)
end
end

@test_throws "`Matrix` cannot be used as a factor node. Both the arguments and the node are not stochastic." create_model(broken_beta_bernoulli()) do model, context
return (; y = getorcreate!(model, context, NodeCreationOptions(kind = :data), :y, LazyIndex(rand(10))))
end
end

14 changes: 10 additions & 4 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,7 @@ end
make_node!,
create_model,
getorcreate!,
AnonymousVariable,
ProxyLabel,
getname,
label_for,
Expand All @@ -1855,10 +1856,15 @@ end
model = create_test_model()
ctx = getcontext(model)
options = NodeCreationOptions()
x = getorcreate!(model, ctx, :x, nothing)
x = AnonymousVariable(model, ctx)
@test make_node!(model, ctx, options, +, x, (1, 1)) == (nothing, 2)
@test make_node!(model, ctx, options, sin, x, (0,)) == (nothing, 0)
@test nv(model) == 0

x = ProxyLabel(:proxy, nothing, AnonymousVariable(model, ctx))
@test make_node!(model, ctx, options, +, x, (1, 1)) == (nothing, 2)
@test make_node!(model, ctx, options, sin, x, (0,)) == (nothing, 0)
@test nv(model) == 1
@test nv(model) == 0

# Test 2: Stochastic atomic call returns a new node id
node_id, _ = make_node!(model, ctx, options, Normal, x, (μ = 0, σ = 1))
Expand Down Expand Up @@ -1947,7 +1953,7 @@ end
model = create_test_model()
ctx = getcontext(model)
options = NodeCreationOptions()
out = getorcreate!(model, ctx, :out, nothing)
out = AnonymousVariable(model, ctx)
@test make_node!(model, ctx, options, abc, out, (a = 1, b = 2)) == (nothing, 3)

# Test 11: Deterministic node with mixed arguments
Expand All @@ -1957,7 +1963,7 @@ end
model = create_test_model()
ctx = getcontext(model)
options = NodeCreationOptions()
out = getorcreate!(model, ctx, :out, nothing)
out = AnonymousVariable(model, ctx)
@test make_node!(model, ctx, options, abc, out, MixedArguments((2,), (b = 2,))) == (nothing, 4)

# Test 12: Deterministic node with mixed arguments that has to be materialized should throw error
Expand Down

0 comments on commit e4052db

Please sign in to comment.