Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathfinder #194

Merged
merged 24 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions blackjax/adaptation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import window_adaptation
from . import pathfinder_adaptation, window_adaptation

__all__ = ["window_adaptation"]
__all__ = ["window_adaptation", "pathfinder_adaptation"]
126 changes: 126 additions & 0 deletions blackjax/adaptation/pathfinder_adaptation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Implementation of the Pathinder warmup for the HMC family of sampling algorithms."""
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp

from blackjax.adaptation.step_size import (
DualAveragingAdaptationState,
dual_averaging_adaptation,
)
from blackjax.mcmc.hmc import HMCState
from blackjax.types import Array, PRNGKey, PyTree
from blackjax.vi.pathfinder import init as pathfinder_init_fn
from blackjax.vi.pathfinder import lbfgs_inverse_hessian_formula_1, sample_from_state

__all__ = ["base"]


class PathfinderAdaptationState(NamedTuple):
da_state: DualAveragingAdaptationState
inverse_mass_matrix: Array


def base(
kernel_factory: Callable,
logprob_fn: Callable,
target_acceptance_rate: float = 0.65,
):
"""Warmup scheme for sampling procedures based on euclidean manifold HMC.
This function tunes the values of the step size and the mass matrix according
to this schema:
* pathfinder algorithm is run and an estimation of the inverse mass matrix
is derived, as well as an initialization point for the markov chain
* Nesterov's dual averaging adaptation is then run to tune the step size

Parameters
----------
kernel_factory
A function which returns a transition kernel given a step size and a
mass matrix.
logprob_fn
The log density probability density function from which we wish to sample.
target_acceptance_rate:
The target acceptance rate for the step size adaptation.

Returns
-------
init
Function that initializes the warmup.
update
Function that moves the warmup one step.
final
Function that returns the step size and mass matrix given a warmup state.

"""
da_init, da_update, da_final = dual_averaging_adaptation(
target=target_acceptance_rate
)

def init(
rng_key: PRNGKey, initial_position: Array, initial_step_size: float
) -> Tuple[PathfinderAdaptationState, PyTree]:
"""Initialize the warmup.

To initialize the warmup we use pathfinder to estimate the inverse mass matrix and
then we set up the dual averaging adaptation algorithm
"""
da_state = da_init(initial_step_size)

pathfinder_rng_key, sample_rng_key = jax.random.split(rng_key, 2)
pathfinder_state = pathfinder_init_fn(
pathfinder_rng_key, logprob_fn, initial_position
)
new_initial_position = sample_from_state(sample_rng_key, pathfinder_state)
inverse_mass_matrix = lbfgs_inverse_hessian_formula_1(
pathfinder_state.alpha, pathfinder_state.beta, pathfinder_state.gamma
)

warmup_state = PathfinderAdaptationState(da_state, inverse_mass_matrix)

return warmup_state, new_initial_position

def update(
rng_key: PRNGKey,
chain_state: HMCState,
adaptation_state: PathfinderAdaptationState,
) -> Tuple[HMCState, PathfinderAdaptationState, NamedTuple]:
"""Move the warmup by one step.

We first create a new kernel with the current values of the step size
and mass matrix and move the chain one step. Then, we update the dual
averaging adaptation algorithm.

Parameters
----------
rng_key
The key used in JAX's random number generator.
chain_state
Current state of the chain.
adaprtation_state
Current warmup state.

Returns
-------
The updated states of the chain and the warmup.

"""
step_size = jnp.exp(adaptation_state.da_state.log_step_size)
inverse_mass_matrix = adaptation_state.inverse_mass_matrix
kernel = kernel_factory(step_size, inverse_mass_matrix)

chain_state, chain_info = kernel(rng_key, chain_state)
new_da_state = da_update(adaptation_state.da_state, chain_info)
new_warmup_state = PathfinderAdaptationState(
new_da_state, adaptation_state.inverse_mass_matrix
)

return chain_state, new_warmup_state, chain_info

def final(warmup_state: PathfinderAdaptationState) -> Tuple[float, Array]:
"""Return the step size and mass matrix."""
step_size = jnp.exp(warmup_state.da_state.log_step_size_avg)
inverse_mass_matrix = warmup_state.inverse_mass_matrix
return step_size, inverse_mass_matrix

return init, update, final
172 changes: 151 additions & 21 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import blackjax.mcmc as mcmc
import blackjax.sgmcmc as sgmcmc
import blackjax.smc as smc
import blackjax.vi as vi
from blackjax.base import AdaptationAlgorithm, SamplingAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.types import Array, PRNGKey, PyTree
Expand All @@ -21,6 +22,8 @@
"sgld",
"tempered_smc",
"window_adaptation",
"pathfinder",
"pathfinder_adaptation",
]


Expand Down Expand Up @@ -647,30 +650,21 @@ def step_fn(rng_key: PRNGKey, state):

class orbital_hmc:
"""Implements the (basic) user interface for the Periodic orbital MCMC kernel

Each iteration of the periodic orbital MCMC outputs ``period`` weighted samples from
a single Hamiltonian orbit connecting the previous sample and momentum (latent) variable
with precision matrix ``inverse_mass_matrix``, evaluated using the ``bijection`` as an
integrator with discretization parameter ``step_size``.

Examples
--------

A new Periodic orbital MCMC kernel can be initialized and used with the following code:

.. code::

per_orbit = blackjax.orbital_hmc(logprob_fn, step_size, inverse_mass_matrix, period)
state = per_orbit.init(position)
new_state, info = per_orbit.step(rng_key, state)

We can JIT-compile the step function for better performance

.. code::

step = jax.jit(per_orbit.step)
new_state, info = step(rng_key, state)

