Skip to content

Commit

Permalink
Restore the previous API
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Mar 4, 2024
1 parent 4408de4 commit 03ae1a0
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 58 deletions.
6 changes: 3 additions & 3 deletions src/nodes/nodes.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export Deterministic, Stochastic, isdeterministic, isstochastic, sdtype
export MeanField, FullFactorisation, Marginalisation, MomentMatching
export functionalform, interfaces, factorisation, localmarginals, localmarginalnames, metadata
export FactorNode
export FactorNode, factornode
export @node

using Rocket
Expand Down Expand Up @@ -197,8 +197,8 @@ struct FactorNode{F, I} <: AbstractFactorNode
end
end

FactorNode(::Type{F}, interfaces::I) where {F, I} = FactorNode(F, __prepare_interfaces_generic(interfaces))
FactorNode(::F, interfaces::I) where {F <: Function, I} = FactorNode(F, __prepare_interfaces_generic(interfaces))
factornode(::Type{F}, interfaces::I) where {F, I} = FactorNode(F, __prepare_interfaces_generic(interfaces))
factornode(::F, interfaces::I) where {F <: Function, I} = FactorNode(F, __prepare_interfaces_generic(interfaces))

functionalform(factornode::FactorNode{F}) where {F} = F
getinterfaces(factornode::FactorNode) = factornode.interfaces
Expand Down
4 changes: 3 additions & 1 deletion src/variables/constant.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export ConstVariable
export constvar

mutable struct ConstVariable <: AbstractVariable
marginal :: MarginalObservable
Expand All @@ -14,6 +14,8 @@ function ConstVariable(constant)
return ConstVariable(marginal, messageout, 0)
end

constvar(constant) = ConstVariable(constant)

degree(constvar::ConstVariable) = constvar.nconnected

israndom(::ConstVariable) = false
Expand Down
4 changes: 3 additions & 1 deletion src/variables/data.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export DataVariable, update!
export datavar, update!, DataVariableActivationOptions

mutable struct DataVariable{M, P} <: AbstractVariable
input_messages :: Vector{MessageObservable{AbstractMessage}}
Expand All @@ -14,6 +14,8 @@ function DataVariable()
return DataVariable(Vector{MessageObservable{AbstractMessage}}(), marginal, messageout, nothing) # MarginalObservable())
end

datavar() = DataVariable()

degree(datavar::DataVariable) = length(datavar.input_messages)

israndom(::DataVariable) = false
Expand Down
8 changes: 2 additions & 6 deletions src/variables/random.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export RandomVariable, RandomVariableActivationOptions
export randomvar, RandomVariableActivationOptions

## Random variable implementation

Expand All @@ -24,11 +24,7 @@ end
const DefaultMessageProdFn = messages_prod_fn(FoldLeftProdStrategy(), GenericProd(), UnspecifiedFormConstraint(), FormConstraintCheckLast())
const DefaultMarginalProdFn = marginal_prod_fn(FoldLeftProdStrategy(), GenericProd(), UnspecifiedFormConstraint(), FormConstraintCheckLast())

function RandomVariable()
return RandomVariable(DefaultMessageProdFn, DefaultMarginalProdFn)
end

function RandomVariable(messages_prod_fn::M, marginal_prod_fn::F) where {M, F}
function randomvar(messages_prod_fn::M = DefaultMessageProdFn, marginal_prod_fn::F = DefaultMarginalProdFn) where {M, F}
return RandomVariable{M, F}(Vector{MessageObservable{AbstractMessage}}(), Vector{MessageObservable{Message}}(), MarginalObservable(), messages_prod_fn, marginal_prod_fn)
end

Expand Down
24 changes: 12 additions & 12 deletions test/variables/constant_tests.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,47 @@

@testitem "ConstVariable: uninitialized" begin
import ReactiveMP: ConstVariable, messageout, messagein
import ReactiveMP: messageout, messagein

# Should throw if not initialised properly
@testset let constvar = ConstVariable(1)
@testset let var = constvar(1)
for i in 1:10
@test messageout(constvar, 1) === messageout(constvar, i)
@test_throws ErrorException messagein(constvar, i)
@test messageout(var, 1) === messageout(var, i)
@test_throws ErrorException messagein(var, i)
end
end
end

