Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Contour SGMCMC sampler. #396

Merged
merged 4 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .diagnostics import potential_scale_reduction as rhat
from .kernels import (
adaptive_tempered_smc,
csgld,
elliptical_slice,
ghmc,
hmc,
Expand Down Expand Up @@ -39,6 +40,7 @@
"meads",
"sgld", # stochastic gradient mcmc
"sghmc",
"csgld",
"window_adaptation", # mcmc adaptation
"pathfinder_adaptation",
"adaptive_tempered_smc", # smc
Expand Down
56 changes: 53 additions & 3 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"rmh",
"sgld",
"sghmc",
"csgld",
"tempered_smc",
"window_adaptation",
"irmh",
Expand Down Expand Up @@ -533,13 +534,21 @@ class sgld:

def __new__( # type: ignore[misc]
cls,
grad_estimator: sgmcmc.gradients.GradientEstimator,
grad_estimator: Callable,
) -> Callable:

step = cls.kernel()

def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float):
return step(rng_key, state, grad_estimator, minibatch, step_size)
def step_fn(
rng_key: PRNGKey,
state,
minibatch: PyTree,
step_size: float,
temperature: float = 1,
):
return step(
rng_key, state, grad_estimator, minibatch, step_size, temperature
)

return step_fn

Expand Down Expand Up @@ -623,6 +632,47 @@ def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float):
return step_fn


class csgld:

init = staticmethod(sgmcmc.csgld.init)
kernel = staticmethod(sgmcmc.csgld.kernel)

def __new__( # type: ignore[misc]
cls,
logdensity_estimator_fn: Callable,
zeta: float = 1,
temperature: float = 0.01,
num_partitions: int = 512,
energy_gap: float = 100,
rlouf marked this conversation as resolved.
Show resolved Hide resolved
min_energy: float = 0,
) -> MCMCSamplingAlgorithm:

step = cls.kernel(num_partitions, energy_gap, min_energy)

def init_fn(position: PyTree):
return cls.init(position, num_partitions)

def step_fn(
rng_key: PRNGKey,
state,
minibatch: PyTree,
step_size_diff: float,
step_size_stoch: float,
):
return step(
rng_key,
state,
logdensity_estimator_fn,
minibatch,
step_size_diff,
step_size_stoch,
zeta,
temperature,
)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]


# -----------------------------------------------------------------------------
# ADAPTATION
# -----------------------------------------------------------------------------
Expand Down
5 changes: 3 additions & 2 deletions blackjax/sgmcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import gradients, sghmc, sgld
from . import csgld, sghmc, sgld
from .gradients import grad_estimator, logdensity_estimator

__all__ = ["gradients", "sgld", "sghmc"]
__all__ = ["grad_estimator", "logdensity_estimator", "csgld", "sgld", "sghmc"]
176 changes: 176 additions & 0 deletions blackjax/sgmcmc/csgld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Public API for the Contour Stochastic gradient Langevin Dynamics kernel.

References
----------
.. [0]: Deng, W., Lin, G., Liang, F. (2020).
A Contour Stochastic Gradient Langevin Dynamics Algorithm
for Simulations of Multi-modal Distributions.
In Neural Information Processing Systems (NeurIPS 2020).

.. [1]: Deng, W., Liang, S., Hao, B., Lin, G., Liang, F. (2022)
Interacting Contour Stochastic Gradient Langevin Dynamics
In International Conference on Learning Representations (ICLR)
"""
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

from blackjax.sgmcmc.diffusions import overdamped_langevin
from blackjax.types import Array, PRNGKey, PyTree

__all__ = ["ContourSGLDState", "init", "kernel"]


class ContourSGLDState(NamedTuple):
r"""State of the Contour SgLD algorithm.

Parameters
----------
position
Current position in the sample space.
energy_pdf
Vector with `m` non-negative values that sum to 1. The `i`-th value
of the vector is equal to :math:`\int_{S_1} \pi(\mathrm{d}x)` where
:math:`S_i` is the `i`-th energy partition.
energy_idx
Index `i` such that the current position belongs to :math:`S_i`.

"""
position: PyTree
energy_pdf: Array
energy_idx: int


def init(position: PyTree, num_partitions=512):
energy_pdf = (
jnp.arange(num_partitions, 0, -1) / jnp.arange(num_partitions, 0, -1).sum()
)
return ContourSGLDState(position, energy_pdf, num_partitions - 1)


def kernel(num_partitions=512, energy_gap=10, min_energy=0) -> Callable:
r"""

