diff --git a/src/model_macro.jl b/src/model_macro.jl index f415afd6..827a4400 100644 --- a/src/model_macro.jl +++ b/src/model_macro.jl @@ -702,7 +702,7 @@ function get_make_node_function(ms_body, ms_args, ms_name) __parent_context__::GraphPPL.Context, __options__::GraphPPL.NodeCreationOptions, ::typeof($ms_name), - __lhs_interface__::GraphPPL.ProxyLabel, + __lhs_interface__::Union{GraphPPL.NodeLabel, GraphPPL.ProxyLabel}, __rhs_interfaces__::NamedTuple, __n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))} ) diff --git a/test/graph_construction_tests.jl b/test/graph_construction_tests.jl index a8983adb..d522d6ee 100644 --- a/test/graph_construction_tests.jl +++ b/test/graph_construction_tests.jl @@ -616,7 +616,7 @@ end return (y = y,) end - variable_nodes(model) do label, nodedata + variable_nodes(model) do label, nodedata properties = getproperties(nodedata) if is_random(properties) # Shouldn't be any anonymous variables here @@ -634,7 +634,50 @@ end @test length(collect(filter(as_node(sum), model))) === 0 @test length(collect(filter(as_variable(:C), model))) === 1 @test length(collect(filter(as_variable(:m), model))) === 1 + end +end + +@testitem "Submodels can be used in the keyword arguments" begin + using Distributions, LinearAlgebra + + import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex, variable_nodes, getproperties, is_random, getname + @model function prod_distributions(a, b, c) + a ~ b * c + end + # The test tests if we can write `μ = prod_distributions(b = A, c = x_prev)` + @model function state_transition_with_submodel(y_next, x_next, x_prev, A, B, P, Q) + x_next ~ MvNormal(μ = prod_distributions(b = A, c = x_prev), Σ = Q) + y_next ~ MvNormal(μ = prod_distributions(b = B, c = x_next), Σ = P) end + + @model function multivariate_lgssm_model_with_several_submodels(y, mean0, cov0, A, B, Q, P) + x_prev ~ MvNormal(μ = mean0, Σ = cov0) + for i in eachindex(y) + x[i] ~ state_transition_with_submodel(y_next = y[i], x_prev = x_prev, A = A, B = B, P = P, Q = Q) + x_prev = x[i] + end + end + + ydata = rand(10) + A = rand(3, 3) + B = rand(3, 3) + Q = rand(3, 3) + P = rand(3, 3) + mean0 = rand(3) + cov0 = rand(3, 3) + + model = + create_model(multivariate_lgssm_model_with_several_submodels(mean0 = mean0, cov0 = cov0, A = A, B = B, Q = Q, P = P)) do model, ctx + y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :y, LazyIndex(ydata)) + return (y = y,) + end + + @test length(collect(filter(as_node(MvNormal), model))) === 21 + @test length(collect(filter(as_node(prod), model))) === 20 + + @test length(collect(filter(as_variable(:a), model))) === 0 + @test length(collect(filter(as_variable(:b), model))) === 0 + @test length(collect(filter(as_variable(:x), model))) === 10 end \ No newline at end of file