@testitem "ConstVariable: getmessagein!" begin
import ReactiveMP: ConstVariable, MessageObservable, create_messagein!, messagein, degree
import ReactiveMP: MessageObservable, create_messagein!, messagein, degree

# Test for different degrees `d`
@testset for d in 1:5:100
@testset let constvar = ConstVariable(1)
@testset let var = constvar(1)
for i in 1:d
messagein, index = create_messagein!(constvar)
messagein, index = create_messagein!(var)
@test messagein isa MessageObservable
@test index === 1
@test degree(constvar) === i
@test degree(var) === i
end
@test degree(constvar) === d
@test degree(var) === d
end
end
end

@testitem "ConstVariable: getmarginal" begin
using BayesBase

import ReactiveMP: ConstVariable, MessageObservable, create_messagein!, messagein, degree, activate!, connect!, DataVariableActivationOptions, messageout
import ReactiveMP: MessageObservable, create_messagein!, messagein, degree, activate!, connect!, DataVariableActivationOptions, messageout

include("../testutilities.jl")

@testset begin
# Test marginal computation
@testset for d in 1:5:100, constant in rand(10)
@testset let constvar = ConstVariable(constant)
@testset let var = constvar(constant)

marginal_expected = mgl(PointMass(constant))
marginal_result = check_stream_updated_once(getmarginal(constvar)) do
marginal_result = check_stream_updated_once(getmarginal(var)) do
nothing
end

Expand Down
32 changes: 16 additions & 16 deletions test/variables/data_tests.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,56 @@

@testitem "DataVariable: uninitialized" begin
import ReactiveMP: DataVariable, messageout, messagein
import ReactiveMP: messageout, messagein

# Should throw if not initialised properly
@testset let datavar = DataVariable()
@testset let var = datavar()
for i in 1:10
@test messageout(datavar, 1) === messageout(datavar, i)
@test_throws BoundsError messagein(datavar, i)
@test messageout(var, 1) === messageout(var, i)
@test_throws BoundsError messagein(var, i)
end
end
end

@testitem "DataVariable: getmessagein!" begin
import ReactiveMP: DataVariable, MessageObservable, create_messagein!, messagein, degree
import ReactiveMP: MessageObservable, create_messagein!, messagein, degree

# Test for different degrees `d`
@testset for d in 1:5:100
@testset let datavar = DataVariable()
@testset let var = datavar()
for i in 1:d
messagein, index = create_messagein!(datavar)
messagein, index = create_messagein!(var)
@test messagein isa MessageObservable
@test index === i
@test degree(datavar) === i
@test degree(var) === i
end
@test degree(datavar) === d
@test degree(var) === d
end
end
end

@testitem "DataVariable: getmarginal" begin
using BayesBase

import ReactiveMP: DataVariable, MessageObservable, create_messagein!, messagein, degree, activate!, connect!, DataVariableActivationOptions, messageout
import ReactiveMP: MessageObservable, create_messagein!, messagein, degree, activate!, connect!, DataVariableActivationOptions, messageout

include("../testutilities.jl")

@testset begin
# Test marginal computation
@testset for d in 1:5:100
@testset let datavar = DataVariable()
@testset let var = datavar()
messageins = map(1:d) do _
s = Subject(AbstractMessage)
m, i = create_messagein!(datavar)
m, i = create_messagein!(var)
connect!(m, s)
return s
end

activate!(datavar, DataVariableActivationOptions(false))
activate!(var, DataVariableActivationOptions(false))

messages = map(msg, rand(d))

@test check_stream_not_updated(getmarginal(datavar)) do
@test check_stream_not_updated(getmarginal(var)) do
foreach(zip(messageins, messages)) do (messagein, message)
next!(messagein, message)
end
Expand All @@ -59,8 +59,8 @@ end
data_point = rand()

marginal_expected = mgl(PointMass(data_point))
marginal_result = check_stream_updated_once(getmarginal(datavar)) do
update!(datavar, data_point)
marginal_result = check_stream_updated_once(getmarginal(var)) do
update!(var, data_point)
end

