Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-reversible parallel tempering #740

Open
pawel-czyz opened this issue Sep 22, 2024 · 6 comments
Open

Add non-reversible parallel tempering #740

pawel-czyz opened this issue Sep 22, 2024 · 6 comments

Comments

@pawel-czyz
Copy link

pawel-czyz commented Sep 22, 2024

Presentation of the new sampler

Parallel tempering, known also as replica exchange MCMC, maintains $K$ Markov chains at different temperatures, ranging from $\pi_0$ (the reference distribution, for example the prior) to the target distribution $\pi =: \pi_{K-1}$.
Apart from local exploration kernels, targeting each distribution individually, it includes swap kernels, trying to switch states from different chains, hence targeting the distribution
$$(x_0, \dotsc, x_{K-1})\mapsto \prod_{i=0}^{K-1} \pi_i(x_i),$$
defined on the spaces $\mathcal X^K = \mathcal X\times \cdots \times \mathcal X$. By retaining the samples from only the last coordinate, it allows one to sample from $\pi = \pi_{K-1}$.

Similarly to sequential Monte Carlo (SMC) samplers, this strategy can be highly efficient to sample from multimodal posteriors. Modern variant of parallel tempering, called non-reversible parallel tempering (NRPT), achieves the state-of-the-art performance in sampling from complex high-dimensional distributions. NRPT works with both discrete and continuous spaces, and allows one to leverage preliminary runs to tune the tempering schedule.

Resources

How does it compare to other algorithms in blackjax?

Compared with a single-chain MCMC sampler:

  • It requires maintaining $K$ chains. Hence, it can generally require more evaluations for the same number of samples collected.
  • However, it can lead to much faster mixing (especially for complex geometries or multimodal posteriors), resulting in a high effective sample size.
  • Computations for different chains can be parallelized.
  • Similarly to single-chain MCMC methods, warm-up samples have to be discarded.

Compared with a tempered SMC sampler:

  • It maintains $K$ chains (each at its own temperature) over $T$ timesteps. SMC samplers keep all particles at the same temperature and then use a resampling step to move particles to the new temperature.
  • Similarly to SMC samplers, it works with arbitrary spaces (whether continuous or discrete), if efficient local exploration kernels (targeting the tempered distributions $\pi_i$) are available.
  • Similarly to some variants of tempered SMC samplers, non-reversible parallel tempering allows one to tune the tempering schedule basing on previous runs.
  • NRPT requires discarding the warm-up samples.

Where does it fit in blackjax

BlackJAX offers a large collection of MCMC kernels. They can be leveraged to build non-reversible parallel tempering MCMC samplers, exploring different temperatures simultaneously, which leads to faster mixing and allows one to sample from multimodal posteriors. NRPT ac

Are you willing to open a PR?

Yes. I have a prototype implemented in a blog post, which I would be willing to refactor and contribute.

I am however unsure about two design choices:

  1. Non-reversible parallel tempering requires application of local kernels, $K_i$, to individual coordinate chains $x_i \in \mathcal X$ (corresponding to tempered distributions $\pi_i$). What would be the best practice of storing and applying different kernels? For example, one may be willing to use HMC with a large step size for sampling from $\pi_0$ (or even sample from the prior directly, if e.g., using a probabilistic programming language) and HMC with a small step size to sample from $\pi=\pi_{K-1}$. In my prototype I employ kernels from the same family to use jax.vmap, rather than a for loop. I guess this problem is less apparent in tempered SMC samplers, where at temperature $T$ all particles are moved by the same kernel $K_T$. (Even though kernels $K_T$ may differ for different temperatures.)
  2. How to parallelize the computation? Pigeons.jl uses MPI communication between different machines to study very high-dimensional problems. Should a BlackJAX version use some version of sharding or could I keep it simple and rely on built-in parallelism?
@junpenglao
Copy link
Member

Thank you for the detailed write-up, much appreciated. And yes, a contribution will be very welcome!

Regarding the design choice, jax.vmap would be the answer to both of your question.

I have not read in detail of your blog post, but just wondering if you have compared your implementation with TFP, which i am a bit more familiar with.

@junpenglao
Copy link
Member

BTW, I am a huge fan of parallel tempering - very excited about this! Looking forward to your PR!

@AdrienCorenflos
Copy link
Contributor

Same here, I had planned to do it at some point but I've not been able to commit the time to :D very happy someone is doing it!

