Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify variables structures for predictions functionality #248

Merged
merged 76 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 74 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
da44b05
Add MAR node
albertpod Oct 17, 2022
ef8e17e
Add rules prototypes
albertpod Oct 20, 2022
35aba35
Add rules for MAR
albertpod Oct 28, 2022
3462700
Update MAR
albertpod Oct 31, 2022
b160302
Merge branch 'master' into mar
albertpod Nov 13, 2022
24cfeaf
Merge branch 'master' into mar
albertpod Nov 27, 2022
91179d1
Update rules
albertpod Nov 28, 2022
6795aff
Update rules
albertpod Dec 30, 2022
bd7be22
WIP: Update mask MAR function
albertpod Jan 2, 2023
86a24e4
Bug fix
albertpod Jan 2, 2023
dcb5bd5
Update rules
albertpod Jan 3, 2023
adc58ad
Update rules
albertpod Jan 3, 2023
ee8b471
project: add BlockArrays dependency
bvdmitri Jan 4, 2023
a41b0d9
Merge branch 'mar' of github.com:biaslab/ReactiveMP.jl into mar
bvdmitri Jan 4, 2023
4d3e221
fix constructor
bvdmitri Jan 4, 2023
09140d1
Update rules
albertpod Jan 4, 2023
9e15486
Merge branch 'mar' of https://github.com/biaslab/ReactiveMP.jl into mar
albertpod Jan 4, 2023
78c5216
Update FE
albertpod Jan 4, 2023
a723fcd
Merge branch 'master' into mar
ismailsenoz Jan 4, 2023
67d47cc
Update rule
albertpod Jan 5, 2023
96731d0
Merge branch 'mar' of https://github.com/biaslab/ReactiveMP.jl into mar
albertpod Jan 5, 2023
f2fd4d5
Update MAR rules
albertpod Jan 6, 2023
5eab7be
Update rules
albertpod Jan 6, 2023
395ea37
WIP: Update marginals & lambda
albertpod Jan 6, 2023
2725813
Fix bug
albertpod Jan 7, 2023
cfc4243
Fix backward rule
albertpod Jan 9, 2023
7a91826
Update rules
albertpod Jan 10, 2023
d5248e0
Fix FE
albertpod Jan 10, 2023
01f59f1
Clean up
albertpod Jan 10, 2023
a839f3d
Update rules
albertpod Jan 11, 2023
b52f77b
Update MF rules
albertpod Jan 11, 2023
87399c2
Modify variables structures for predictions functionality
albertpod Jan 24, 2023
f173652
Merge branch 'master' into dev-predict
albertpod Jan 24, 2023
ae0e770
Make format
albertpod Jan 24, 2023
e2d56e8
Merge branch 'master' into dev-predict
albertpod Jan 30, 2023
696f8ea
WIP: Change data
albertpod Jan 30, 2023
22ce46e
feat: add allows_missings function & tests
bvdmitri Jan 30, 2023
54074d3
Make format
albertpod Feb 1, 2023
fe57269
Merge branch 'master' into mar
albertpod Feb 1, 2023
d9564ed
improve factorisation logic for prediction variables
bvdmitri Feb 1, 2023
33d7ebd
Merge branch 'dev-predict' into mar
albertpod Feb 1, 2023
e936392
fix: update warning for factorisation check
bvdmitri Feb 1, 2023
2ae7b4d
Merge branch 'dev-predict' into mar
albertpod Feb 2, 2023
07c5bdc
Update mapping for marginal
albertpod Feb 6, 2023
b99d989
Merge branch 'master' into dev-predict
albertpod Feb 6, 2023
9179163
Merge branch 'dev-predict' into mar
albertpod Feb 6, 2023
37d0a72
Make format
albertpod Feb 7, 2023
2bb563f
Delete WIPs
albertpod Feb 7, 2023
65cdae8
Make format
albertpod Feb 7, 2023
395cc62
Merge branch 'dev-predict' into mar
albertpod Feb 7, 2023
2a2aa82
Merge branch 'master' into mar
albertpod Feb 13, 2023
cbbc8f1
Merge branch 'master' into dev-predict
albertpod Feb 22, 2023
dd05659
Merge branch 'dev-predict' into mar
albertpod Feb 22, 2023
19547be
Merge master into dev-predict
albertpod Mar 6, 2023
a490f8a
fix tests
bvdmitri Mar 7, 2023
eaae1d2
Update rules
albertpod Mar 19, 2023
dc836fe
Fix MAR rules
albertpod Mar 19, 2023
644a8e8
Merge branch 'dev-predict' into mar
albertpod Mar 19, 2023
a3fcc37
Decrease allocs
albertpod Mar 19, 2023
c22d44a
Merge branch 'master' into mar
albertpod Mar 19, 2023
db06986
Optmize functions
albertpod Mar 21, 2023
3b4bd2b
Merge branch 'master' into dev-predict
albertpod Mar 21, 2023
ac9d8c0
Merge branch 'master' into mar
albertpod Mar 30, 2023
b89e419
Remove diffs
albertpod Mar 30, 2023
3ff75e8
Merge branch 'master' into mar
albertpod May 29, 2023
a6ce995
Merge branch 'master' into dev-predict
albertpod Jun 19, 2023
a84991e
Merge branch 'master' into mar
albertpod Jul 23, 2023
b04fb89
Merge branch 'mar' into dev-predict
albertpod Jul 23, 2023
59a79b0
Merge branch 'master' into dev-predict
albertpod Sep 8, 2023
e0edea1
Remove MV autoregressive node
albertpod Sep 8, 2023
0643188
Remove mv autoregressive from ReactiveMP.jl
albertpod Sep 12, 2023
c659769
Merge branch 'master' into dev-predict
albertpod Sep 12, 2023
a248d5b
Remove not needed exports
albertpod Sep 12, 2023
5ddad79
Remove BlockArrays
albertpod Sep 12, 2023
0c94a12
Update src/variables/data.jl
albertpod Sep 18, 2023
ad7b165
fix warning for predicted datavars
bvdmitri Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@ uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3"
authors = ["Dmitry Bagaev <d.v.bagaev@tue.nl>", "Albert Podusenko <a.podusenko@tue.nl>", "Bart van Erp <b.v.erp@tue.nl>", "Ismail Senoz <i.senoz@tue.nl>"]
version = "3.9.3"

