Skip to content

Commit

Permalink
Merge pull request #386 from ReactiveBayes/reimplement-node-macro
Browse files Browse the repository at this point in the history
Reimplement node macro
  • Loading branch information
bvdmitri authored Mar 4, 2024
2 parents 03ae1a0 + 6361f97 commit 9f438e5
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 317 deletions.
202 changes: 53 additions & 149 deletions src/nodes/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import Rocket: getscheduler
import Base: show, +, push!, iterate, IteratorSize, IteratorEltype, eltype, length, size
import Base: getindex, setindex!, firstindex, lastindex

##

function make_node end # TODO (bvdmitri): remove this

## Node traits

"""
Expand Down Expand Up @@ -197,8 +201,8 @@ struct FactorNode{F, I} <: AbstractFactorNode
end
end

factornode(::Type{F}, interfaces::I) where {F, I} = FactorNode(F, __prepare_interfaces_generic(interfaces))
factornode(::F, interfaces::I) where {F <: Function, I} = FactorNode(F, __prepare_interfaces_generic(interfaces))
factornode(::Type{F}, interfaces::I) where {F, I} = FactorNode(F, __prepare_interfaces_generic(correct_interfaces(F, interfaces)))
factornode(::F, interfaces::I) where {F <: Function, I} = FactorNode(F, __prepare_interfaces_generic(correct_interfaces(F, interfaces)))

functionalform(factornode::FactorNode{F}) where {F} = F
getinterfaces(factornode::FactorNode) = factornode.interfaces
Expand Down Expand Up @@ -234,170 +238,70 @@ end

import .MacroHelpers

# Are still needed for the `@node` macro
function make_node end
function interface_get_index end
function interface_get_name end

"""
@node(fformtype, sdtype, interfaces_list)
`@node` macro creates a node for a `fformtype` type object. To obtain a list of available nodes use `?make_node`.
# Arguments
- `fformtype`: Either an existing type identifier, e.g. `Normal` or a function type identifier, e.g. `typeof(+)`
- `sdtype`: Either `Stochastic` or `Deterministic`. Defines the type of the functional relationship
- `interfaces_list`: Defines a fixed list of edges of a factor node, by convention the first element should be `out`. Example: `[ out, mean, variance ]`
Note: `interfaces_list` must not include names that contain `_` symbol in them, as it is reserved to identify joint posteriors around the node object.
# Examples
```julia
function correct_interfaces end

struct MyNormalDistribution
mean :: Float64
var :: Float64
end
@node MyNormalDistribution Stochastic [ out, mean, var ]
```
```julia
@node typeof(+) Deterministic [ out, in1, in2 ]
```
# List of available nodes
See `?make_node`.
See also: [`make_node`](@ref), [`Stochastic`](@ref), [`Deterministic`](@ref)
"""
macro node(fformtype, sdtype, interfaces_list)
fbottomtype = MacroHelpers.bottom_type(fformtype)
fuppertype = MacroHelpers.upper_type(fformtype)

@assert sdtype [:Stochastic, :Deterministic] "Invalid sdtype $(sdtype). Can be either Stochastic or Deterministic."

@capture(interfaces_list, [interfaces_args__]) || error("Invalid interfaces specification.")

interfaces = map(interfaces_args) do arg
if @capture(arg, name_Symbol)
return (name, [])
elseif @capture(arg, (name_Symbol, aliases = [aliases__]))
@assert all(a -> a isa Symbol && !isequal(a, name), aliases)
return (name, aliases)
else
error("Interface specification should have a 'name' or (name, aliases = [ alias1, alias2,... ]) signature.")
alias_group(s::Symbol) = [s]
function alias_group(e::Expr)
if @capture(e, (s_, aliases = aliases_))
result = [s, aliases.args...]
if length(result) != length(unique(result))
error("Aliases should be unique")
end
return result
else
return [e]
end
end

@assert length(interfaces) !== 0 "Node should have at least one interface."
check_all_symbol(::AbstractArray{T} where {T <: NTuple{N, Symbol} where {N}}) = nothing
check_all_symbol(::Any) = error("All interfaces should be symbols")

names = map(d -> d[1], interfaces)
aliases = map(d -> d[2], interfaces)
function generate_node_expression(node_fform, node_type, node_interfaces, interface_aliases)
# Assert that the node type is either Stochastic or Deterministic, and that all interfaces are symbols
@assert node_type [:Stochastic, :Deterministic]
@assert length(node_interfaces.args) > 0

foreach(names) do name
@assert !occursin('_', string(name)) "Node interfaces names (and aliases) must not contain `_` symbol in them, found in $(name)."
end
interface_alias_groups = map(alias_group, node_interfaces.args)
all_aliases = vec(collect(Iterators.product(interface_alias_groups...)))

foreach(Iterators.flatten(aliases)) do alias
@assert !occursin('_', string(alias)) "Node interfaces names (and aliases) must not contain `_` symbol in them, found in $(alias)."
# Determine whether we should dispatch on `typeof($fform)` or `Type{$node_fform}`
dispatch_type = if @capture(node_fform, typeof(fform_))
:(typeof($fform))
else
:(Type{$node_fform})
end

names_quoted_tuple = Expr(:tuple, map(name -> Expr(:quote, name), names)...)
names_indices = Expr(:tuple, map(i -> i, 1:length(names))...)
names_splitted_indices = Expr(:tuple, map(i -> Expr(:tuple, i), 1:length(names))...)
names_indexed = Expr(:tuple, map(name -> Expr(:call, :(ReactiveMP.indexed_name), name), names)...)

interface_names = map(name -> :(ReactiveMP.indexed_name($name)), names)
interface_args = map(name -> :($name), names)
interface_connections = map(name -> :(ReactiveMP.connect!(node, $(Expr(:quote, name)), $name)), names)

joined_interface_names = :(join((($(interface_names...)),), ", "))

# Check that all arguments within interface refer to the unique var objects
non_unique_error_sym = gensym(:non_unique_error_sym)
non_unique_error_msg = :($non_unique_error_sym = (fformtype, names) -> """
Non-unique variables used for the creation of the `$(fformtype)` node, which is disallowed.
Check creation of the `$(fformtype)` with the `[ $(join(names, ", ")) ]` arguments.
""")
interface_uniqueness = map(enumerate(names)) do (index, name)
names_without_current = skipindex(names, index)
return quote
if Base.in($(name), ($(names_without_current...),))
Base.error($(non_unique_error_sym)($fformtype, $names_indexed))
end
end
# If there are any aliases, define the alias correction function
if @capture(interface_aliases, aliases = aliases_)
defined_aliases = map(alias_group -> Tuple(alias_group.args), aliases.args)
all_aliases = vcat(all_aliases, defined_aliases)
end

# Here we create helpers function for GraphPPL.jl interfacing
# They are used to convert interface names from `where { q = q(x, y)q(z) }` to an equivalent tuple respresentation, e.g. `((1, 2), (3, ))`
# The general recipe to get a proper index is to call `interface_get_index(Val{ :NodeTypeName }, interface_get_name(Val{ :NodeTypeName }, Val{ :name_expression }))`
interface_name_getters = map(enumerate(interfaces)) do (index, interface)
name = first(interface)
aliases = last(interface)
check_all_symbol(all_aliases)

index_name_getter = :(ReactiveMP.interface_get_index(::Type{Val{$(Expr(:quote, fbottomtype))}}, ::Type{Val{$(Expr(:quote, name))}}) = $(index))
name_symbol_getter = :(ReactiveMP.interface_get_name(::Type{Val{$(Expr(:quote, fbottomtype))}}, ::Type{Val{$(Expr(:quote, name))}}) = $(Expr(:quote, name)))
name_index_getter = :(ReactiveMP.interface_get_name(::Type{Val{$(Expr(:quote, fbottomtype))}}, ::Type{Val{$index}}) = $(Expr(:quote, name)))
first_interfaces = map(first, interface_alias_groups)

alias_getters = map(aliases) do alias
return :(ReactiveMP.interface_get_name(::Type{Val{$(Expr(:quote, fbottomtype))}}, ::Type{Val{$(Expr(:quote, alias))}}) = $(Expr(:quote, name)))
end

return quote
$index_name_getter
$name_symbol_getter
$name_index_getter
$(alias_getters...)
end
alias_corrections = Expr(:block)
alias_corrections.args = map(all_aliases) do alias
:(ReactiveMP.correct_interfaces(::$dispatch_type, nt::NamedTuple{$alias}) = NamedTuple{$(Tuple(first_interfaces))}(values(nt)))
end

# By default every argument passed to a factorisation option of the node is transformed by
# `collect_factorisation` function to have a tuple like structure.
# The default recipe is simple: for stochastic nodes we convert `FullFactorisation` and `MeanField` objects
# to their tuple of indices equivalents. For deterministic nodes any factorisation is replaced by a FullFactorisation equivalent
factorisation_collectors = if sdtype === :Stochastic
quote
ReactiveMP.collect_factorisation(::$fuppertype, ::Nothing) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, factorisation::Tuple) = factorisation
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.FullFactorisation) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.MeanField) = $names_splitted_indices
end
# Define the necessary function types
result = quote
ReactiveMP.as_node_functional_form(::$dispatch_type) = ReactiveMP.ValidNodeFunctionalForm()
ReactiveMP.sdtype(::$dispatch_type) = (ReactiveMP.$node_type)()
ReactiveMP.collect_factorisation(::$dispatch_type, factorisation::Tuple) = factorisation

elseif sdtype === :Deterministic
quote
ReactiveMP.collect_factorisation(::$fuppertype, ::Nothing) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, factorisation::Tuple) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.FullFactorisation) = ($names_indices,)
ReactiveMP.collect_factorisation(::$fuppertype, ::ReactiveMP.MeanField) = ($names_indices,)
end
else
error("Unreachable in @node macro.")
$alias_corrections
end

doctype = rpad(fbottomtype, 30)
docsdtype = rpad(sdtype, 15)
docedges = string(interfaces_list)
doc = """
$doctype : $docsdtype : $docedges
"""

res = quote
ReactiveMP.as_node_functional_form(::$fuppertype) = ReactiveMP.ValidNodeFunctionalForm()
ReactiveMP.as_node_functional_form(::Type{$fuppertype}) = ReactiveMP.ValidNodeFunctionalForm()

ReactiveMP.sdtype(::$fuppertype) = (ReactiveMP.$sdtype)()

ReactiveMP.as_node_symbol(::$fuppertype) = $(QuoteNode(fbottomtype))

$(interface_name_getters...)

$factorisation_collectors
end
return result
end

return esc(res)
macro node(node_fform, node_type, node_interfaces, interface_aliases)
return esc(generate_node_expression(node_fform, node_type, node_interfaces, interface_aliases))
end

macro node(node_fform, node_type, node_interfaces)
return esc(generate_node_expression(node_fform, node_type, node_interfaces, nothing))
end
4 changes: 2 additions & 2 deletions src/nodes/predefined/beta.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import SpecialFunctions: logbeta

@node Beta Stochastic [out, α, β]
@node Beta Stochastic [out, (a, aliases = [α]), (b, aliases = [β])]

@average_energy Beta (q_out::Any, q_α::Any, q_β::Any) = logbeta(mean(q_α), mean(q_β)) - (mean(q_α) - 1.0) * mean(log, q_out) - (mean(q_β) - 1.0) * mean(mirrorlog, q_out)
@average_energy Beta (q_out::Any, q_a::Any, q_b::Any) = logbeta(mean(q_a), mean(q_b)) - (mean(q_a) - 1.0) * mean(log, q_out) - (mean(q_b) - 1.0) * mean(mirrorlog, q_out)
4 changes: 2 additions & 2 deletions src/rules/beta/marginals.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

@marginalrule Beta(:out_α_β) (m_out::Beta, m_α::PointMass, m_β::PointMass) = begin
return convert_paramfloattype((out = prod(ClosedProd(), Beta(mean(m_α), mean(m_β)), m_out), a = m_α, b = m_β))
@marginalrule Beta(:out_a_b) (m_out::Beta, m_a::PointMass, m_b::PointMass) = begin
return convert_paramfloattype((out = prod(ClosedProd(), Beta(mean(m_a), mean(m_b)), m_out), a = m_a, b = m_b))
end
4 changes: 2 additions & 2 deletions src/rules/beta/out.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

@rule Beta(:out, Marginalisation) (m_α::PointMass, m_β::PointMass) = Beta(mean(m_α), mean(m_β))
@rule Beta(:out, Marginalisation) (m_a::PointMass, m_b::PointMass) = Beta(mean(m_a), mean(m_b))

@rule Beta(:out, Marginalisation) (q_α::PointMass, q_β::PointMass) = Beta(mean(q_α), mean(q_β))
@rule Beta(:out, Marginalisation) (q_a::PointMass, q_b::PointMass) = Beta(mean(q_a), mean(q_b))
2 changes: 1 addition & 1 deletion src/variables/constant.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export constvar
export constvar, ConstVariable

mutable struct ConstVariable <: AbstractVariable
marginal :: MarginalObservable
Expand Down
2 changes: 1 addition & 1 deletion src/variables/data.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export datavar, update!, DataVariableActivationOptions
export datavar, DataVariable, update!, DataVariableActivationOptions

mutable struct DataVariable{M, P} <: AbstractVariable
input_messages :: Vector{MessageObservable{AbstractMessage}}
Expand Down
2 changes: 1 addition & 1 deletion src/variables/random.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export randomvar, RandomVariableActivationOptions
export randomvar, RandomVariable, RandomVariableActivationOptions

## Random variable implementation

Expand Down
Loading

0 comments on commit 9f438e5

Please sign in to comment.