Skip to content

Commit

Permalink
Fix submodels as keyword arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Mar 12, 2024
1 parent 42a1ba8 commit 39a68ce
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/model_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ function get_make_node_function(ms_body, ms_args, ms_name)
__parent_context__::GraphPPL.Context,
__options__::GraphPPL.NodeCreationOptions,
::typeof($ms_name),
__lhs_interface__::GraphPPL.ProxyLabel,
__lhs_interface__::Union{GraphPPL.NodeLabel, GraphPPL.ProxyLabel},
__rhs_interfaces__::NamedTuple,
__n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))}
)
Expand Down
45 changes: 44 additions & 1 deletion test/graph_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ end
return (y = y,)
end

variable_nodes(model) do label, nodedata
variable_nodes(model) do label, nodedata
properties = getproperties(nodedata)
if is_random(properties)
# Shouldn't be any anonymous variables here
Expand All @@ -634,7 +634,50 @@ end
@test length(collect(filter(as_node(sum), model))) === 0
@test length(collect(filter(as_variable(:C), model))) === 1
@test length(collect(filter(as_variable(:m), model))) === 1
end
end

@testitem "Submodels can be used in the keyword arguments" begin
using Distributions, LinearAlgebra

import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex, variable_nodes, getproperties, is_random, getname

@model function prod_distributions(a, b, c)
a ~ b * c
end

# The test tests if we can write `μ = prod_distributions(b = A, c = x_prev)`
@model function state_transition_with_submodel(y_next, x_next, x_prev, A, B, P, Q)
x_next ~ MvNormal= prod_distributions(b = A, c = x_prev), Σ = Q)
y_next ~ MvNormal= prod_distributions(b = B, c = x_next), Σ = P)
end

@model function multivariate_lgssm_model_with_several_submodels(y, mean0, cov0, A, B, Q, P)
x_prev ~ MvNormal= mean0, Σ = cov0)
for i in eachindex(y)
x[i] ~ state_transition_with_submodel(y_next = y[i], x_prev = x_prev, A = A, B = B, P = P, Q = Q)
x_prev = x[i]
end
end

ydata = rand(10)
A = rand(3, 3)
B = rand(3, 3)
Q = rand(3, 3)
P = rand(3, 3)
mean0 = rand(3)
cov0 = rand(3, 3)

model =
create_model(multivariate_lgssm_model_with_several_submodels(mean0 = mean0, cov0 = cov0, A = A, B = B, Q = Q, P = P)) do model, ctx
y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :y, LazyIndex(ydata))
return (y = y,)
end

@test length(collect(filter(as_node(MvNormal), model))) === 21
@test length(collect(filter(as_node(prod), model))) === 20

@test length(collect(filter(as_variable(:a), model))) === 0
@test length(collect(filter(as_variable(:b), model))) === 0
@test length(collect(filter(as_variable(:x), model))) === 10
end

0 comments on commit 39a68ce

Please sign in to comment.