@test getdata(marginal_result) === getdata(marginal_expected)
Expand Down
38 changes: 19 additions & 19 deletions test/variables/random_tests.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,55 @@

@testitem "RandomVariable: uninitialized" begin
import ReactiveMP: RandomVariable, messageout, messagein
import ReactiveMP: messageout, messagein

# Should throw if not initialised properly
@testset let randomvar = RandomVariable()
@testset let var = randomvar()
for i in 1:10
@test_throws BoundsError messageout(randomvar, i)
@test_throws BoundsError messagein(randomvar, i)
@test_throws BoundsError messageout(var, i)
@test_throws BoundsError messagein(var, i)
end
end
end

@testitem "RandomVariable: getmessagein!" begin
import ReactiveMP: RandomVariable, MessageObservable, create_messagein!, messagein, degree
import ReactiveMP: MessageObservable, create_messagein!, messagein, degree

# Test for different degrees `d`
@testset for d in 1:5:100
@testset let randomvar = RandomVariable()
@testset let var = randomvar()
for i in 1:d
messagein, index = create_messagein!(randomvar)
messagein, index = create_messagein!(var)
@test messagein isa MessageObservable
@test index === i
@test degree(randomvar) === i
@test degree(var) === i
end
@test degree(randomvar) === d
@test degree(var) === d
end
end
end

@testitem "RandomVariable: getmarginal" begin
import ReactiveMP: RandomVariable, MessageObservable, create_messagein!, messagein, degree, activate!, connect!, RandomVariableActivationOptions, messageout
import ReactiveMP: MessageObservable, create_messagein!, messagein, degree, activate!, connect!, RandomVariableActivationOptions, messageout

include("../testutilities.jl")

message_prod_fn = (msgs) -> error("Messages should not be called here")
marginal_prod_fn = (msgs) -> mgl(sum(getdata.(msgs)))
@testset for d in 1:5:100
@testset let randomvar = RandomVariable(message_prod_fn, marginal_prod_fn)
@testset let var = randomvar(message_prod_fn, marginal_prod_fn)
messageins = map(1:d) do _
s = Subject(AbstractMessage)
m, i = create_messagein!(randomvar)
m, i = create_messagein!(var)
connect!(m, s)
return s
end

activate!(randomvar, RandomVariableActivationOptions())
activate!(var, RandomVariableActivationOptions())

messages = map(msg, rand(d))

marginal_expected = marginal_prod_fn(messages)
marginal_result = check_stream_updated_once(getmarginal(randomvar)) do
marginal_result = check_stream_updated_once(getmarginal(var)) do
foreach(zip(messageins, messages)) do (messagein, message)
next!(messagein, message)
end
Expand All @@ -63,7 +63,7 @@ end
end

@testitem "RandomVariable: messageout" begin
import ReactiveMP: RandomVariable, MessageObservable, create_messagein!, messagein, degree, activate!, connect!, RandomVariableActivationOptions, messageout
import ReactiveMP: MessageObservable, create_messagein!, messagein, degree, activate!, connect!, RandomVariableActivationOptions, messageout

include("../testutilities.jl")

Expand All @@ -72,21 +72,21 @@ end

# We start from `2` because `1` is not a valid degree for a random variable
@testset for d in 2:5:100, k in 1:d
@testset let randomvar = RandomVariable(message_prod_fn, marginal_prod_fn)
@testset let var = randomvar(message_prod_fn, marginal_prod_fn)
messageins = map(1:d) do _
s = Subject(AbstractMessage)
m, i = create_messagein!(randomvar)
m, i = create_messagein!(var)
connect!(m, s)
return s
end

activate!(randomvar, RandomVariableActivationOptions())
activate!(var, RandomVariableActivationOptions())

messages = map(msg, rand(d))

# the outbound message is the result of multiplication of `n - 1` messages excluding index `k`
kmessage_expected = message_prod_fn(collect(skipindex(messages, k)))
kmessage_result = check_stream_updated_once(messageout(randomvar, k)) do
kmessage_result = check_stream_updated_once(messageout(var, k)) do
foreach(zip(messageins, messages)) do (messagein, message)
next!(messagein, message)
end
Expand Down

0 comments on commit 03ae1a0

Please sign in to comment.