Parameters
----------
num_partitions
The number of partitions we divide the energy landscape into.
energy_gap
The difference in energy :math:`\Delta u` between the successive
partitions. Can be determined by running e.g. an optimizer to determine
the range of energies. `num_partition` * `energy_gap` should match this
range.
rlouf marked this conversation as resolved.
Show resolved Hide resolved
min_energy
A rough estimate of the minimum energy in a dataset, which should be
strictly smaller than the exact minimum energy! e.g. if the minimum
energy of a dataset is 3456, we can set min_energy to be any value
smaller than 3456. Set it to 0 is acceptable, but not efficient enough.
the closer the gap between min_energy and 3456 is, the better.
"""

integrator = overdamped_langevin()

def one_step(
rng_key: PRNGKey,
state: ContourSGLDState,
logdensity_estimator_fn: Callable,
minibatch: PyTree,
step_size_diff: float, # step size for Langevin diffusion
step_size_stoch: float = 1e-3, # step size for stochastic approximation
zeta: float = 1,
temperature: float = 1.0,
) -> ContourSGLDState:
r"""Multil-modal sampling via Contour SGLD.

We are interested in the simulations of :math:`\exp(-U(x) / T)`,
where :math:`U` is an energy function and :math:`T` is the temperature.

To do so we partition the energy space into :math:`m`:

.. math::
S_0 = {x: U(x) <= u_1}
S_1 = {x: u_1 < U(x) <= u_2}
S_2 = {x: u_2 < U(x) <= u_3}
...
S_{m-2} = {x: u_{m-2} < U(x) <= u_{m-1}}
S_{m-1} = {x: U(x) > u_{m-1}}

where :math:`-\inf < u_1 < u_2 < · · · < u_{m−1} < \inf`. We assume
:math:`u_{i+1} − u_i = \Delta u` for :math:`i = 1, \dots , m−2`.

Parameters
----------
rng_key
State of the pseudo-random number generator.
state
Current state of the CSGLD sampler
logdensity_estimator_fn
Function that returns an estimation of the value of the density
function at the current position.
minibatch
Minibatch of data.
step_size_diff
Step size for the dynamics integration. Also called learning rate.
step_size_stoch
Step size for the update of the energy estimation.
zeta
Hyperparameter that controls the geometric property of the flattened
density. If `zeta=0` the function reduces to the SGLD step function.
temperature
Temperature parameter :math:`T`.

References
----------
.. [0]: Deng, W., Lin, G., Liang, F. (2020).
A Contour Stochastic Gradient Langevin Dynamics Algorithm
for Simulations of Multi-modal Distributions.
In Neural Information Processing Systems (NeurIPS 2020).

[1]: Deng, W., Liang, S., Hao, B., Lin, G., Liang, F. (2022)
Interacting Contour Stochastic Gradient Langevin Dynamics
In International Conference on Learning Representations (ICLR)
"""

position, energy_pdf, idx = state

# Update the position using the overdamped Langevin diffusion
gradient_multiplier = (
1.0
+ zeta
* temperature
* (jnp.log(energy_pdf[idx]) - jnp.log(energy_pdf[idx - 1]))
/ energy_gap
)

logprob_grad = jax.grad(logdensity_estimator_fn)(position, minibatch)
position = integrator(
rng_key,
position,
jax.tree_util.tree_map(lambda g: gradient_multiplier * g, logprob_grad),
step_size_diff,
temperature,
)

# Update the stochastic approximation to the energy histogram
neg_logprob = -logdensity_estimator_fn(position, minibatch)
idx = jax.lax.min(
jax.lax.max(
jax.lax.floor((neg_logprob - min_energy) / energy_gap + 1).astype(
"int32"
),
1,
),
num_partitions - 1,
)

energy_pdf_update = -energy_pdf.copy()
energy_pdf_update = energy_pdf_update.at[idx].set(energy_pdf_update[idx] + 1)
energy_pdf = jax.tree_util.tree_map(
lambda e: e + step_size_stoch * energy_pdf[idx] * energy_pdf_update,
energy_pdf,
)
Comment on lines +167 to +172
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since energy_pdf_update is not used after that we can probably make this more efficient.


return ContourSGLDState(position, energy_pdf, idx)

return one_step
10 changes: 7 additions & 3 deletions blackjax/sgmcmc/diffusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from blackjax.types import PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise

__all__ = ["overdamped_langevin"]
__all__ = ["overdamped_langevin", "sghmc"]


def overdamped_langevin():
Expand All @@ -39,11 +39,14 @@ def one_step(
position: PyTree,
logdensity_grad: PyTree,
step_size: float,
temperature: float = 1.0,
) -> PyTree:

noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(
lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n,
lambda p, g, n: p
+ step_size * g
+ jnp.sqrt(2 * temperature * step_size) * n,
position,
logdensity_grad,
noise,
Expand Down Expand Up @@ -76,13 +79,14 @@ def one_step(
momentum: PyTree,
logdensity_grad: PyTree,
step_size: float,
temperature: float = 1.0,
):
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum)
momentum = jax.tree_util.tree_map(
lambda p, g, n: (1.0 - alpha) * p
+ step_size * g
+ jnp.sqrt(2 * step_size * (alpha - beta)) * n,
+ jnp.sqrt(2 * step_size * (alpha - beta) * temperature) * n,
momentum,
logdensity_grad,
noise,
Expand Down
Loading