diff --git a/src/nodes/nodes.jl b/src/nodes/nodes.jl index 45376ea04..fc5973dd5 100644 --- a/src/nodes/nodes.jl +++ b/src/nodes/nodes.jl @@ -1,7 +1,7 @@ export Deterministic, Stochastic, isdeterministic, isstochastic, sdtype export MeanField, FullFactorisation, Marginalisation, MomentMatching export functionalform, interfaces, factorisation, localmarginals, localmarginalnames, metadata -export FactorNode +export FactorNode, factornode export @node using Rocket @@ -197,8 +197,8 @@ struct FactorNode{F, I} <: AbstractFactorNode end end -FactorNode(::Type{F}, interfaces::I) where {F, I} = FactorNode(F, __prepare_interfaces_generic(interfaces)) -FactorNode(::F, interfaces::I) where {F <: Function, I} = FactorNode(F, __prepare_interfaces_generic(interfaces)) +factornode(::Type{F}, interfaces::I) where {F, I} = FactorNode(F, __prepare_interfaces_generic(interfaces)) +factornode(::F, interfaces::I) where {F <: Function, I} = FactorNode(F, __prepare_interfaces_generic(interfaces)) functionalform(factornode::FactorNode{F}) where {F} = F getinterfaces(factornode::FactorNode) = factornode.interfaces diff --git a/src/variables/constant.jl b/src/variables/constant.jl index bd31fdac4..034db5fc1 100644 --- a/src/variables/constant.jl +++ b/src/variables/constant.jl @@ -1,4 +1,4 @@ -export ConstVariable +export constvar mutable struct ConstVariable <: AbstractVariable marginal :: MarginalObservable @@ -14,6 +14,8 @@ function ConstVariable(constant) return ConstVariable(marginal, messageout, 0) end +constvar(constant) = ConstVariable(constant) + degree(constvar::ConstVariable) = constvar.nconnected israndom(::ConstVariable) = false diff --git a/src/variables/data.jl b/src/variables/data.jl index f1f8b3e68..fdfe2d74f 100644 --- a/src/variables/data.jl +++ b/src/variables/data.jl @@ -1,4 +1,4 @@ -export DataVariable, update! +export datavar, update!, DataVariableActivationOptions mutable struct DataVariable{M, P} <: AbstractVariable input_messages :: Vector{MessageObservable{AbstractMessage}} @@ -14,6 +14,8 @@ function DataVariable() return DataVariable(Vector{MessageObservable{AbstractMessage}}(), marginal, messageout, nothing) # MarginalObservable()) end +datavar() = DataVariable() + degree(datavar::DataVariable) = length(datavar.input_messages) israndom(::DataVariable) = false diff --git a/src/variables/random.jl b/src/variables/random.jl index c71065f95..6cfdf29b3 100644 --- a/src/variables/random.jl +++ b/src/variables/random.jl @@ -1,4 +1,4 @@ -export RandomVariable, RandomVariableActivationOptions +export randomvar, RandomVariableActivationOptions ## Random variable implementation @@ -24,11 +24,7 @@ end const DefaultMessageProdFn = messages_prod_fn(FoldLeftProdStrategy(), GenericProd(), UnspecifiedFormConstraint(), FormConstraintCheckLast()) const DefaultMarginalProdFn = marginal_prod_fn(FoldLeftProdStrategy(), GenericProd(), UnspecifiedFormConstraint(), FormConstraintCheckLast()) -function RandomVariable() - return RandomVariable(DefaultMessageProdFn, DefaultMarginalProdFn) -end - -function RandomVariable(messages_prod_fn::M, marginal_prod_fn::F) where {M, F} +function randomvar(messages_prod_fn::M = DefaultMessageProdFn, marginal_prod_fn::F = DefaultMarginalProdFn) where {M, F} return RandomVariable{M, F}(Vector{MessageObservable{AbstractMessage}}(), Vector{MessageObservable{Message}}(), MarginalObservable(), messages_prod_fn, marginal_prod_fn) end diff --git a/test/variables/constant_tests.jl b/test/variables/constant_tests.jl index 2b21e5aef..dd2530ef5 100644 --- a/test/variables/constant_tests.jl +++ b/test/variables/constant_tests.jl @@ -1,29 +1,29 @@ @testitem "ConstVariable: uninitialized" begin - import ReactiveMP: ConstVariable, messageout, messagein + import ReactiveMP: messageout, messagein # Should throw if not initialised properly - @testset let constvar = ConstVariable(1) + @testset let var = constvar(1) for i in 1:10 - @test messageout(constvar, 1) === messageout(constvar, i) - @test_throws ErrorException messagein(constvar, i) + @test messageout(var, 1) === messageout(var, i) + @test_throws ErrorException messagein(var, i) end end end @testitem "ConstVariable: getmessagein!" begin - import ReactiveMP: ConstVariable, MessageObservable, create_messagein!, messagein, degree + import ReactiveMP: MessageObservable, create_messagein!, messagein, degree # Test for different degrees `d` @testset for d in 1:5:100 - @testset let constvar = ConstVariable(1) + @testset let var = constvar(1) for i in 1:d - messagein, index = create_messagein!(constvar) + messagein, index = create_messagein!(var) @test messagein isa MessageObservable @test index === 1 - @test degree(constvar) === i + @test degree(var) === i end - @test degree(constvar) === d + @test degree(var) === d end end end @@ -31,17 +31,17 @@ end @testitem "ConstVariable: getmarginal" begin using BayesBase - import ReactiveMP: ConstVariable, MessageObservable, create_messagein!, messagein, degree, activate!, connect!, DataVariableActivationOptions, messageout + import ReactiveMP: MessageObservable, create_messagein!, messagein, degree, activate!, connect!, DataVariableActivationOptions, messageout include("../testutilities.jl") @testset begin # Test marginal computation @testset for d in 1:5:100, constant in rand(10) - @testset let constvar = ConstVariable(constant) + @testset let var = constvar(constant) marginal_expected = mgl(PointMass(constant)) - marginal_result = check_stream_updated_once(getmarginal(constvar)) do + marginal_result = check_stream_updated_once(getmarginal(var)) do nothing end diff --git a/test/variables/data_tests.jl b/test/variables/data_tests.jl index 1c7794392..7082bb211 100644 --- a/test/variables/data_tests.jl +++ b/test/variables/data_tests.jl @@ -1,29 +1,29 @@ @testitem "DataVariable: uninitialized" begin - import ReactiveMP: DataVariable, messageout, messagein + import ReactiveMP: messageout, messagein # Should throw if not initialised properly - @testset let datavar = DataVariable() + @testset let var = datavar() for i in 1:10 - @test messageout(datavar, 1) === messageout(datavar, i) - @test_throws BoundsError messagein(datavar, i) + @test messageout(var, 1) === messageout(var, i) + @test_throws BoundsError messagein(var, i) end end end @testitem "DataVariable: getmessagein!" begin - import ReactiveMP: DataVariable, MessageObservable, create_messagein!, messagein, degree + import ReactiveMP: MessageObservable, create_messagein!, messagein, degree # Test for different degrees `d` @testset for d in 1:5:100 - @testset let datavar = DataVariable() + @testset let var = datavar() for i in 1:d - messagein, index = create_messagein!(datavar) + messagein, index = create_messagein!(var) @test messagein isa MessageObservable @test index === i - @test degree(datavar) === i + @test degree(var) === i end - @test degree(datavar) === d + @test degree(var) === d end end end @@ -31,26 +31,26 @@ end @testitem "DataVariable: getmarginal" begin using BayesBase - import ReactiveMP: DataVariable, MessageObservable, create_messagein!, messagein, degree, activate!, connect!, DataVariableActivationOptions, messageout + import ReactiveMP: MessageObservable, create_messagein!, messagein, degree, activate!, connect!, DataVariableActivationOptions, messageout include("../testutilities.jl") @testset begin # Test marginal computation @testset for d in 1:5:100 - @testset let datavar = DataVariable() + @testset let var = datavar() messageins = map(1:d) do _ s = Subject(AbstractMessage) - m, i = create_messagein!(datavar) + m, i = create_messagein!(var) connect!(m, s) return s end - activate!(datavar, DataVariableActivationOptions(false)) + activate!(var, DataVariableActivationOptions(false)) messages = map(msg, rand(d)) - @test check_stream_not_updated(getmarginal(datavar)) do + @test check_stream_not_updated(getmarginal(var)) do foreach(zip(messageins, messages)) do (messagein, message) next!(messagein, message) end @@ -59,8 +59,8 @@ end data_point = rand() marginal_expected = mgl(PointMass(data_point)) - marginal_result = check_stream_updated_once(getmarginal(datavar)) do - update!(datavar, data_point) + marginal_result = check_stream_updated_once(getmarginal(var)) do + update!(var, data_point) end @test getdata(marginal_result) === getdata(marginal_expected) diff --git a/test/variables/random_tests.jl b/test/variables/random_tests.jl index c623da22c..a02b1421e 100644 --- a/test/variables/random_tests.jl +++ b/test/variables/random_tests.jl @@ -1,55 +1,55 @@ @testitem "RandomVariable: uninitialized" begin - import ReactiveMP: RandomVariable, messageout, messagein + import ReactiveMP: messageout, messagein # Should throw if not initialised properly - @testset let randomvar = RandomVariable() + @testset let var = randomvar() for i in 1:10 - @test_throws BoundsError messageout(randomvar, i) - @test_throws BoundsError messagein(randomvar, i) + @test_throws BoundsError messageout(var, i) + @test_throws BoundsError messagein(var, i) end end end @testitem "RandomVariable: getmessagein!" begin - import ReactiveMP: RandomVariable, MessageObservable, create_messagein!, messagein, degree + import ReactiveMP: MessageObservable, create_messagein!, messagein, degree # Test for different degrees `d` @testset for d in 1:5:100 - @testset let randomvar = RandomVariable() + @testset let var = randomvar() for i in 1:d - messagein, index = create_messagein!(randomvar) + messagein, index = create_messagein!(var) @test messagein isa MessageObservable @test index === i - @test degree(randomvar) === i + @test degree(var) === i end - @test degree(randomvar) === d + @test degree(var) === d end end end @testitem "RandomVariable: getmarginal" begin - import ReactiveMP: RandomVariable, MessageObservable, create_messagein!, messagein, degree, activate!, connect!, RandomVariableActivationOptions, messageout + import ReactiveMP: MessageObservable, create_messagein!, messagein, degree, activate!, connect!, RandomVariableActivationOptions, messageout include("../testutilities.jl") message_prod_fn = (msgs) -> error("Messages should not be called here") marginal_prod_fn = (msgs) -> mgl(sum(getdata.(msgs))) @testset for d in 1:5:100 - @testset let randomvar = RandomVariable(message_prod_fn, marginal_prod_fn) + @testset let var = randomvar(message_prod_fn, marginal_prod_fn) messageins = map(1:d) do _ s = Subject(AbstractMessage) - m, i = create_messagein!(randomvar) + m, i = create_messagein!(var) connect!(m, s) return s end - activate!(randomvar, RandomVariableActivationOptions()) + activate!(var, RandomVariableActivationOptions()) messages = map(msg, rand(d)) marginal_expected = marginal_prod_fn(messages) - marginal_result = check_stream_updated_once(getmarginal(randomvar)) do + marginal_result = check_stream_updated_once(getmarginal(var)) do foreach(zip(messageins, messages)) do (messagein, message) next!(messagein, message) end @@ -63,7 +63,7 @@ end end @testitem "RandomVariable: messageout" begin - import ReactiveMP: RandomVariable, MessageObservable, create_messagein!, messagein, degree, activate!, connect!, RandomVariableActivationOptions, messageout + import ReactiveMP: MessageObservable, create_messagein!, messagein, degree, activate!, connect!, RandomVariableActivationOptions, messageout include("../testutilities.jl") @@ -72,21 +72,21 @@ end # We start from `2` because `1` is not a valid degree for a random variable @testset for d in 2:5:100, k in 1:d - @testset let randomvar = RandomVariable(message_prod_fn, marginal_prod_fn) + @testset let var = randomvar(message_prod_fn, marginal_prod_fn) messageins = map(1:d) do _ s = Subject(AbstractMessage) - m, i = create_messagein!(randomvar) + m, i = create_messagein!(var) connect!(m, s) return s end - activate!(randomvar, RandomVariableActivationOptions()) + activate!(var, RandomVariableActivationOptions()) messages = map(msg, rand(d)) # the outbound message is the result of multiplication of `n - 1` messages excluding index `k` kmessage_expected = message_prod_fn(collect(skipindex(messages, k))) - kmessage_result = check_stream_updated_once(messageout(randomvar, k)) do + kmessage_result = check_stream_updated_once(messageout(var, k)) do foreach(zip(messageins, messages)) do (messagein, message) next!(messagein, message) end