Design choice wise, I actually do not think it would be a good idea to vmap everything at the lower level, in particular in sight of being able to do proper sharding.
IMO there are two components to the method: 1) swap kernel, 2) reversible vs non-reversible application (this can probably be vmapped as it's onlt a chain on the indices conditionally on the log likelihood values).

Once that's done, the choice of parallelism for the state chains is very much user driven and it's hard to enforce a coherent interface supporting all JAX models.

@junpenglao
Copy link
Member

IIUC, swapping kernel is basically swapping parameter (eg. step size), which means you update the input parameter with some advance indexing. The base kernel would remain the same like step = jax.vmap(kernel.step)

@pawel-czyz
Copy link
Author

pawel-czyz commented Sep 23, 2024

Thank you for your kind feedback and all the suggestions!

I have not read in detail of your blog post, but just wondering if you have compared your implementation with TFP, which i am a bit more familiar with.

Thanks, I have not been aware that there exists a TFP implementation! I like it, the major differences seem to be:

  1. Allow different step sizes by using batching. (If I understand BlackJAX philosophy, jax.vmap is preferred over batching, isn't it? I.e., it's not possible to get log_p return a batch of log-PDFs compatible with a "batched" kernel?)
  2. TFP allows also other kinds of parallel tempering, with different swapping schemes. I think Adrien's suggestion will be here useful, making this essentially a variable argument resulting in different kernels.
  3. I don't think TFP records rejection statistics and uses them to tune up the optimisation schedule.

Once that's done, the choice of parallelism for the state chains is very much user driven and it's hard to enforce a coherent interface supporting all JAX models.

I think this is a very good point. I think it'd be convenient to have a utility function, allowing the end user to quickly build a reasonable (even if not optimally sharded) parallel tempering kernel out of an existing one, using jax.vmap. I've been thinking about something along these lines:

import jax
import jax.random as jrandom
import jax.numpy as jnp
import blackjax


def init(
    init_fn,
    positions,
    log_target,
    log_reference,
    inverse_temperatures,
):
    def create_tempered_log_p(inverse_temperature):
        def log_p(x):
            return inverse_temperature * log_target(x) + (1.0 - inverse_temperature) * log_reference(x)
        return log_p

    def init_fn_temp(position, inverse_temperature):
        return init_fn(position, create_tempered_log_p(inverse_temperature))

    return jax.vmap(init_fn_temp)(positions, inverse_temperatures)


def build_kernel(
    base_kernel_fn,
    log_target,
    log_reference,
    inverse_temperatures,
    parameters,
):
    def create_tempered_log_p(inverse_temperature):
        def log_p(x):
            return inverse_temperature * log_target(x) + (1.0 - inverse_temperature) * log_reference(x)
        return log_p

    def kernel(rng_key, state, inverse_temperature, parameter):
        return base_kernel_fn(rng_key, state, create_tempered_log_p(inverse_temperature), parameter)

    n_chains = inverse_temperatures.shape[0]
    
    def step_fn(
        rng_key,
        state,
    ):
        keys = jrandom.split(rng_key, n_chains)
        return jax.vmap(kernel)(
            keys,
            state,
            inverse_temperatures,
            parameters,
        )

    return step_fn


def log_p(x):
    return -(jnp.sum(jnp.square(x)) + jnp.sum(jnp.power(x, 4)))


def log_ref(x):
    return -jnp.sum(jnp.square(x))


n_chains = 10
inverse_temperatures = 0.9 ** jnp.linspace(0, 1, n_chains)
initial_positions = jnp.ones((n_chains, 2))
parameters = jnp.linspace(0.1, 1, n_chains)


init_state = init(
    blackjax.mcmc.mala.init,
    initial_positions,
    log_p,
    log_ref,
    inverse_temperatures,
)

kernel = build_kernel(
    blackjax.mcmc.mala.build_kernel(),
    log_p,
    log_ref,
    inverse_temperatures,
    parameters,
)

rng_key = jrandom.PRNGKey(42)
new_state, info = kernel(rng_key, init_state)

Please, let me know what you think! Also:

Question: how to optimally pass the parameters to the individual kernels? This solution works only if each kernel has a single parameter. This parameter could be dictionary-valued, though, allowing the users to write wrappers around kernel initialisers. E.g., one could create a wrapper around the the HMC kernel builder function, which passes the step size, mass matrix and numbers of steps as a dictionary, but I'm not sure how convenient it is for the end users.

IIUC, swapping kernel is basically swapping parameter (eg. step size), which means you update the input parameter with some advance indexing. The base kernel would remain the same like step = jax.vmap(kernel.step)

I see! This potentially can lead to better parallelism, but I think it'd be easier for me to swap the states $x_i$, rather than swapping parameters and temperatures. One of the reasons is that I do not have to build the explicit index process, then, and can easily record the rejection rates, which allows one to tune the tempering schedule – would such a solution be still fine? 🙂

Question: I'm also not sure about the best design choice regarding the composed kernels. Currently each kernel $K$ records some information (e.g., acceptance rates, divergences, ...). In this case, we have a kernel $K_\text{ind}$ (applying kernels $K_i$ independently to individual components $x_i$) and a kernel $K_\text{swap}$, applied to the joint state $(x_1, \dotsc, x_{T-1})$. Should I define the information object to be a named tuple constructed out of the information of kernel $K_\text{ind}$, which builds upon $K_i$, and the auxiliary information from $K_\text{swap}$? In this case, how should one handle the auxiliary information coming from an application of $K_\text{ind}$ e.g., 3 times for every swap attempt? (In other words, the joint kernel can be $ K_\text{ind}^3 K_\text{swap}$, resulting in 3 times longer information about the independent moves...)

More generally,
Question: Are there existing utilities for combining the kernels? I know that the Metropolis-within-Gibbs tutorial constructs explicitly a kernel, but in theory, one could imagine an operation of composing the kernels corresponding to the updates of different variables. Such utilities could be useful not only in the context of parallel tempering or Metropolis-within-Gibbs, but also for building non-reversible kernels employing some information about the given target. For example, when sampling phylogenetic trees one has several kernels (e.g., kernels changing the tree topology and kernels permuting the taxa between different nodes) and they are combined either by composition $K = K_1 K_2 K_3$ or by a mixture $K = \frac{1}{3}(K_1 + K_2 + K_3)$.

@junpenglao
Copy link
Member

I think for simplicity, let's start with building the functionality assuming we are using the same base kernel (e.g., HMC) with different parameter (e.g., step_size)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants