Skip to content

Commit

Permalink
Merge pull request #346 from biaslab/dev-call-rule-addons
Browse files Browse the repository at this point in the history
Allow returning addons from the `@call_rule`
  • Loading branch information
bvdmitri authored Sep 1, 2023
2 parents b40a56a + 5160db4 commit a8e6f94
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
50 changes: 44 additions & 6 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,14 +436,32 @@ macro rule(fform, lambda)
end

"""
@call_rule NodeType(:edge, Constraint) (argument1 = value1, argument2 = value2, ..., [ meta = ... ])
@call_rule NodeType(:edge, Constraint) (argument1 = value1, argument2 = value2, ..., [ meta = ..., addons = ... ])
The `@call_rule` macro helps to call the `rule` method with an easier syntax.
The structure of the macro is almost the same as in the `@rule` macro, but there is no `begin ... end` block, but instead each argument must have a specified value with the `=` operator.
The `@call_rule` accepts optional list of options before the functional form specification, for example:
```julia
@call_rule [ return_addons = true ] NodeType(:edge, Constraint) (argument1 = value1, argument2 = value2, ..., [ meta = ..., addons = ... ])
```
The list of available options is:
- `return_addons` - forces the `@call_rule` to return the tuple of `(result, addons)`
See also: [`@rule`](@ref), [`rule`](@ref), [`@call_marginalrule`](@ref)
"""
macro call_rule(options, fform, args)
return call_rule_expression(options, fform, args)
end

macro call_rule(fform, args)
return call_rule_expression(nothing, fform, args)
end

function call_rule_expression(options, fform, args)
@capture(fform, fformtype_(on_, vconstraint_)) || error("Error in macro. Functional form specification should in the form of 'fformtype_(on_, vconstraint_)'")

@capture(args, (inputs__, meta = meta_, addons = addons_) | (inputs__, addons = addons_) | (inputs__, meta = meta_) | (inputs__,)) ||
Expand All @@ -466,14 +484,34 @@ macro call_rule(fform, args)

on_arg = call_rule_macro_construct_on_arg(on_type, on_index)

output = quote
let
# TODO: (bvdmitri At the moment we cannot really get the result of the addon by calling `@call_rule`
local __distribution_sym, _ = ReactiveMP.rule($fbottomtype, $on_arg, $(vconstraint)(), $m_names_arg, $m_values_arg, $q_names_arg, $q_values_arg, $meta, $addons, $node)
__distribution_sym
# Options
# Option 1. Modifies the output of the `@call_rule` macro and returns a tuple of the result and the enabled addons
return_addons = false

if !isnothing(options)
@capture(options, [voptions__]) || error("Error in macro. Options should be in a form of `[ option1 = value1, ... ]`, got $(options).")
foreach(voptions) do option
@capture(option, key_ = value_) || error("Error in macro. An options should be in a form of `option = value`, got $(option).")
if key === :return_addons
return_addons = Bool(value)
else
@warn "Unknown option in the `@call_rule` macro: $(option)"
end
end
end

call = quote
local __distribution_sym, __addons_sym = ReactiveMP.rule(
$fbottomtype, $on_arg, $(vconstraint)(), $m_names_arg, $m_values_arg, $q_names_arg, $q_values_arg, $meta, $addons, $node
)
end

output = if !return_addons
:($call; __distribution_sym)
else
:($call)
end

return esc(output)
end

Expand Down
18 changes: 18 additions & 0 deletions test/test_rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,24 @@ import MacroTools: inexpr

@test_throws ReactiveMP.RuleMethodError (@call_rule DummyNode(:out, Marginalisation) (q_x = vague(NormalMeanPrecision), q_y = vague(NormalMeanPrecision), meta = 3))
end

@testset "Check the `return_addons` option" begin
# Enable LogScale addon
dist_and_addons = @call_rule [return_addons = true] Bernoulli(:out, Marginalisation) (m_p = Beta(1, 2), addons = (AddonLogScale(),))

@test dist_and_addons isa Tuple
@test length(dist_and_addons) === 2
@test dist_and_addons[1] isa Bernoulli
@test dist_and_addons[2] isa Tuple{AddonLogScale}

# Without addons but with the option
dist_and_nothing = @call_rule [return_addons = true] Bernoulli(:out, Marginalisation) (m_p = Beta(1, 2),)

@test dist_and_nothing isa Tuple
@test length(dist_and_nothing) === 2
@test dist_and_nothing[1] isa Bernoulli
@test dist_and_nothing[2] isa Nothing
end
end

end

0 comments on commit a8e6f94

Please sign in to comment.