[weakdeps]
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[extensions]
ReactiveMPOptimisersExt = "Optimisers"
ReactiveMPZygoteExt = "Zygote"
ReactiveMPRequiresExt = "Requires"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -38,6 +28,16 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"

[weakdeps]
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
ReactiveMPOptimisersExt = "Optimisers"
ReactiveMPRequiresExt = "Requires"
ReactiveMPZygoteExt = "Zygote"

[compat]
DataStructures = "0.17, 0.18"
Distributions = "0.24, 0.25"
Expand Down Expand Up @@ -73,9 +73,9 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
68 changes: 54 additions & 14 deletions src/constraints/specifications/factorisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,56 @@ resolve_factorisation(::UnspecifiedConstraints, any, allvariables, fform, variab
# Preoptimised dispatch rule for unspecified constraints and a deterministic node with any number of inputs
resolve_factorisation(::UnspecifiedConstraints, ::Deterministic, allvariables, fform, variables) = FullFactorisation()

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2),)
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable} = ((1,), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,))

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1, 2, 3),)
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable, V3 <: RandomVariable} = ((1,), (2, 3))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: RandomVariable} = ((1, 3), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1, 2), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: RandomVariable, V3 <: Union{<:ConstVariable, <:DataVariable}} = ((1,), (2,), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: Union{<:ConstVariable, <:DataVariable}, V2 <: Union{<:ConstVariable, <:DataVariable}, V3 <: RandomVariable} = ((1,), (2,), (3,))
# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs, random variable & constant variable
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: RandomVariable} = ((1, 2),)
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: ConstVariable, V2 <: RandomVariable} = ((1,), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: ConstVariable} = ((1,), (2,))

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 2 inputs, random variable & data variable
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2}) where {V1 <: DataVariable, V2 <: RandomVariable} =
allows_missings(variables[1]) ? ((1, 2),) : ((1,), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2}) where {V1 <: RandomVariable, V2 <: DataVariable} =
allows_missings(variables[2]) ? ((1, 2),) : ((1,), (2,))

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & constant variables
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1, 2, 3),)
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: RandomVariable} = ((1,), (2, 3))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: RandomVariable} = ((1, 3), (2,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: ConstVariable} = ((1, 2), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: ConstVariable} = ((1,), (2,), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: ConstVariable} = ((1,), (2,), (3,))
resolve_factorisation(::UnspecifiedConstraints, ::Stochastic, allvariables, fform, ::Tuple{V1, V2, V3}) where {V1 <: ConstVariable, V2 <: ConstVariable, V3 <: RandomVariable} = ((1,), (2,), (3,))

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & data variable
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: DataVariable, V2 <: RandomVariable, V3 <: RandomVariable} = allows_missings(variables[1]) ? ((1, 2, 3),) : ((1,), (2, 3))
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: RandomVariable, V2 <: DataVariable, V3 <: RandomVariable} = allows_missings(variables[2]) ? ((1, 2, 3),) : ((1, 3), (2,))
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: RandomVariable, V2 <: RandomVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1, 2, 3),) : ((1, 2), (3,))

