Skip to content

Commit

Permalink
Draft pre-conditioning matrix in Barker proposal.
Browse files Browse the repository at this point in the history
This is a first draft of adding the pre-conditioning to the Barker
proposal. This follows Algorithms 4 and 5 in Appendix G of the original
Barker proposal paper. It's somewhat unclear from the paper, but the
separate step size that was already implemented serves as a global
scale for the normal distribution of the proposal. The function
`_compute_acceptance_probability` now takes in the transpose sqrt mass
matrix and the inverse, also it has been flattened to accomodate
the corresponding matrix multiplicatios.
  • Loading branch information
ismael-mendoza committed Sep 1, 2024
1 parent 8a9b546 commit 68ede50
Showing 1 changed file with 74 additions and 31 deletions.
105 changes: 74 additions & 31 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
from jax.scipy import stats
from jax.tree_util import tree_leaves, tree_map

from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.proposal import static_binomial_sampling
Expand Down Expand Up @@ -81,44 +81,57 @@ def build_kernel():
"""

def _compute_acceptance_probability(
state: BarkerState,
proposal: BarkerState,
state: BarkerState, proposal: BarkerState, C_t: jnp.Array, C_t_inv: jnp.Array
) -> float:
"""Compute the acceptance probability of the Barker's proposal kernel."""

def ratio_proposal_nd(y, x, log_y, log_x):
num = -_log1pexp(-log_y * (x - y))
den = -_log1pexp(-log_x * (y - x))
x_flat, _ = ravel_pytree(state.position)
y_flat, _ = ravel_pytree(proposal.position)
log_x_flat, _ = ravel_pytree(state.logdensity_grad)
log_y_flat, _ = ravel_pytree(proposal.logdensity_grad)

return jnp.sum(num - den)
z = C_t_inv.dot(y_flat - x_flat)
c_x = log_x_flat.dot(C_t)
c_y = log_y_flat.dot(C_t)

num = _log1pexp(-z * c_x)
denom = _log1pexp(z * c_y)

ratio_proposal = jnp.sum(num - denom)

ratios_proposals = tree_map(
ratio_proposal_nd,
proposal.position,
state.position,
proposal.logdensity_grad,
state.logdensity_grad,
)
ratio_proposal = sum(tree_leaves(ratios_proposals))
return proposal.logdensity - state.logdensity + ratio_proposal

def kernel(
rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float
rng_key: PRNGKey,
state: BarkerState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: jnp.Array,
) -> tuple[BarkerState, BarkerInfo]:
"""Generate a new sample with the MALA kernel."""
"""Generate a new sample with the Barker kernel."""
grad_fn = jax.value_and_grad(logdensity_fn)

key_sample, key_rmh = jax.random.split(rng_key)

mass_matrix_sqrt, inv_mass_matrix_sqrt = _get_mass_matrix_sqrt(
inverse_mass_matrix
)

proposed_pos = _barker_sample(
key_sample, state.position, state.logdensity_grad, step_size
key_sample,
state.position,
state.logdensity_grad,
step_size,
mass_matrix_sqrt,
)

proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos)
proposed_state = BarkerState(
proposed_pos, proposed_logdensity, proposed_logdensity_grad
)

log_p_accept = _compute_acceptance_probability(state, proposed_state)
log_p_accept = _compute_acceptance_probability(
state, proposed_state, mass_matrix_sqrt, inv_mass_matrix_sqrt
)
accepted_state, info = static_binomial_sampling(
key_rmh, log_p_accept, state, proposed_state
)
Expand All @@ -129,8 +142,7 @@ def kernel(


def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
logdensity_fn: Callable, step_size: float, inverse_mass_matrix: jnp.Array
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
Gaussian base kernel.
Expand Down Expand Up @@ -175,6 +187,8 @@ def as_top_level_api(
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
inverse_mass_matrix
The inverse mass matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`).
Returns
-------
Expand All @@ -189,12 +203,12 @@ def init_fn(position: ArrayLikeTree, rng_key=None):
return init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)
return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix)

return SamplingAlgorithm(init_fn, step_fn)


def _barker_sample_nd(key, mean, a, scale):
def _barker_sample_nd(key, mean, a, scale, C_t):
"""
Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function:
Expand All @@ -214,8 +228,10 @@ def _barker_sample_nd(key, mean, a, scale):
a
The parameter :math:`a` in the equation above, an Array. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above.
The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above.
It encodes the step size of the proposal.
C_t
The transpose of the sqrt of the mass matrix, an Array. It is not used in the 1D version of Barker's proposal and thus not present in the equation above.
Returns
-------
Expand All @@ -225,17 +241,18 @@ def _barker_sample_nd(key, mean, a, scale):

key1, key2 = jax.random.split(key)
z = scale * jax.random.normal(key1, shape=mean.shape)
c = a.dot(C_t)

# Sample b=1 with probability p and 0 with probability 1 - p where
# p = 1 / (1 + exp(-a * (z - mean)))
log_p = -_log1pexp(-a * z)
log_p = -_log1pexp(-c * z)
b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape)

# return mean + z if b == 1 else mean - z
return mean + b * z - (1 - b) * z
return mean + C_t.dot(b * z - (1 - b) * z)


def _barker_sample(key, mean, a, scale):
def _barker_sample(key, mean, a, scale, C_t):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.
Expand All @@ -248,21 +265,47 @@ def _barker_sample(key, mean, a, scale):
a
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above.
The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above.
It encodes the step size of the proposal.
C_t
The transpose of the sqrt of the mass matrix, an Array.
"""

flat_mean, unravel_fn = ravel_pytree(mean)
flat_a, _ = ravel_pytree(a)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale, C_t)
return unravel_fn(flat_sample)


def _log1pexp(a):
return jnp.log1p(jnp.exp(a))


def _get_mass_matrix_sqrt(inverse_mass_matrix):
# want transpoed cholesky decomposition C_t of mass matrix (see Appendix G of paper)

ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type]
shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type]
if ndim == 1: # diagonal
inv_mass_matrix_sqrt = jnp.sqrt(inverse_mass_matrix)
mass_matrix_sqrt = jnp.reciprocal(inv_mass_matrix_sqrt)
elif ndim == 2:
# inverse mass matrix can be factored into L*L.T. We want the cholesky
# factor (inverse of L.T) of the mass matrix.
L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
L, identity, lower=True, trans=True
)
inv_mass_matrix_sqrt = L.T
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {ndim}."
)
return mass_matrix_sqrt, inv_mass_matrix_sqrt


def _barker_logpdf(x, mean, a, scale):
logpdf = jnp.log(2) + stats.norm.logpdf(x, mean, scale) - _log1pexp(-a * (x - mean))
return logpdf
Expand Down

0 comments on commit 68ede50

Please sign in to comment.