From 467b07621975d81856792d0788f0e490c3ecaaa5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 12 Oct 2024 05:14:39 -0400 Subject: [PATCH] Add `getparameters` and `setparameters!!` (#86) * added state_from_transition, parameters and setparameters!! * Update src/AbstractMCMC.jl Co-authored-by: David Widmann * renamed state_from_transition to updatestate!! * adhere to julia convention * added docs * fixed docs * fixed docs * added example for why updatestate!! is useful * improved MixtureState example * further improvements to docs * renamed parameters and setparameters!! to values and setvalues!! * fixed typo in docs * fixed documenting values * improved and fixed some bugs in docs * fixed typo in docs * renamed values and setvalues!! to realize and realize!! * added model to updatestate!! * Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update docs/src/api.md Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> * Update docs/src/api.md Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> * version bump --------- Co-authored-by: David Widmann Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun --- Project.toml | 2 +- docs/src/api.md | 141 ++++++++++++++++++++++++++++++++++++++++++++ src/AbstractMCMC.jl | 21 +++++++ 3 files changed, 163 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f57b1ff0..c4ef9ae1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.4.0" +version = "5.5.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index f6dd1cf8..db4ce565 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -113,3 +113,144 @@ For chains of this type, AbstractMCMC defines the following two methods. AbstractMCMC.chainscat AbstractMCMC.chainsstack ``` + +## Interacting with states of samplers + +To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: +```@docs +AbstractMCMC.getparams +AbstractMCMC.setparams!! +``` +These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers. + +### Example: `MixtureSampler` + +In a `MixtureSampler` we need two things: +- `components`: collection of samplers. +- `weights`: collection of weights representing the probability of choosing the corresponding sampler. + +```julia +struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler + components::C + weights::W +end +``` + +To implement the state, we need to keep track of a couple of things: +- `index`: the index of the sampler used in this `step`. +- `states`: the current states of _all_ the components. +We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously. +The reason is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc. + + +```julia +struct MixtureState{S} + index::Int + states::S +end +``` +The `step` for a `MixtureSampler` is defined by the following generative process +```math +\begin{aligned} +i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\ +X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1}) +\end{aligned} +``` +where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler. +[`AbstractMCMC.getparams`](@ref) and [`AbstractMCMC.setparams!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. + +If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code: + +```julia +# Update the corresponding state, i.e. `state.states[i]`, using +# the state and transition from the previous iteration. +state_current = AbstractMCMC.setparams!!( + state.states[i], + AbstractMCMC.getparams(state.states[i_prev]), +) + +# Take a `step` for this sampler using the updated state. +transition, state_current = AbstractMCMC.step( + rng, model, sampler_current, sampler_state; + kwargs... +) +``` + +The full [`AbstractMCMC.step`](@ref) implementation would then be something like: + +```julia +function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...) + # Sample the component to use in this `step`. + i = rand(Categorical(sampler.weights)) + sampler_current = sampler.components[i] + + # Update the corresponding state, i.e. `state.states[i]`, using + # the state and transition from the previous iteration. + i_prev = state.index + state_current = AbstractMCMC.setparams!!( + state.states[i], + AbstractMCMC.getparams(state.states[i_prev]), + ) + + # Take a `step` for this sampler using the updated state. + transition, state_current = AbstractMCMC.step( + rng, model, sampler_current, state_current; + kwargs... + ) + + # Create the new states. + # NOTE: Code below will result in `states_new` being a `Vector`. + # If we wanted to allow usage of alternative containers, e.g. `Tuple`, + # it would be better to use something like `@set states[i] = state_current` + # where `@set` is from Setfield.jl. + states_new = map(1:length(state.states)) do j + if j == i + # Replace the i-th state with the new one. + state_current + else + # Otherwise we just carry over the previous ones. + state.states[j] + end + end + + # Create the new `MixtureState`. + state_new = MixtureState(i, states_new) + + return transition, state_new +end +``` + +And for the initial [`AbstractMCMC.step`](@ref) we have: + +```julia +function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...) + # Initialize every state. + transitions_and_states = map(sampler.components) do spl + AbstractMCMC.step(rng, model, spl; kwargs...) + end + + # Sample the component to use this `step`. + i = rand(Categorical(sampler.weights)) + # Extract the corresponding transition. + transition = first(transitions_and_states[i]) + # Extract states. + states = map(last, transitions_and_states) + # Create new `MixtureState`. + state = MixtureState(i, states) + + return transition, state +end +``` + +Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `getparams` and `setparams!!`. + + +To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do + +```julia +sampler = MixtureSampler([sampler1, sampler2], [0.1, 0.9]) +transition, state = AbstractMCMC.step(rng, model, sampler) +while ... + transition, state = AbstractMCMC.step(rng, model, sampler, state) +end +``` diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index d7374dd2..8343bfa8 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -80,6 +80,27 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr """ struct MCMCSerial <: AbstractMCMCEnsemble end +""" + getparams(state[; kwargs...]) + +Retrieve the values of parameters from the sampler's `state` as a `Vector{<:Real}`. +""" +function getparams end + +""" + setparams!!(state, params) + +Set the values of parameters in the sampler's `state` from a `Vector{<:Real}`. + +This function should follow the `BangBang` interface: mutate `state` in-place if possible and +return the mutated `state`. Otherwise, it should return a new `state` containing the updated parameters. + +Although not enforced, it should hold that `setparams!!(state, getparams(state)) == state`. In another +word, the sampler should implement a consistent transformation between its internal representation +and the vector representation of the parameter values. +""" +function setparams!! end + include("samplingstats.jl") include("logging.jl") include("interface.jl")