# Preoptimised dispatch rules for unspecified constraints and a stochastic node with 3 inputs, random variable & data variable & const variable
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: DataVariable, V2 <: ConstVariable, V3 <: RandomVariable} = allows_missings(variables[1]) ? ((1, 3), (2,)) : ((1,), (2,), (3,))
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: DataVariable, V2 <: RandomVariable, V3 <: ConstVariable} = allows_missings(variables[1]) ? ((1, 2), (3,)) : ((1,), (2,), (3,))
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: ConstVariable, V2 <: DataVariable, V3 <: RandomVariable} = allows_missings(variables[2]) ? ((1,), (2, 3)) : ((1,), (3,), (2,))
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: RandomVariable, V2 <: DataVariable, V3 <: ConstVariable} = allows_missings(variables[2]) ? ((1, 2), (3,)) : ((1,), (2,), (3,))
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: ConstVariable, V2 <: RandomVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1,), (2, 3)) : ((1,), (2,), (3,))
resolve_factorisation(
::UnspecifiedConstraints, ::Stochastic, allvariables, fform, variables::Tuple{V1, V2, V3}
) where {V1 <: RandomVariable, V2 <: ConstVariable, V3 <: DataVariable} = allows_missings(variables[3]) ? ((1, 3), (2,)) : ((1,), (2,), (3,))