Parameters
----------
logprob_fn
Expand All @@ -685,11 +679,9 @@ class orbital_hmc:
The number of steps used to build the orbit.
bijection
(algorithm parameter) The symplectic integrator to use to build the orbit.

Returns
-------
A ``SamplingAlgorithm``.

"""

init = staticmethod(mcmc.periodic_orbital.init)
Expand Down Expand Up @@ -725,36 +717,26 @@ def step_fn(rng_key: PRNGKey, state):

class elliptical_slice:
"""Implements the (basic) user interface for the Elliptical Slice sampling kernel

Examples
--------

A new Elliptical Slice sampling kernel can be initialized and used with the following code:

.. code::

ellip_slice = blackjax.elliptical_slice(loglikelihood_fn, cov_matrix)
state = ellip_slice.init(position)
new_state, info = ellip_slice.step(rng_key, state)

We can JIT-compile the step function for better performance

.. code::

step = jax.jit(ellip_slice.step)
new_state, info = step(rng_key, state)

Parameters
----------
loglikelihood_fn
Only the log likelihood function from the posterior distributon we wish to sample.
cov_matrix
The value of the covariance matrix of the gaussian prior distribution from the posterior we wish to sample.

Returns
-------
A ``SamplingAlgorithm``.

"""

init = staticmethod(mcmc.elliptical_slice.init)
Expand All @@ -781,3 +763,151 @@ def step_fn(rng_key: PRNGKey, state):
)

return SamplingAlgorithm(init_fn, step_fn)


# -----------------------------------------------------------------------------
# VARIATIONAL INFERENCE
# -----------------------------------------------------------------------------


class pathfinder:
"""Implements the (basic) user interface for the pathfinder kernel.

Pathfinder locates normal approximations to the target density along a
quasi-Newton optimization path, with local covariance estimated using
the inverse Hessian estimates produced by the L-BFGS optimizer.
Pathfinder returns draws from the approximation with the lowest estimated
Kullback-Leibler (KL) divergence to the true posterior.

Note: all the heavy processing in performed in the init function, step
function is just a drawing a sample from a normal distribution


Returns
-------
A ``SamplingAlgorithm``.

"""

init = staticmethod(vi.pathfinder.init)
kernel = staticmethod(vi.pathfinder.kernel)

def __new__( # type: ignore[misc]
cls,
rng_key: PRNGKey,
logprob_fn: Callable,
num_samples: int = 200,
**lbfgs_kwargs,
) -> SamplingAlgorithm:

step = cls.kernel()

def init_fn(position: PyTree):
return cls.init(
rng_key, logprob_fn, position, num_samples, False, **lbfgs_kwargs
)

def step_fn(rng_key: PRNGKey, state):
return step(
rng_key,
state,
)

return SamplingAlgorithm(init_fn, step_fn)


def pathfinder_adaptation(
algorithm: Union[hmc, nuts],
logprob_fn: Callable,
num_steps: int = 400,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.65,
**parameters,
) -> AdaptationAlgorithm:
"""Adapt the parameters of algorithms in the HMC family.

Algorithms in the HMC family on a euclidean manifold depend on the value of
at least two parameters: the step size, related to the trajectory
integrator, and the mass matrix, linked to the euclidean metric.

Good tuning is very important, especially for algorithms like NUTS which can
be extremely inefficient with the wrong parameter values.
This function tunes the values of these parameters according to this schema:
* pathfinder algorithm is run and an estimation of the inverse mass matrix
is derived, as well as an initialization point for the markov chain
* Nesterov's dual averaging adaptation is then run to tune the step size

Parameters
----------
algorithm
The algorithm whose parameters are being tuned.
logprob_fn
The log density probability density function from which we wish to sample.
num_steps
The number of adaptation steps for the dual averaging adaptation scheme.
initial_step_size
The initial step size used in the algorithm.
target_acceptance_rate
The acceptance rate that we target during step size adaptation.
**parameters
The extra parameters to pass to the algorithm, e.g. the number of
integration steps for HMC.

Returns
-------
A function that returns the last chain state and a sampling kernel with the tuned parameter values from an initial state.

"""

kernel = algorithm.kernel()

def kernel_factory(step_size: float, inverse_mass_matrix: Array):
def kernel_fn(rng_key, state):
return kernel(
rng_key,
state,
logprob_fn,
step_size,
inverse_mass_matrix,
**parameters,
)

return kernel_fn

init, update, final = adaptation.pathfinder_adaptation.base(
kernel_factory,
logprob_fn,
target_acceptance_rate=target_acceptance_rate,
)

@jax.jit
def one_step(carry, rng_key):
state, adaptation_state = carry
state, adaptation_state, info = update(rng_key, state, adaptation_state)
return ((state, adaptation_state), (state, info, adaptation_state.da_state))

def run(rng_key: PRNGKey, position: PyTree):

rng_key_init, rng_key_chain = jax.random.split(rng_key, 2)

init_warmup_state, init_position = init(rng_key, position, initial_step_size)
init_state = algorithm.init(init_position, logprob_fn)

keys = jax.random.split(rng_key, num_steps)
last_state, warmup_chain = jax.lax.scan(
one_step,
(init_state, init_warmup_state),
keys,
)
last_chain_state, last_warmup_state = last_state
history_state, history_info, history_da = warmup_chain
history_adaptation = last_warmup_state._replace(da_state=history_da)

warmup_chain = (history_state, history_info, history_adaptation)

step_size, inverse_mass_matrix = final(last_warmup_state)
kernel = kernel_factory(step_size, inverse_mass_matrix)

return last_chain_state, kernel, warmup_chain

return AdaptationAlgorithm(run)
3 changes: 3 additions & 0 deletions blackjax/vi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import pathfinder

__all__ = ["pathfinder"]
Loading