Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typo in DeferredMessage #419

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/message.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export AbstractMessage, Message, DefferedMessage
export AbstractMessage, Message, DeferredMessage
export getdata, is_clamped, is_initial, as_message

using Distributions
Expand Down Expand Up @@ -192,22 +192,22 @@ MacroHelpers.@proxy_methods Message getdata [

Distributions.mean(fn::Function, message::Message) = mean(fn, getdata(message))

## Deffered Message
## Deferred Message

"""
A special type of a message, for which the actual message is not computed immediately, but is computed later on demand (potentially never).
To compute and get the actual message, one needs to call the `as_message` method.
"""
mutable struct DefferedMessage{R, S, F} <: AbstractMessage
mutable struct DeferredMessage{R, S, F} <: AbstractMessage
const messages :: R
const marginals :: S
const mappingFn :: F
cache :: Union{Nothing, Message}
end

DefferedMessage(messages::R, marginals::S, mappingFn::F) where {R, S, F} = DefferedMessage(messages, marginals, mappingFn, nothing)
DeferredMessage(messages::R, marginals::S, mappingFn::F) where {R, S, F} = DeferredMessage(messages, marginals, mappingFn, nothing)

function Base.show(io::IO, message::DefferedMessage)
function Base.show(io::IO, message::DeferredMessage)
cache = getcache(message)
if isnothing(cache)
print(io, "DeferredMessage([ use `as_message` to compute the message ])")
Expand All @@ -216,22 +216,22 @@ function Base.show(io::IO, message::DefferedMessage)
end
end

getcache(message::DefferedMessage) = message.cache
setcache!(message::DefferedMessage, cache::Message) = message.cache = cache
getcache(message::DeferredMessage) = message.cache
setcache!(message::DeferredMessage, cache::Message) = message.cache = cache

function as_message(message::DefferedMessage)::Message
function as_message(message::DeferredMessage)::Message
return as_message(message, getcache(message))
end

function as_message(message::DefferedMessage, cache::Message)::Message
function as_message(message::DeferredMessage, cache::Message)::Message
return cache
end

function as_message(message::DefferedMessage, cache::Nothing)::Message
function as_message(message::DeferredMessage, cache::Nothing)::Message
return as_message(message, cache, getrecent(message.messages), getrecent(message.marginals))
end

function as_message(message::DefferedMessage, cache::Nothing, messages, marginals)::Message
function as_message(message::DeferredMessage, cache::Nothing, messages, marginals)::Message
computed = message.mappingFn(messages, marginals)
setcache!(message, computed)
return computed
Expand Down
2 changes: 1 addition & 1 deletion src/nodes/dependencies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function activate!(dependencies::FunctionalDependencies, factornode, options)
vmessageout = combineLatest((messages, marginals), PushNew())

mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, messagestag, marginalstag, meta, addons, node_if_required(fform, factornode), rulefallback)
(dependencies) -> DefferedMessage(dependencies[1], dependencies[2], messagemap)
(dependencies) -> DeferredMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = vmessageout |> map(AbstractMessage, mapping)
Expand Down
2 changes: 1 addition & 1 deletion src/nodes/predefined/delta/layouts/cvi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function deltafn_apply_layout(::CVIApproximationDeltaFnRuleLayout, ::Val{:m_out}
vmessageout = combineLatest((msgs_observable, marginals_observable), PushNew())

mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, factornode, rulefallback)
(dependencies) -> DefferedMessage(dependencies[1], dependencies[2], messagemap)
(dependencies) -> DeferredMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = with_statics(factornode, vmessageout)
Expand Down
6 changes: 3 additions & 3 deletions src/nodes/predefined/delta/layouts/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:m_out}, factorn
vmessageout = combineLatest((msgs_observable, marginals_observable), PushNew())

mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, factornode, rulefallback)
(dependencies) -> DefferedMessage(dependencies[1], dependencies[2], messagemap)
(dependencies) -> DeferredMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = with_statics(factornode, vmessageout)
Expand Down Expand Up @@ -116,7 +116,7 @@ function deltafn_apply_layout(::DeltaFnDefaultRuleLayout, ::Val{:m_in}, factorno
vmessageout = combineLatest((msgs_observable, marginals_observable), PushNew())

mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, factornode, rulefallback)
(dependencies) -> DefferedMessage(dependencies[1], dependencies[2], messagemap)
(dependencies) -> DeferredMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = with_statics(factornode, vmessageout)
Expand Down Expand Up @@ -184,7 +184,7 @@ function deltafn_apply_layout(::DeltaFnDefaultKnownInverseRuleLayout, ::Val{:m_i
vmessageout = combineLatest((msgs_observable, marginals_observable), PushNew())

mapping = let messagemap = MessageMapping(fform, vtag, vconstraint, msgs_names, marginal_names, meta, addons, factornode, rulefallback)
(dependencies) -> DefferedMessage(dependencies[1], dependencies[2], messagemap)
(dependencies) -> DeferredMessage(dependencies[1], dependencies[2], messagemap)
end

vmessageout = with_statics(factornode, vmessageout)
Expand Down
Loading