From 73077827f3b0f2afb4ee419d0b76702218f80836 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Fri, 1 Mar 2024 15:01:30 +0100 Subject: [PATCH 1/5] Reimplement node macro --- src/nodes/nodes.jl | 211 +++++++-------------------------------------- test/node_tests.jl | 107 ++++------------------- 2 files changed, 51 insertions(+), 267 deletions(-) diff --git a/src/nodes/nodes.jl b/src/nodes/nodes.jl index 8fcc78b5c..0c9478c06 100644 --- a/src/nodes/nodes.jl +++ b/src/nodes/nodes.jl @@ -424,195 +424,48 @@ include("dependencies.jl") ## macro helpers -import .MacroHelpers - -# Are still needed for the `@node` macro -function make_node end -function interface_get_index end -function interface_get_name end - -""" - @node(fformtype, sdtype, interfaces_list) - - -`@node` macro creates a node for a `fformtype` type object. To obtain a list of available nodes use `?make_node`. - -# Arguments - -- `fformtype`: Either an existing type identifier, e.g. `Normal` or a function type identifier, e.g. `typeof(+)` -- `sdtype`: Either `Stochastic` or `Deterministic`. Defines the type of the functional relationship -- `interfaces_list`: Defines a fixed list of edges of a factor node, by convention the first element should be `out`. Example: `[ out, mean, variance ]` - -Note: `interfaces_list` must not include names that contain `_` symbol in them, as it is reserved to identify joint posteriors around the node object. - -# Examples -```julia - -struct MyNormalDistribution - mean :: Float64 - var :: Float64 -end - -@node MyNormalDistribution Stochastic [ out, mean, var ] -``` - -```julia - -@node typeof(+) Deterministic [ out, in1, in2 ] -``` - -# List of available nodes - -See `?make_node`. - -See also: [`make_node`](@ref), [`Stochastic`](@ref), [`Deterministic`](@ref) -""" -macro node(fformtype, sdtype, interfaces_list) - fbottomtype = MacroHelpers.bottom_type(fformtype) - fuppertype = MacroHelpers.upper_type(fformtype) +## macro helpers - @assert sdtype ∈ [:Stochastic, :Deterministic] "Invalid sdtype $(sdtype). Can be either Stochastic or Deterministic." +import .MacroHelpers - @capture(interfaces_list, [interfaces_args__]) || error("Invalid interfaces specification.") +function correct_interfaces end - interfaces = map(interfaces_args) do arg - if @capture(arg, name_Symbol) - return (name, []) - elseif @capture(arg, (name_Symbol, aliases = [aliases__])) - @assert all(a -> a isa Symbol && !isequal(a, name), aliases) - return (name, aliases) - else - error("Interface specification should have a 'name' or (name, aliases = [ alias1, alias2,... ]) signature.") - end +macro node(node_fform, node_type, node_interfaces, interface_aliases) + # Assert that the node type is either Stochastic or Deterministic, and that all interfaces are symbols + @assert node_type ∈ [:Stochastic, :Deterministic] + @assert length(node_interfaces.args) > 0 + for interface in node_interfaces.args + @assert isa(interface, Symbol) end - @assert length(interfaces) !== 0 "Node should have at least one interface." - - names = map(d -> d[1], interfaces) - aliases = map(d -> d[2], interfaces) - - foreach(names) do name - @assert !occursin('_', string(name)) "Node interfaces names (and aliases) must not contain `_` symbol in them, found in $(name)." + # Determine whether we should dispatch on `typeof($fform)` or `Type{$node_fform}` + if @capture(node_fform, typeof(fform_)) + dispatch_type = quote typeof($fform) end + else + dispatch_type = quote Type{$node_fform} end end - foreach(Iterators.flatten(aliases)) do alias - @assert !occursin('_', string(alias)) "Node interfaces names (and aliases) must not contain `_` symbol in them, found in $(alias)." + # Define the necessary function types + result = quote + ReactiveMP.as_node_functional_form(::$dispatch_type) = ReactiveMP.ValidNodeFunctionalForm() + ReactiveMP.sdtype(::$dispatch_type) = (ReactiveMP.$node_type)() end - names_quoted_tuple = Expr(:tuple, map(name -> Expr(:quote, name), names)...) - names_indices = Expr(:tuple, map(i -> i, 1:length(names))...) - names_splitted_indices = Expr(:tuple, map(i -> Expr(:tuple, i), 1:length(names))...) - names_indexed = Expr(:tuple, map(name -> Expr(:call, :(ReactiveMP.indexed_name), name), names)...) - - interface_names = map(name -> :(ReactiveMP.indexed_name($name)), names) - interface_args = map(name -> :($name), names) - interface_connections = map(name -> :(ReactiveMP.connect!(node, $(Expr(:quote, name)), $name)), names) - - joined_interface_names = :(join((($(interface_names...)),), ", ")) - - # Check that all arguments within interface refer to the unique var objects - non_unique_error_sym = gensym(:non_unique_error_sym) - non_unique_error_msg = :($non_unique_error_sym = (fformtype, names) -> """ - Non-unique variables used for the creation of the `$(fformtype)` node, which is disallowed. - Check creation of the `$(fformtype)` with the `[ $(join(names, ", ")) ]` arguments. - """) - interface_uniqueness = map(enumerate(names)) do (index, name) - names_without_current = skipindex(names, index) - return quote - if Base.in($(name), ($(names_without_current...),)) - Base.error($(non_unique_error_sym)($fformtype, $names_indexed)) + # If there are any aliases, define the alias correction function + if @capture(interface_aliases, aliases = aliases_) + for alias in aliases.args + result = quote + $result + ReactiveMP.correct_interfaces(::$dispatch_type, nt::NamedTuple{Tuple($(alias.args))}) = NamedTuple{$(Tuple(node_interfaces.args))}(values(nt)) end end end - - # Here we create helpers function for GraphPPL.jl interfacing - # They are used to convert interface names from `where { q = q(x, y)q(z) }` to an equivalent tuple respresentation, e.g. `((1, 2), (3, ))` - # The general recipe to get a proper index is to call `interface_get_index(Val{ :NodeTypeName }, interface_get_name(Val{ :NodeTypeName }, Val{ :name_expression }))` - interface_name_getters = map(enumerate(interfaces)) do (index, interface) - name = first(interface) - aliases = last(interface) - - index_name_getter = :(ReactiveMP.interface_get_index(::Type{Val{$(Expr(:quote, fbottomtype))}}, ::Type{Val{$(Expr(:quote, name))}}) = $(index)) - name_symbol_getter = :(ReactiveMP.interface_get_name(::Type{Val{$(Expr(:quote, fbottomtype))}}, ::Type{Val{$(Expr(:quote, name))}}) = $(Expr(:quote, name))) - name_index_getter = :(ReactiveMP.interface_get_name(::Type{Val{$(Expr(:quote, fbottomtype))}}, ::Type{Val{$index}}) = $(Expr(:quote, name))) - - alias_getters = map(aliases) do alias - return :(ReactiveMP.interface_get_name(::Type{Val{$(Expr(:quote, fbottomtype))}}, ::Type{Val{$(Expr(:quote, alias))}}) = $(Expr(:quote, name))) - end - - return quote - $index_name_getter - $name_symbol_getter - $name_index_getter - $(alias_getters...) - end - end - - # By default every argument passed to a factorisation option of the node is transformed by - # `collect_factorisation` function to have a tuple like structure. - # The default recipe is simple: for stochastic nodes we convert `FullFactorisation` and `MeanField` objects - # to their tuple of indices equivalents. For deterministic nodes any factorisation is replaced by a FullFactorisation equivalent - factorisation_collectors = if sdtype === :Stochastic - quote - ReactiveMP.collect_factorisation(::$fuppertype, ::Nothing) = ($names_indices,) - ReactiveMP.collect_factorisation(::$fuppertype, factorisation::Tuple) = factorisation - ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.FullFactorisation) = ($names_indices,) - ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.MeanField) = $names_splitted_indices - end - - elseif sdtype === :Deterministic - quote - ReactiveMP.collect_factorisation(::$fuppertype, ::Nothing) = ($names_indices,) - ReactiveMP.collect_factorisation(::$fuppertype, factorisation::Tuple) = ($names_indices,) - ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.FullFactorisation) = ($names_indices,) - ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.MeanField) = ($names_indices,) - end - else - error("Unreachable in @node macro.") - end - - doctype = rpad(fbottomtype, 30) - docsdtype = rpad(sdtype, 15) - docedges = string(interfaces_list) - doc = """ - $doctype : $docsdtype : $docedges - """ - - res = quote - ReactiveMP.as_node_functional_form(::$fuppertype) = ReactiveMP.ValidNodeFunctionalForm() - ReactiveMP.as_node_functional_form(::Type{$fuppertype}) = ReactiveMP.ValidNodeFunctionalForm() - - ReactiveMP.sdtype(::$fuppertype) = (ReactiveMP.$sdtype)() - - ReactiveMP.as_node_symbol(::$fuppertype) = $(QuoteNode(fbottomtype)) - - @doc $doc function ReactiveMP.make_node(::Union{$fuppertype, Type{$fuppertype}}, options::FactorNodeCreationOptions) - return ReactiveMP.FactorNode( - $fbottomtype, - $names_quoted_tuple, - ReactiveMP.collect_factorisation($fbottomtype, ReactiveMP.factorisation(options)), - ReactiveMP.collect_meta($fbottomtype, ReactiveMP.metadata(options)), - ReactiveMP.collect_pipeline($fbottomtype, ReactiveMP.getpipeline(options)) - ) - end - - function ReactiveMP.make_node(::Union{$fuppertype, Type{$fuppertype}}, options::FactorNodeCreationOptions, $(interface_args...)) - node = ReactiveMP.make_node($fbottomtype, options) - $(non_unique_error_msg) - $(interface_uniqueness...) - $(interface_connections...) - return node - end - - # Fallback method for unsupported number of arguments, e.g. if node expects 2 inputs, but only 1 was given - function ReactiveMP.make_node(::Union{$fuppertype, Type{$fuppertype}}, options::FactorNodeCreationOptions, args...) - ReactiveMP.make_node_incompatible_number_of_arguments_error($fuppertype, $fbottomtype, $interfaces, args) - end - - $(interface_name_getters...) - - $factorisation_collectors - end - - return esc(res) + + return esc(result) end + +macro node(node_fform, node_type, node_interfaces) + esc(quote + @node($node_fform, $node_type, $node_interfaces, nothing) + end) +end \ No newline at end of file diff --git a/test/node_tests.jl b/test/node_tests.jl index 0edc20d03..77c5ed5b2 100644 --- a/test/node_tests.jl +++ b/test/node_tests.jl @@ -1,4 +1,3 @@ - @testitem "FactorNode" begin using ReactiveMP, Rocket, BayesBase, Distributions @@ -25,36 +24,19 @@ struct CustomStochasticNode end - @node CustomStochasticNode Stochastic [out, (x, aliases = [xx]), (y, aliases = [yy]), z] - - @test ReactiveMP.interface_get_index(Val{:CustomStochasticNode}, Val{:out}) === 1 - @test ReactiveMP.interface_get_index(Val{:CustomStochasticNode}, Val{:x}) === 2 - @test ReactiveMP.interface_get_index(Val{:CustomStochasticNode}, Val{:y}) === 3 - @test ReactiveMP.interface_get_index(Val{:CustomStochasticNode}, Val{:z}) === 4 - - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{:out}) === :out - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{:x}) === :x - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{:y}) === :y - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{:z}) === :z + @node CustomStochasticNode Stochastic [out, x, y, z] aliases = [(out, xx, yy, z)] - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{:xx}) === :x - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{:yy}) === :y + @test ReactiveMP.sdtype(CustomStochasticNode) === Stochastic() + @test ReactiveMP.correct_interfaces(CustomStochasticNode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{1}) === :out - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{2}) === :x - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{3}) === :y - @test ReactiveMP.interface_get_name(Val{:CustomStochasticNode}, Val{4}) === :z + # Testing stochastic function node specification - @test ReactiveMP.collect_factorisation(CustomStochasticNode, ((1,), (2,), (3,), (4,))) === ((1,), (2,), (3,), (4,)) - @test ReactiveMP.collect_factorisation(CustomStochasticNode, ((1, 2), (3,), (4,))) === ((1, 2), (3,), (4,)) - @test ReactiveMP.collect_factorisation(CustomStochasticNode, ((1, 2, 3), (4,))) === ((1, 2, 3), (4,)) - @test ReactiveMP.collect_factorisation(CustomStochasticNode, ((1, 2, 3), (4,))) === ((1, 2, 3), (4,)) - @test ReactiveMP.collect_factorisation(CustomStochasticNode, ((1, 2, 3, 4),)) === ((1, 2, 3, 4),) + function customstochasticnode end - @test ReactiveMP.collect_factorisation(CustomStochasticNode, MeanField()) === ((1,), (2,), (3,), (4,)) - @test ReactiveMP.collect_factorisation(CustomStochasticNode, FullFactorisation()) === ((1, 2, 3, 4),) + @node typeof(customstochasticnode) Stochastic [out, x, y, z] aliases = [(out, xx, yy, z)] - @test sdtype(CustomStochasticNode) === Stochastic() + @test ReactiveMP.sdtype(customstochasticnode) === Stochastic() + @test ReactiveMP.correct_interfaces(customstochasticnode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) # Testing Deterministic node specification @@ -62,72 +44,22 @@ CustomDeterministicNode(x, y, z) = x + y + z - @node CustomDeterministicNode Deterministic [out, (x, aliases = [xx]), (y, aliases = [yy]), z] - - @test ReactiveMP.interface_get_index(Val{:CustomDeterministicNode}, Val{:out}) === 1 - @test ReactiveMP.interface_get_index(Val{:CustomDeterministicNode}, Val{:x}) === 2 - @test ReactiveMP.interface_get_index(Val{:CustomDeterministicNode}, Val{:y}) === 3 - @test ReactiveMP.interface_get_index(Val{:CustomDeterministicNode}, Val{:z}) === 4 - - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{:out}) === :out - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{:x}) === :x - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{:y}) === :y - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{:z}) === :z - - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{:xx}) === :x - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{:yy}) === :y - - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{1}) === :out - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{2}) === :x - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{3}) === :y - @test ReactiveMP.interface_get_name(Val{:CustomDeterministicNode}, Val{4}) === :z - - @test ReactiveMP.collect_factorisation(CustomDeterministicNode, ((1,), (2,), (3,), (4,))) === ((1, 2, 3, 4),) - @test ReactiveMP.collect_factorisation(CustomDeterministicNode, ((1, 2), (3,), (4,))) === ((1, 2, 3, 4),) - @test ReactiveMP.collect_factorisation(CustomDeterministicNode, ((1, 2, 3), (4,))) === ((1, 2, 3, 4),) - @test ReactiveMP.collect_factorisation(CustomDeterministicNode, ((1, 2, 3), (4,))) === ((1, 2, 3, 4),) - @test ReactiveMP.collect_factorisation(CustomDeterministicNode, ((1, 2, 3, 4),)) === ((1, 2, 3, 4),) - - @test ReactiveMP.collect_factorisation(CustomDeterministicNode, MeanField()) === ((1, 2, 3, 4),) - @test ReactiveMP.collect_factorisation(CustomDeterministicNode, FullFactorisation()) === ((1, 2, 3, 4),) - - @test sdtype(CustomDeterministicNode) === Deterministic() - - # Check that same variables are not allowed - - struct DummyNodeCheckUniqueness end + @node CustomDeterministicNode Deterministic [out, x, y, z] aliases = [(out, xx, yy, z)] - @node DummyNodeCheckUniqueness Stochastic [a, b, c] + @test ReactiveMP.sdtype(CustomDeterministicNode) === Deterministic() + @test ReactiveMP.correct_interfaces(CustomDeterministicNode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) - sx = randomvar(:rx) - sd = datavar(:rd, Float64) - sc = constvar(:sc, 1.0) + # Testing deterministic function node specification - vs = (sx, sd, sc) + function customdeterministicnode end - for a in vs, b in vs, c in vs - input = (a, b, c) - if length(input) != length(Set(input)) - @test_throws ErrorException make_node(DummyNodeCheckUniqueness, FactorNodeCreationOptions(), a, b, c) - end - end + customdeterministicnode(x, y, z) = x + y + z - # `make_node` must show a warning in case if factorisation include the `PointMass` distributed variables jointly with other variables - struct DummyNodeCheckFactorisationWarning end + @node typeof(customdeterministicnode) Deterministic [out, x, y, z] aliases = [(out, xx, yy, z)] - @node DummyNodeCheckFactorisationWarning Stochastic [a, b, c] + @test ReactiveMP.sdtype(customdeterministicnode) === Deterministic() + @test ReactiveMP.correct_interfaces(customdeterministicnode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) - for a in (datavar(:a, Float64), constvar(:a, 1.0)), b in (randomvar(:b),), c in (randomvar(:c),) - @test_logs (:warn, r".*replace `q\(a, b, c\)` with `q\(a\)q\(\.\.\.\)`.*") make_node( - DummyNodeCheckFactorisationWarning, FactorNodeCreationOptions(FullFactorisation(), nothing, nothing), a, b, c - ) - @test_logs (:warn, r".*replace `q\(a, b, c\)` with `q\(a\)q\(\.\.\.\)`.*") make_node( - DummyNodeCheckFactorisationWarning, FactorNodeCreationOptions(((1, 2, 3),), nothing, nothing), a, b, c - ) - @test_logs (:warn, r".*replace `q\(a, b\)` with `q\(a\)q\(\.\.\.\)`.*") make_node( - DummyNodeCheckFactorisationWarning, FactorNodeCreationOptions(((1, 2), (3,)), nothing, nothing), a, b, c - ) - end # Testing expected exceptions @@ -135,12 +67,11 @@ @test_throws Exception eval(:(@node DummyStruct NotStochasticAndNotDeterministic [out, in, x])) @test_throws Exception eval(:(@node DummyStruct Stochastic [1, in, x])) - @test_throws Exception eval(:(@node DummyStruct Stochastic [(1, aliases = [out]), in, x])) + @test_throws Exception eval(:(@node DummyStruct Stochastic [p, in, x] aliases = [([z], y, x)])) @test_throws Exception eval(:(@node DummyStruct Stochastic [(out, aliases = [out]), in, x])) @test_throws Exception eval(:(@node DummyStruct Stochastic [(out, aliases = [1]), in, x])) @test_throws Exception eval(:(@node DummyStruct Stochastic [])) - @test_throws LoadError eval(:(@node DummyStruct Stochastic [out, interfaces_with_underscore])) @test_throws LoadError eval(:(@node DummyStruct Stochastic [out, (interface, aliases = [alias_with_underscore])])) end @@ -149,4 +80,4 @@ @test sdtype(DummyDistribution) === Stochastic() end -end +end \ No newline at end of file From 3b346d594f0a76bcebd2c97f650e646b572cdc1f Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Fri, 1 Mar 2024 15:38:35 +0100 Subject: [PATCH 2/5] Update node macro with old alias system --- src/nodes/nodes.jl | 47 ++++++++++++++++++++++++++++++++++------------ test/node_tests.jl | 4 +--- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/nodes/nodes.jl b/src/nodes/nodes.jl index 0c9478c06..15562c3a8 100644 --- a/src/nodes/nodes.jl +++ b/src/nodes/nodes.jl @@ -422,7 +422,7 @@ end include("dependencies.jl") -## macro helpers +function make_node end # TODO (wouterwln) remove this, but it breaks precompilation because of node definitions downstream ## macro helpers @@ -430,37 +430,60 @@ import .MacroHelpers function correct_interfaces end +alias_group(s::Symbol) = [s] +function alias_group(e::Expr) + if @capture(e, (s_, aliases = aliases_)) + result = [s, aliases.args...] + if length(result) != length(unique(result)) + error("Aliases should be unique") + end + return result + else + return [e] + end +end + +check_all_symbol(::AbstractArray{T} where {T <: NTuple{N, Symbol} where {N}}) = nothing +check_all_symbol(::Any) = error("All interfaces should be symbols") + macro node(node_fform, node_type, node_interfaces, interface_aliases) # Assert that the node type is either Stochastic or Deterministic, and that all interfaces are symbols @assert node_type ∈ [:Stochastic, :Deterministic] @assert length(node_interfaces.args) > 0 - for interface in node_interfaces.args - @assert isa(interface, Symbol) - end + + interface_alias_groups = map(alias_group, node_interfaces.args) + all_aliases = vec(collect(Iterators.product(interface_alias_groups...))) + + # Determine whether we should dispatch on `typeof($fform)` or `Type{$node_fform}` if @capture(node_fform, typeof(fform_)) dispatch_type = quote typeof($fform) end else + dispatch_type = quote Type{$node_fform} end end - # Define the necessary function types result = quote ReactiveMP.as_node_functional_form(::$dispatch_type) = ReactiveMP.ValidNodeFunctionalForm() ReactiveMP.sdtype(::$dispatch_type) = (ReactiveMP.$node_type)() end - # If there are any aliases, define the alias correction function if @capture(interface_aliases, aliases = aliases_) - for alias in aliases.args - result = quote - $result - ReactiveMP.correct_interfaces(::$dispatch_type, nt::NamedTuple{Tuple($(alias.args))}) = NamedTuple{$(Tuple(node_interfaces.args))}(values(nt)) - end + defined_aliases = map(alias_group -> Tuple(alias_group.args), aliases.args) + all_aliases = vcat(all_aliases, defined_aliases) + end + + check_all_symbol(all_aliases) + + first_interfaces = map(first, interface_alias_groups) + + for alias in all_aliases + result = quote + $result + ReactiveMP.correct_interfaces(::$dispatch_type, nt::NamedTuple{$alias}) = NamedTuple{$(Tuple(first_interfaces))}(values(nt)) end end - return esc(result) end diff --git a/test/node_tests.jl b/test/node_tests.jl index 77c5ed5b2..4359b0c14 100644 --- a/test/node_tests.jl +++ b/test/node_tests.jl @@ -19,7 +19,7 @@ end @testset "@node macro" begin - + # Testing Stochastic node specification struct CustomStochasticNode end @@ -71,8 +71,6 @@ @test_throws Exception eval(:(@node DummyStruct Stochastic [(out, aliases = [out]), in, x])) @test_throws Exception eval(:(@node DummyStruct Stochastic [(out, aliases = [1]), in, x])) @test_throws Exception eval(:(@node DummyStruct Stochastic [])) - - @test_throws LoadError eval(:(@node DummyStruct Stochastic [out, (interface, aliases = [alias_with_underscore])])) end @testset "sdtype of an arbitrary distribution is Stochastic" begin From 7b7998b22ad8d22448d6767c96e8f154737844a6 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 4 Mar 2024 15:49:00 +0100 Subject: [PATCH 3/5] after merge --- src/variables/constant.jl | 2 +- src/variables/data.jl | 2 +- src/variables/random.jl | 2 +- test/node_tests.jl | 81 ---------------------------------- test/nodes/nodes_tests.jl | 93 ++++++++++++++++++++++++++++++++++++--- 5 files changed, 89 insertions(+), 91 deletions(-) diff --git a/src/variables/constant.jl b/src/variables/constant.jl index 034db5fc1..3e03803e2 100644 --- a/src/variables/constant.jl +++ b/src/variables/constant.jl @@ -1,4 +1,4 @@ -export constvar +export constvar, ConstVariable mutable struct ConstVariable <: AbstractVariable marginal :: MarginalObservable diff --git a/src/variables/data.jl b/src/variables/data.jl index fdfe2d74f..c39c80090 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -1,4 +1,4 @@ -export datavar, update!, DataVariableActivationOptions +export datavar, DataVariable, update!, DataVariableActivationOptions mutable struct DataVariable{M, P} <: AbstractVariable input_messages :: Vector{MessageObservable{AbstractMessage}} diff --git a/src/variables/random.jl b/src/variables/random.jl index 6cfdf29b3..29083b895 100644 --- a/src/variables/random.jl +++ b/src/variables/random.jl @@ -1,4 +1,4 @@ -export randomvar, RandomVariableActivationOptions +export randomvar, RandomVariable, RandomVariableActivationOptions ## Random variable implementation diff --git a/test/node_tests.jl b/test/node_tests.jl index 4359b0c14..e69de29bb 100644 --- a/test/node_tests.jl +++ b/test/node_tests.jl @@ -1,81 +0,0 @@ -@testitem "FactorNode" begin - using ReactiveMP, Rocket, BayesBase, Distributions - - @testset "Common" begin - @test ReactiveMP.as_node_functional_form(() -> nothing) === ReactiveMP.UndefinedNodeFunctionalForm() - @test ReactiveMP.as_node_functional_form(2) === ReactiveMP.UndefinedNodeFunctionalForm() - - @test isdeterministic(Deterministic()) === true - @test isdeterministic(Deterministic) === true - @test isdeterministic(Stochastic()) === false - @test isdeterministic(Stochastic) === false - @test isstochastic(Deterministic()) === false - @test isstochastic(Deterministic) === false - @test isstochastic(Stochastic()) === true - @test isstochastic(Stochastic) === true - - @test sdtype(() -> nothing) === Deterministic() - @test_throws MethodError sdtype(0) - end - - @testset "@node macro" begin - - # Testing Stochastic node specification - - struct CustomStochasticNode end - - @node CustomStochasticNode Stochastic [out, x, y, z] aliases = [(out, xx, yy, z)] - - @test ReactiveMP.sdtype(CustomStochasticNode) === Stochastic() - @test ReactiveMP.correct_interfaces(CustomStochasticNode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) - - # Testing stochastic function node specification - - function customstochasticnode end - - @node typeof(customstochasticnode) Stochastic [out, x, y, z] aliases = [(out, xx, yy, z)] - - @test ReactiveMP.sdtype(customstochasticnode) === Stochastic() - @test ReactiveMP.correct_interfaces(customstochasticnode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) - - # Testing Deterministic node specification - - struct CustomDeterministicNode end - - CustomDeterministicNode(x, y, z) = x + y + z - - @node CustomDeterministicNode Deterministic [out, x, y, z] aliases = [(out, xx, yy, z)] - - @test ReactiveMP.sdtype(CustomDeterministicNode) === Deterministic() - @test ReactiveMP.correct_interfaces(CustomDeterministicNode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) - - # Testing deterministic function node specification - - function customdeterministicnode end - - customdeterministicnode(x, y, z) = x + y + z - - @node typeof(customdeterministicnode) Deterministic [out, x, y, z] aliases = [(out, xx, yy, z)] - - @test ReactiveMP.sdtype(customdeterministicnode) === Deterministic() - @test ReactiveMP.correct_interfaces(customdeterministicnode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) - - - # Testing expected exceptions - - struct DummyStruct end - - @test_throws Exception eval(:(@node DummyStruct NotStochasticAndNotDeterministic [out, in, x])) - @test_throws Exception eval(:(@node DummyStruct Stochastic [1, in, x])) - @test_throws Exception eval(:(@node DummyStruct Stochastic [p, in, x] aliases = [([z], y, x)])) - @test_throws Exception eval(:(@node DummyStruct Stochastic [(out, aliases = [out]), in, x])) - @test_throws Exception eval(:(@node DummyStruct Stochastic [(out, aliases = [1]), in, x])) - @test_throws Exception eval(:(@node DummyStruct Stochastic [])) - end - - @testset "sdtype of an arbitrary distribution is Stochastic" begin - struct DummyDistribution <: Distribution{Univariate, Continuous} end - - @test sdtype(DummyDistribution) === Stochastic() - end -end \ No newline at end of file diff --git a/test/nodes/nodes_tests.jl b/test/nodes/nodes_tests.jl index ecd79d6a7..2739d2113 100644 --- a/test/nodes/nodes_tests.jl +++ b/test/nodes/nodes_tests.jl @@ -1,25 +1,104 @@ @testitem "GenericFactorNode constructor" begin - import ReactiveMP: GenericFactorNode, functionalform, getinterfaces, getinterface + import ReactiveMP: functionalform, getinterfaces, getinterface struct ArbitraryNodeType end function foo end @testset "functionalform" begin - @test @inferred(functionalform(GenericFactorNode(ArbitraryNodeType, (; )))) === ArbitraryNodeType - @test @inferred(functionalform(GenericFactorNode(foo, (; )))) === typeof(foo) + @test @inferred(functionalform(factornode(ArbitraryNodeType, (;)))) === ArbitraryNodeType + @test @inferred(functionalform(factornode(foo, (;)))) === typeof(foo) end @testset "getinterfaces" for fform in (ArbitraryNodeType, foo) - a = RandomVariable() - b = DataVariable() - c = ConstVariable(1) + a = randomvar() + b = datavar() + c = constvar(1) - let node = GenericFactorNode(fform, (a = a, b = b, c = c)) + let node = factornode(fform, (a = a, b = b, c = c)) @test name.(getinterfaces(node)) == (:a, :b, :c) @test name(getinterface(node, 1)) == :a @test name(getinterface(node, 2)) == :b @test name(getinterface(node, 3)) == :c end end +end + +@testitem "sdtype" begin + @test ReactiveMP.as_node_functional_form(() -> nothing) === ReactiveMP.UndefinedNodeFunctionalForm() + @test ReactiveMP.as_node_functional_form(2) === ReactiveMP.UndefinedNodeFunctionalForm() + + @test isdeterministic(Deterministic()) === true + @test isdeterministic(Deterministic) === true + @test isdeterministic(Stochastic()) === false + @test isdeterministic(Stochastic) === false + @test isstochastic(Deterministic()) === false + @test isstochastic(Deterministic) === false + @test isstochastic(Stochastic()) === true + @test isstochastic(Stochastic) === true + + @test sdtype(() -> nothing) === Deterministic() + @test_throws MethodError sdtype(0) +end + +@testitem "@node macro" begin + + # Testing Stochastic node specification + + struct CustomStochasticNode end + + @node CustomStochasticNode Stochastic [out, x, y, z] aliases = [(out, xx, yy, z)] + + @test ReactiveMP.sdtype(CustomStochasticNode) === Stochastic() + @test ReactiveMP.correct_interfaces(CustomStochasticNode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) + + # Testing stochastic function node specification + + function customstochasticnode end + + @node typeof(customstochasticnode) Stochastic [out, x, y, z] aliases = [(out, xx, yy, z)] + + @test ReactiveMP.sdtype(customstochasticnode) === Stochastic() + @test ReactiveMP.correct_interfaces(customstochasticnode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) + + # Testing Deterministic node specification + + struct CustomDeterministicNode end + + CustomDeterministicNode(x, y, z) = x + y + z + + @node CustomDeterministicNode Deterministic [out, x, y, z] aliases = [(out, xx, yy, z)] + + @test ReactiveMP.sdtype(CustomDeterministicNode) === Deterministic() + @test ReactiveMP.correct_interfaces(CustomDeterministicNode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) + + # Testing deterministic function node specification + + function customdeterministicnode end + + customdeterministicnode(x, y, z) = x + y + z + + @node typeof(customdeterministicnode) Deterministic [out, x, y, z] aliases = [(out, xx, yy, z)] + + @test ReactiveMP.sdtype(customdeterministicnode) === Deterministic() + @test ReactiveMP.correct_interfaces(customdeterministicnode, (out = 1, xx = 2, yy = 3, z = 4)) === (out = 1, x = 2, y = 3, z = 4) + + # Testing expected exceptions + + struct DummyStruct end + + @test_throws Exception eval(:(@node DummyStruct NotStochasticAndNotDeterministic [out, in, x])) + @test_throws Exception eval(:(@node DummyStruct Stochastic [1, in, x])) + @test_throws Exception eval(:(@node DummyStruct Stochastic [p, in, x] aliases = [([z], y, x)])) + @test_throws Exception eval(:(@node DummyStruct Stochastic [(out, aliases = [out]), in, x])) + @test_throws Exception eval(:(@node DummyStruct Stochastic [(out, aliases = [1]), in, x])) + @test_throws Exception eval(:(@node DummyStruct Stochastic [])) +end + +@testitem "sdtype of an arbitrary distribution is Stochastic" begin + using Distributions + + struct DummyDistribution <: Distribution{Univariate, Continuous} end + + @test sdtype(DummyDistribution) === Stochastic() end \ No newline at end of file From 4841a7d47734abf258075d18f6861aec8795a966 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 4 Mar 2024 15:54:22 +0100 Subject: [PATCH 4/5] return to the old Beta interface names --- src/nodes/nodes.jl | 4 ++-- src/nodes/predefined/beta.jl | 4 ++-- src/rules/beta/marginals.jl | 4 ++-- src/rules/beta/out.jl | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/nodes/nodes.jl b/src/nodes/nodes.jl index 50f1cdd19..2f62aa325 100644 --- a/src/nodes/nodes.jl +++ b/src/nodes/nodes.jl @@ -201,8 +201,8 @@ struct FactorNode{F, I} <: AbstractFactorNode end end -factornode(::Type{F}, interfaces::I) where {F, I} = FactorNode(F, __prepare_interfaces_generic(interfaces)) -factornode(::F, interfaces::I) where {F <: Function, I} = FactorNode(F, __prepare_interfaces_generic(interfaces)) +factornode(::Type{F}, interfaces::I) where {F, I} = FactorNode(F, __prepare_interfaces_generic(correct_interfaces(F, interfaces))) +factornode(::F, interfaces::I) where {F <: Function, I} = FactorNode(F, __prepare_interfaces_generic(correct_interfaces(F, interfaces))) functionalform(factornode::FactorNode{F}) where {F} = F getinterfaces(factornode::FactorNode) = factornode.interfaces diff --git a/src/nodes/predefined/beta.jl b/src/nodes/predefined/beta.jl index bccd8922e..e70959958 100644 --- a/src/nodes/predefined/beta.jl +++ b/src/nodes/predefined/beta.jl @@ -1,6 +1,6 @@ import SpecialFunctions: logbeta -@node Beta Stochastic [out, α, β] +@node Beta Stochastic [out, (a, aliases = [α]), (b, aliases = [β])] -@average_energy Beta (q_out::Any, q_α::Any, q_β::Any) = logbeta(mean(q_α), mean(q_β)) - (mean(q_α) - 1.0) * mean(log, q_out) - (mean(q_β) - 1.0) * mean(mirrorlog, q_out) +@average_energy Beta (q_out::Any, q_a::Any, q_b::Any) = logbeta(mean(q_a), mean(q_b)) - (mean(q_a) - 1.0) * mean(log, q_out) - (mean(q_b) - 1.0) * mean(mirrorlog, q_out) diff --git a/src/rules/beta/marginals.jl b/src/rules/beta/marginals.jl index 6ff7b03ad..624837b04 100644 --- a/src/rules/beta/marginals.jl +++ b/src/rules/beta/marginals.jl @@ -1,4 +1,4 @@ -@marginalrule Beta(:out_α_β) (m_out::Beta, m_α::PointMass, m_β::PointMass) = begin - return convert_paramfloattype((out = prod(ClosedProd(), Beta(mean(m_α), mean(m_β)), m_out), a = m_α, b = m_β)) +@marginalrule Beta(:out_a_b) (m_out::Beta, m_a::PointMass, m_b::PointMass) = begin + return convert_paramfloattype((out = prod(ClosedProd(), Beta(mean(m_a), mean(m_b)), m_out), a = m_a, b = m_b)) end diff --git a/src/rules/beta/out.jl b/src/rules/beta/out.jl index dc60165da..498056d38 100644 --- a/src/rules/beta/out.jl +++ b/src/rules/beta/out.jl @@ -1,4 +1,4 @@ -@rule Beta(:out, Marginalisation) (m_α::PointMass, m_β::PointMass) = Beta(mean(m_α), mean(m_β)) +@rule Beta(:out, Marginalisation) (m_a::PointMass, m_b::PointMass) = Beta(mean(m_a), mean(m_b)) -@rule Beta(:out, Marginalisation) (q_α::PointMass, q_β::PointMass) = Beta(mean(q_α), mean(q_β)) +@rule Beta(:out, Marginalisation) (q_a::PointMass, q_b::PointMass) = Beta(mean(q_a), mean(q_b)) From 6361f97a37f50ab73721cf5657c410888a6b7653 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 4 Mar 2024 15:55:36 +0100 Subject: [PATCH 5/5] 2prev --- test/node_tests.jl | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/node_tests.jl diff --git a/test/node_tests.jl b/test/node_tests.jl deleted file mode 100644 index e69de29bb..000000000