"""
resolve_factorisation(constraints, allvariables, fform, variables)
Expand Down Expand Up @@ -419,8 +456,11 @@ function resolve_factorisation(::Stochastic, constraints, allvariables, fform, _
index::Int = 1
shift::Int = 0
for varref in var_refs
if israndom(varref[3])
if israndom(varref[3]) || (isdata(varref[3]) && allows_missings(varref[3]))
# We process everything as usual if varref is a random variable
# or if the variable is data variable and it allows missing
# We probably should change the logic from "allows missings" to "used as prediction"
# For now we assume that if data variable allows missing input it is indeed "used as prediction"
__process_factorisation_entry!(varref[1], varref[2], shift)
else
# We filter out varref from all clusters if it is not random
Expand Down
8 changes: 7 additions & 1 deletion src/marginal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ function (mapping::MarginalMapping)(dependencies)
# Marginal is initial if it is not clamped and all of the inputs are either clamped or initial
is_marginal_initial = !is_marginal_clamped && (__check_all(is_clamped_or_initial, messages) && __check_all(is_clamped_or_initial, marginals))

marginal = marginalrule(marginal_mapping_fform(mapping), mapping.vtag, mapping.msgs_names, messages, mapping.marginals_names, marginals, mapping.meta, mapping.factornode)
marginal = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages)))
missing
elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals)))
missing
else
marginalrule(marginal_mapping_fform(mapping), mapping.vtag, mapping.msgs_names, messages, mapping.marginals_names, marginals, mapping.meta, mapping.factornode)
end

return Marginal(marginal, is_marginal_clamped, is_marginal_initial, nothing)
end
Expand Down
30 changes: 18 additions & 12 deletions src/message.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,24 @@ function materialize!(mapping::MessageMapping, messages, marginals)
# Message is initial if it is not clamped and all of the inputs are either clamped or initial
is_message_initial = !is_message_clamped && (__check_all(is_clamped_or_initial, messages) && __check_all(is_clamped_or_initial, marginals))

result, addons = rule(
message_mapping_fform(mapping),
mapping.vtag,
mapping.vconstraint,
mapping.msgs_names,
messages,
mapping.marginals_names,
marginals,
mapping.meta,
mapping.addons,
mapping.factornode
)
result, addons = if !isnothing(messages) && any(ismissing, TupleTools.flatten(getdata.(messages)))
missing, mapping.addons
elseif !isnothing(marginals) && any(ismissing, TupleTools.flatten(getdata.(marginals)))
missing, mapping.addons
else
rule(
message_mapping_fform(mapping),
mapping.vtag,
mapping.vconstraint,
mapping.msgs_names,
messages,
mapping.marginals_names,
marginals,
mapping.meta,
mapping.addons,
mapping.factornode
)
end

# Inject extra addons after the rule has been executed
addons = message_mapping_addons(mapping, getdata(messages), getdata(marginals), result, addons)
Expand Down
2 changes: 1 addition & 1 deletion src/node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ macro node(fformtype, sdtype, interfaces_list)
missingclustererr = "Cannot find the cluster for the variable connected to the `$(name)` interface around the `$fformtype` node."
quote
# If a variable `$name` is a constvar or a datavar
if ReactiveMP.isconst($(name)) || ReactiveMP.isdata($(name))
if ReactiveMP.isconst($(name)) || (ReactiveMP.isdata($(name)) && !ReactiveMP.allows_missings($(name)))
local __factorisation = ReactiveMP.factorisation(node)
# Find the factorization cluster associated with the constvar `$name`
local __index = ReactiveMP.interface_get_index(Val{$(QuoteNode(fbottomtype))}, Val{$(QuoteNode(name))})
Expand Down
42 changes: 37 additions & 5 deletions src/variables/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import Base: show
mutable struct DataVariable{D, S} <: AbstractVariable
name :: Symbol
collection_type :: AbstractVariableCollectionType
prediction :: MarginalObservable
input_messages :: Vector{MessageObservable{AbstractMessage}}
messageout :: S
nconnected :: Int
isproxy :: Bool
Expand Down Expand Up @@ -74,12 +76,16 @@ datavar(name::Symbol, ::Type{D}, dims::Tuple) where {D}
datavar(name::Symbol, ::Type{D}, dims::Vararg{Int}) where {D} = datavar(DataVariableCreationOptions(D), name, D, dims)

datavar(options::DataVariableCreationOptions{S}, name::Symbol, ::Type{D}, collection_type::AbstractVariableCollectionType = VariableIndividual()) where {S, D} =
DataVariable{D, S}(name, collection_type, options.subject, 0, options.isproxy, options.isused)
DataVariable{D, S}(name, collection_type, MarginalObservable(), Vector{MessageObservable{AbstractMessage}}(), options.subject, 0, options.isproxy, options.isused)

function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, length::Int) where {D}
return map(i -> datavar(similar(options), name, D, VariableVector(i)), 1:length)
end

function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, dim1::Int, extra_dims::Vararg{Int}) where {D}
return datavar(options, name, D, (dim1, extra_dims...))
end

function datavar(options::DataVariableCreationOptions, name::Symbol, ::Type{D}, dims::Tuple) where {D}
indices = CartesianIndices(dims)
size = axes(indices)
Expand All @@ -106,11 +112,17 @@ isdata(::AbstractArray{<:DataVariable}) = true
isconst(::DataVariable) = false
isconst(::AbstractArray{<:DataVariable}) = false

allows_missings(datavar::DataVariable) = allows_missings(datavar, eltype(datavar.messageout))

allows_missings(datavars::AbstractArray{<:DataVariable}) = all(allows_missings, datavars)
allows_missings(datavar::DataVariable, ::Type{Message{D}}) where {D} = false
allows_missings(datavar::DataVariable, ::Type{Union{Message{Missing}, Message{D}}} where {D}) = true

function Base.getindex(datavar::DataVariable, i...)
error("Variable $(indexed_name(datavar)) has been indexed with `[$(join(i, ','))]`. Direct indexing of `data` variables is not allowed.")
end

getlastindex(::DataVariable) = 1
getlastindex(datavar::DataVariable) = degree(datavar) + 1

messageout(datavar::DataVariable, ::Int) = datavar.messageout
messagein(datavar::DataVariable, ::Int) = error("It is not possible to get a reference for inbound message for datavar")
Expand Down Expand Up @@ -163,8 +175,28 @@ _makemarginal(datavar::DataVariable) = error("It is not possible to

setanonymous!(::DataVariable, ::Bool) = nothing

function setmessagein!(datavar::DataVariable, ::Int, messagein)
datavar.nconnected += 1
datavar.isused = true
function setmessagein!(datavar::DataVariable, index::Int, messagein)
if index === (degree(datavar) + 1)
push!(datavar.input_messages, messagein)
datavar.nconnected += 1
datavar.isused = true
else
error(
"Inconsistent state in setmessagein! function for data variable $(datavar). `index` should be equal to `degree(datavar) + 1 = $(degree(datavar) + 1)`, $(index) is given instead"
)
end
return nothing
end

marginal_prod_fn(datavar::DataVariable) = marginal_prod_fn(FoldLeftProdStrategy(), ProdAnalytical(), UnspecifiedFormConstraint(), FormConstraintCheckLast())

_getprediction(datavar::DataVariable) = datavar.prediction
_setprediction!(datavar::DataVariable, observable) = connect!(_getprediction(datavar), observable)
_makeprediction(datavar::DataVariable) = collectLatest(AbstractMessage, Marginal, datavar.input_messages, marginal_prod_fn(datavar))

# options here must implement at least `Rocket.getscheduler`
albertpod marked this conversation as resolved.
Show resolved Hide resolved
albertpod marked this conversation as resolved.
Show resolved Hide resolved
function activate!(datavar::DataVariable, options)
_setprediction!(datavar, _makeprediction(datavar))

return nothing
end
5 changes: 4 additions & 1 deletion src/variables/variable.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export AbstractVariable, degree
export is_clamped, is_marginalisation, is_moment_matching
export FoldLeftProdStrategy, FoldRightProdStrategy, CustomProdStrategy
export getmarginal, getmarginals, setmarginal!, setmarginals!, name, as_variable
export getprediction, getpredictions, getmarginal, getmarginals, setmarginal!, setmarginals!, name, as_variable
export setmessage!, setmessages!

using Rocket
Expand Down Expand Up @@ -80,6 +80,9 @@ add_pipeline_stage!(variable::AbstractVariable, stage) = error("Its not possible
# Helper functions
# Getters

getprediction(variable::AbstractVariable) = _getprediction(variable)
getpredictions(variables::AbstractArray{<:AbstractVariable}) = collectLatest(map(v -> getprediction(v), variables))

getmarginal(variable::AbstractVariable) = getmarginal(variable, SkipInitial())
getmarginal(variable::AbstractVariable, skip_strategy::MarginalSkipStrategy) = apply_skip_filter(_getmarginal(variable), skip_strategy)

Expand Down
Loading
Loading