Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow functional form constraints that are not sub-typed from AbstractFormConstraint #403

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions docs/src/custom/custom-functional-form.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ q(x) = f\left(\frac{\overrightarrow{\mu}(x)\overleftarrow{\mu}(x)}{\int \overrig
AbstractFormConstraint
UnspecifiedFormConstraint
CompositeFormConstraint
ReactiveMP.preprocess_form_constraints
```

### Form check strategy
Expand Down Expand Up @@ -50,35 +51,75 @@ constrain_form

## [Custom Functional Form Example](@id custom-functional-form-example)

In this demo we show how to build a custom functional form constraint that is compatible with the `ReactiveMP.jl` inference backend. An important part of the functional forms constraint implementation is the `prod` function in the [`BayesBase`](https://reactivebayes.github.io/BayesBase.jl/stable/) package. We show a relatively simple use-case, which might not be very useful in practice, but serves as a simple step-by-step guide. Assume that we want a specific posterior marginal of some random variable in our model to have a specific Gaussian parametrisation, for example mean-precision. We can use built-in `NormalMeanPrecision` distribution, but we still need to define our custom functional form constraint:
In this demo, we show how to build a custom functional form constraint that is compatible with the `ReactiveMP.jl` inference backend. An important part of the functional form constraint implementation is the `prod` function in the [`BayesBase`](https://reactivebayes.github.io/BayesBase.jl/stable/) package. We present a relatively simple use case, which may not be very practical but serves as a straightforward step-by-step guide.

Assume that we want a specific posterior marginal of some random variable in our model to have a specific Gaussian parameterization, such as mean-precision. Here, how we can achieve this with our custom `MeanPrecisionFormConstraint` functional form constraint:

```@example custom-functional-form-example
using ReactiveMP, BayesBase
using ReactiveMP, ExponentialFamily, Distributions, BayesBase

# First we define our functional form structure with no fields
# First, we define our functional form structure with no fields
struct MeanPrecisionFormConstraint <: AbstractFormConstraint end
```

Next we define the behaviour of our functional form constraint:

```@example custom-functional-form-example
ReactiveMP.default_form_check_strategy(::MeanPrecisionFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MeanPrecisionFormConstraint) = GenericProd()
ReactiveMP.default_form_check_strategy(::MeanPrecisionFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MeanPrecisionFormConstraint) = GenericProd()

function ReactiveMP.constrain_form(::MeanPrecisionFormConstraint, distribution)
# This is quite a naive assumption, that a given `dsitribution` object has `mean` and `precision` defined
# However this quantities might be approximated with some other external method, e.g. Laplace approximation
# This assumes that the given `distribution` object has `mean` and `precision` defined.
# These quantities might be approximated using other methods, such as Laplace approximation.
m = mean(distribution) # or approximate with some other method
p = precision(distribution) # or approximate with some other method
return NormalMeanPrecision(m, p)
end

function ReactiveMP.constrain_form(::MeanPrecisionFormConstraint, distribution::BayesBase.ProductOf)
# ProductOf is the special case, read about this type more in the corresponding documentation section
# of the `BayesBase` package
# `ProductOf` is a special case. Read more about this type in the corresponding
# documentation section of the `BayesBase` package.
# ...
end

constraint = ReactiveMP.preprocess_form_constraints(MeanPrecisionFormConstraint())

constrain_form(constraint, NormalMeanVariance(0, 2))
```

## Wrapped Form Constraints

Some constraint objects might not be subtypes of `AbstractFormConstraint`. This can occur, for instance, if the object is defined in a different package or needs to subtype a different abstract type. In such cases, `ReactiveMP` expects users to pass a `WrappedFormConstraint` object, which wraps the original object and makes it compatible with the `ReactiveMP` inference backend. Note that the [`ReactiveMP.preprocess_form_constraints`](@ref) function automatically wraps all objects that are not subtypes of `AbstractFormConstraint`.

Additionally, objects wrapped by `WrappedFormConstraints` may implement the `ReactiveMP.prepare_context` function. This function's output will be stored in the `WrappedFormConstraints` along with the original object. If `prepare_context` is implemented, the `constrain_form` function will take three arguments: the original constraint, the context, and the object that needs to be constrained.

```@docs
ReactiveMP.WrappedFormConstraint
ReactiveMP.prepare_context
ReactiveMP.constrain_form(::ReactiveMP.WrappedFormConstraint, something)
```

```@example wrapped-form-constraint-example
using ReactiveMP, Distributions, BayesBase, Random

# First, we define our custom form constraint that creates a set of samples
# Note that this is not a subtype of `AbstractFormConstraint`
struct MyCustomSampleListFormConstraint end

# Note that we still need to implement `default_form_check_strategy` and `default_prod_constraint` functions
# which are necessary for the `ReactiveMP` inference backend
ReactiveMP.default_form_check_strategy(::MyCustomSampleListFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MyCustomSampleListFormConstraint) = GenericProd()

# We implement the `prepare_context` function, which returns a random number generator
function ReactiveMP.prepare_context(constraint::MyCustomSampleListFormConstraint)
return Random.default_rng()
end

# We implement the `constrain_form` function, which returns a set of samples
function ReactiveMP.constrain_form(constraint::MyCustomSampleListFormConstraint, context, distribution)
return rand(context, distribution, 10)
end

constraint = ReactiveMP.preprocess_form_constraints(MyCustomSampleListFormConstraint())

constrain_form(constraint, Normal(0, 10))
```


53 changes: 53 additions & 0 deletions src/constraints/form.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,63 @@ default_form_check_strategy(::UnspecifiedFormConstraint) = FormConstraintCheckLa

default_prod_constraint(::UnspecifiedFormConstraint) = GenericProd()

"""
constrain_form(constraint, something)

This function applies a given form constraint to a given object.
"""
function constrain_form end

constrain_form(::UnspecifiedFormConstraint, something) = something
constrain_form(::UnspecifiedFormConstraint, something::Union{ProductOf, LinearizedProductOf}) =
error("`ProductOf` object cannot be used as a functional form in inference backend. Use form constraints to restrict the functional form of marginal posteriors.")

"""
WrappedFormConstraint(constraint, context)

This is a wrapper for a form constraint object. It allows to pass additional context to the `constrain_form` function.
By default all objects that are not sub-typed from `AbstractFormConstraint` are wrapped into this object.
Use `ReactiveMP.prepare_context` to provide an extra context for a given form constraint, that can be reused between multiple `constrain_form` calls.
"""
struct WrappedFormConstraint{C, X} <: AbstractFormConstraint
constraint::C
context::X
end

struct WrappedFormConstraintNoContext end

"""
prepare_context(constraint)

This function prepares a context for a given form constraint. Returns `WrappedFormConstraintNoContext` if no context is needed (the default behaviour).
"""
prepare_context(constraint) = WrappedFormConstraintNoContext()

"""
constrain_form(wrapped::WrappedFormConstraint, something)

This function unwraps the `wrapped` object and calls `constrain_form` function with the provided context.
If the context is not provided, simply calls `constrain_form` with the wrapped constraint. Otherwise passes the context to the `constrain_form` function as the second argument.
"""
constrain_form(wrapped::WrappedFormConstraint, something) = constrain_form(wrapped, wrapped.context, something)
constrain_form(wrapped::WrappedFormConstraint, ::WrappedFormConstraintNoContext, something) = constrain_form(wrapped.constraint, something)
constrain_form(wrapped::WrappedFormConstraint, context, something) = constrain_form(wrapped.constraint, context, something)

default_form_check_strategy(wrapped::WrappedFormConstraint) = default_form_check_strategy(wrapped.constraint)
default_prod_constraint(wrapped::WrappedFormConstraint) = default_prod_constraint(wrapped.constraint)

"""
preprocess_form_constraints(constraints)

This function preprocesses form constraints and converts the provided objects into a form compatible with ReactiveMP inference backend (if possible).
If a tuple of constraints is passed, it creates a `CompositeFormConstraint` object. Wraps unknown form constraints into a `WrappedFormConstraint` object.
"""
function preprocess_form_constraints end

preprocess_form_constraints(constraints::Tuple) = CompositeFormConstraint(map(preprocess_form_constraints, constraints))
preprocess_form_constraints(constraint::AbstractFormConstraint) = constraint
preprocess_form_constraints(constraint) = WrappedFormConstraint(constraint, prepare_context(constraint))

"""
CompositeFormConstraint

Expand Down
139 changes: 139 additions & 0 deletions test/constraints/form_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
@testitem "`UnspecifiedFormConstraint` should not error on `Distribution` objects" begin
using Distributions
import ReactiveMP: constrain_form

@test constrain_form(UnspecifiedFormConstraint(), Beta(1, 1)) == Beta(1, 1)
@test constrain_form(UnspecifiedFormConstraint(), Normal(0, 1)) == Normal(0, 1)
@test constrain_form(UnspecifiedFormConstraint(), MvNormal([0.0, 0.0])) == MvNormal([0.0, 0.0])
end

@testitem "`UnspecifiedFormConstraint` should error on `ProductOf` and `LinearizedProductOf` objects" begin
using Distributions, BayesBase
import ReactiveMP: constrain_form

@test_throws "object cannot be used as a functional form in inference backend" constrain_form(UnspecifiedFormConstraint(), ProductOf(Beta(1, 1), Normal(0, 1)))
@test_throws "object cannot be used as a functional form in inference backend" constrain_form(UnspecifiedFormConstraint(), LinearizedProductOf([Beta(1, 1), Beta(1, 1)], 2))
end

@testitem "`CompositeFormConstraint` should call the constraints in the specified order" begin
import ReactiveMP: constrain_form

struct FormConstraint1 end
struct FormConstraint2 end

constrain_form(::FormConstraint1, x) = x + 1
constrain_form(::FormConstraint2, x) = x * 2

composite = CompositeFormConstraint((FormConstraint1(), FormConstraint2()))
@test constrain_form(composite, 1) == 4

composite = CompositeFormConstraint((FormConstraint2(), FormConstraint1()))
@test constrain_form(composite, 1) == 3
end

@testitem "`preprocess_form_constraints` should create `CompositeFormConstraint` from a tuple of constraints" begin
import ReactiveMP: preprocess_form_constraints, AbstractFormConstraint

struct FormConstraint1 <: AbstractFormConstraint end
struct FormConstraint2 <: AbstractFormConstraint end

constraints = (FormConstraint1(), FormConstraint2())
@test preprocess_form_constraints(constraints) == CompositeFormConstraint(constraints)
@test preprocess_form_constraints(FormConstraint1()) == FormConstraint1()
@test preprocess_form_constraints(FormConstraint2()) == FormConstraint2()
end

@testitem "`preprocess_form_constraints` should wrap unknown form constraints into a `WrappedFormConstraint`" begin
import ReactiveMP: preprocess_form_constraints, AbstractFormConstraint, WrappedFormConstraint, WrappedFormConstraintNoContext

struct FormConstraint1 <: AbstractFormConstraint end
struct FormConstraint2 end
struct FormConstraint3WithContext end
struct FormConstraint3Context end

ReactiveMP.prepare_context(::FormConstraint3WithContext) = FormConstraint3Context()

@test preprocess_form_constraints(FormConstraint1()) == FormConstraint1()
@test preprocess_form_constraints(FormConstraint2()) == WrappedFormConstraint(FormConstraint2(), WrappedFormConstraintNoContext())
@test preprocess_form_constraints(FormConstraint3WithContext()) == WrappedFormConstraint(FormConstraint3WithContext(), FormConstraint3Context())
@test preprocess_form_constraints((FormConstraint1(), FormConstraint2())) ==
CompositeFormConstraint((FormConstraint1(), WrappedFormConstraint(FormConstraint2(), WrappedFormConstraintNoContext())))
@test preprocess_form_constraints((FormConstraint1(), FormConstraint3WithContext())) ==
CompositeFormConstraint((FormConstraint1(), WrappedFormConstraint(FormConstraint3WithContext(), FormConstraint3Context())))
@test preprocess_form_constraints((FormConstraint2(), FormConstraint3WithContext())) == CompositeFormConstraint((
WrappedFormConstraint(FormConstraint2(), WrappedFormConstraintNoContext()), WrappedFormConstraint(FormConstraint3WithContext(), FormConstraint3Context())
))
@test preprocess_form_constraints((FormConstraint2(), FormConstraint3WithContext(), FormConstraint1())) == CompositeFormConstraint((
WrappedFormConstraint(FormConstraint2(), WrappedFormConstraintNoContext()), WrappedFormConstraint(FormConstraint3WithContext(), FormConstraint3Context()), FormConstraint1()
))

@test preprocess_form_constraints(preprocess_form_constraints(FormConstraint2())) == WrappedFormConstraint(FormConstraint2(), WrappedFormConstraintNoContext())
@test preprocess_form_constraints(preprocess_form_constraints(FormConstraint3WithContext())) == WrappedFormConstraint(FormConstraint3WithContext(), FormConstraint3Context())
end

@testitem "`WrappedFormConstraint` should simply redirect all the important functions to the underlying object" begin
import ReactiveMP: constrain_form, preprocess_form_constraints

struct FormConstraint end

ReactiveMP.default_form_check_strategy(::FormConstraint) = "hello"
ReactiveMP.default_prod_constraint(::FormConstraint) = "world"

constrain_form(::FormConstraint, x) = x + 1

constraint = preprocess_form_constraints(FormConstraint())

@test default_form_check_strategy(constraint) == "hello"
@test default_prod_constraint(constraint) == "world"
end

@testitem "`WrappedFormConstraint` should not pass empty context to the `constrain_form` call" begin
import ReactiveMP: constrain_form, preprocess_form_constraints

struct FormConstraintWithoutContext end

function constrain_form(::FormConstraintWithoutContext, x)
return x + 1
end

function constrain_form(::FormConstraintWithoutContext, context, x)
error("This function should not be called")
end

constraint = preprocess_form_constraints(FormConstraintWithoutContext())

@test constrain_form(constraint, 1) == 2
@test constrain_form(constraint, 2) == 3
@test constrain_form(constraint, 7) == 8
end

@testitem "`WrappedFormConstraint` should be able to reuse the context between multiple `constrain_form` calls" begin
import ReactiveMP: constrain_form, preprocess_form_constraints

struct FormConstraintWithContext end
mutable struct FormConstraintContext
value::Int
end

ReactiveMP.prepare_context(::FormConstraintWithContext) = FormConstraintContext(0)

function constrain_form(::FormConstraintWithContext, x)
error("This function should not be called")
end

function constrain_form(::FormConstraintWithContext, context::FormConstraintContext, x)
context.value += 1
return x + context.value
end

constraint = preprocess_form_constraints(FormConstraintWithContext())

@test constrain_form(constraint, 1) == 2
@test constraint.context.value === 1

@test constrain_form(constraint, 2) == 4
@test constraint.context.value === 2

@test constrain_form(constraint, 6) == 9
@test constraint.context.value === 3
end
Loading