Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the control variates gradient estimator #299

Merged
merged 9 commits into from
Nov 20, 2022
123 changes: 47 additions & 76 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,10 @@ def step_fn(rng_key: PRNGKey, state, delta: float):
class sgld:
"""Implements the (basic) user interface for the SGLD kernel.

The general sgld kernel (:meth:`blackjax.mcmc.sgld.kernel`, alias `blackjax.sgld.kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
parameters at initialization time, we provide a helper function that
specializes the general kernel.
The general sgld kernel (:meth:`blackjax.mcmc.sgld.kernel`, alias
`blackjax.sgld.kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.

Example
-------
Expand All @@ -476,35 +476,36 @@ class sgld:

.. code::

schedule_fn = lambda _: 1e-3
grad_fn = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)

We can now initialize the sgld kernel and the state:

.. code::

sgld = blackjax.sgld(grad_fn, schedule_fn)
sgld = blackjax.sgld(grad_fn)
state = sgld.init(position)

Assuming we have an iterator `batches` that yields batches of data we can perform one step:
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:

.. code::

data_batch = next(batches)
new_state = sgld.step(rng_key, state, data_batch)
step_size = 1e-3
minibatch = next(batches)
new_state = sgld.step(rng_key, state, minibatch, step_size)

Kernels are not jit-compiled by default so you will need to do it manually:

.. code::

step = jax.jit(sgld.step)
new_state, info = step(rng_key, state)
new_state, info = step(rng_key, state, minibatch, step_size)

Parameters
----------
gradient_estimator_fn
A function which, given a position and a batch of data, returns an estimation
of the value of the gradient of the log-posterior distribution at this position.
gradient_estimator
A tuple of functions that initialize and update the gradient estimation
state.
schedule_fn
A function which returns a step size given a step number.

Expand All @@ -514,47 +515,28 @@ class sgld:

"""

init = staticmethod(sgmcmc.sgld.init)
kernel = staticmethod(sgmcmc.sgld.kernel)

def __new__( # type: ignore[misc]
cls,
grad_estimator_fn: Callable,
learning_rate: Union[Callable[[int], float], float],
) -> MCMCSamplingAlgorithm:

step = cls.kernel(grad_estimator_fn)

if callable(learning_rate):
learning_rate_fn = learning_rate
elif float(learning_rate):

def learning_rate_fn(_):
return learning_rate
grad_estimator: sgmcmc.gradients.GradientEstimator,
) -> Callable:

else:
raise TypeError(
"The learning rate must either be a float (which corresponds to a constant learning rate) "
f"or a function of the index of the current iteration. Got {type(learning_rate)} instead."
)

def init_fn(position: PyTree, data_batch: PyTree):
return cls.init(position, data_batch, grad_estimator_fn)
step = cls.kernel()

def step_fn(rng_key: PRNGKey, state, data_batch: PyTree):
step_size = learning_rate_fn(state.step)
return step(rng_key, state, data_batch, step_size)
def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float):
return step(rng_key, state, grad_estimator, minibatch, step_size)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return step_fn


class sghmc:
"""Implements the (basic) user interface for the SGHMC kernel.

The general sghmc kernel (:meth:`blackjax.mcmc.sghmc.kernel`, alias `blackjax.sghmc.kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
parameters at initialization time, we provide a helper function that
specializes the general kernel.
The general sghmc kernel (:meth:`blackjax.mcmc.sghmc.kernel`, alias
`blackjax.sghmc.kernel`) can be cumbersome to manipulate. Since most users
only need to specify the kernel parameters at initialization time, we
provide a helper function that specializes the general kernel.

Example
-------
Expand All @@ -565,35 +547,36 @@ class sghmc:

.. code::

schedule_fn = lambda _: 1e-3
grad_fn = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)

We can now initialize the sghmc kernel and the state. Like HMC, SGHMC needs the user to specify a number of integration steps.

.. code::

sghmc = blackjax.sghmc(grad_fn, schedule_fn, num_integration_steps)
sghmc = blackjax.sghmc(grad_estimator, num_integration_steps)
state = sghmc.init(position)

Assuming we have an iterator `batches` that yields batches of data we can perform one step:
Assuming we have an iterator `batches` that yields batches of data we can
perform one step:

.. code::

data_batch = next(batches)
new_state = sghmc.step(rng_key, state, data_batch)
step_size = 1e-3
minibatch = next(batches)
new_state = sghmc.step(rng_key, state, minibatch, step_size)

Kernels are not jit-compiled by default so you will need to do it manually:

.. code::

step = jax.jit(sghmc.step)
new_state, info = step(rng_key, state)
new_state, info = step(rng_key, state, minibatch, step_size)

Parameters
----------
gradient_estimator_fn
A function which, given a position and a batch of data, returns an estimation
of the value of the gradient of the log-posterior distribution at this position.
gradient_estimator
A tuple of functions that initialize and update the gradient estimation
state.
schedule_fn
A function which returns a step size given a step number.

Expand All @@ -603,39 +586,27 @@ class sghmc:

"""

init = staticmethod(sgmcmc.sgld.init)
kernel = staticmethod(sgmcmc.sghmc.kernel)

def __new__( # type: ignore[misc]
cls,
grad_estimator_fn: Callable,
learning_rate: Union[Callable[[int], float], float],
grad_estimator: Callable,
num_integration_steps: int = 10,
) -> MCMCSamplingAlgorithm:

step = cls.kernel(grad_estimator_fn)
) -> Callable:

if callable(learning_rate):
learning_rate_fn = learning_rate
elif float(learning_rate):

def learning_rate_fn(_):
return learning_rate
step = cls.kernel()

else:
raise TypeError(
"The learning rate must either be a float (which corresponds to a constant learning rate) "
f"or a function of the index of the current iteration. Got {type(learning_rate)} instead."
def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float):
return step(
rng_key,
state,
grad_estimator,
minibatch,
step_size,
num_integration_steps,
)

def init_fn(position: PyTree, data_batch: PyTree):
return cls.init(position, data_batch, grad_estimator_fn)

def step_fn(rng_key: PRNGKey, state, data_batch: PyTree):
step_size = learning_rate_fn(state.step)
return step(rng_key, state, data_batch, step_size, num_integration_steps)

return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]
return step_fn


# -----------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/diffusion.py → blackjax/mcmc/diffusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DiffusionState(NamedTuple):
logprob_grad: PyTree


def overdamped_langevin(logprob_and_grad_fn):
def overdamped_langevin(logprob_grad_fn):
"""Euler solver for overdamped Langevin diffusion."""

def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = ()):
Expand All @@ -29,7 +29,7 @@ def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = ()
noise,
)

logprob, logprob_grad = logprob_and_grad_fn(position, *batch)
logprob, logprob_grad = logprob_grad_fn(position, *batch)
return DiffusionState(position, logprob, logprob_grad)

return one_step
4 changes: 2 additions & 2 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp

from blackjax.mcmc.diffusion import overdamped_langevin
import blackjax.mcmc.diffusions as diffusions
from blackjax.types import PRNGKey, PyTree

__all__ = ["MALAState", "MALAInfo", "init", "kernel"]
Expand Down Expand Up @@ -83,7 +83,7 @@ def one_step(

"""
grad_fn = jax.value_and_grad(logprob_fn)
integrator = overdamped_langevin(grad_fn)
integrator = diffusions.overdamped_langevin(grad_fn)

key_integrator, key_rmh = jax.random.split(rng_key)

Expand Down
4 changes: 2 additions & 2 deletions blackjax/sgmcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import sghmc, sgld
from . import gradients, sghmc, sgld

__all__ = ["sgld", "sghmc"]
__all__ = ["gradients", "sgld", "sghmc"]
59 changes: 31 additions & 28 deletions blackjax/sgmcmc/diffusion.py → blackjax/sgmcmc/diffusions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Solvers for Langevin diffusions."""
from typing import NamedTuple

import jax
import jax.numpy as jnp

Expand All @@ -10,18 +8,26 @@
__all__ = ["overdamped_langevin"]


class DiffusionState(NamedTuple):
position: PyTree
logprob_grad: PyTree
def overdamped_langevin():
"""Euler solver for overdamped Langevin diffusion.

This algorithm was ported from [0]_.

def overdamped_langevin(logprob_grad_fn):
"""Euler solver for overdamped Langevin diffusion."""
References
----------
.. [0]: Coullon, J., & Nemeth, C. (2022). SGMCMCJax: a lightweight JAX
library for stochastic gradient Markov chain Monte Carlo algorithms.
Journal of Open Source Software, 7(72), 4113.
"""

def one_step(
rng_key: PRNGKey, state: DiffusionState, step_size: float, batch: tuple = ()
):
position, logprob_grad = state
rng_key: PRNGKey,
position: PyTree,
logprob_grad: PyTree,
step_size: float,
minibatch: tuple = (),
) -> PyTree:

noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(
lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n,
Expand All @@ -30,33 +36,35 @@ def one_step(
noise,
)

logprob_grad = logprob_grad_fn(position, batch)
return DiffusionState(position, logprob_grad)
return position

return one_step


class SGHMCState(NamedTuple):
position: PyTree
momentum: PyTree
logprob_grad: PyTree


def sghmc(logprob_grad_fn, alpha: float = 0.01, beta: float = 0):
def sghmc(alpha: float = 0.01, beta: float = 0):
"""Solver for the diffusion equation of the SGHMC algorithm [0]_.

This algorithm was ported from [1]_.

References
----------
.. [0]: Chen, T., Fox, E., & Guestrin, C. (2014, June). Stochastic
gradient hamiltonian monte carlo. In International conference on
machine learning (pp. 1683-1691). PMLR.
.. [1]: Coullon, J., & Nemeth, C. (2022). SGMCMCJax: a lightweight JAX
library for stochastic gradient Markov chain Monte Carlo algorithms.
Journal of Open Source Software, 7(72), 4113.

"""

def one_step(
rng_key: PRNGKey, state: SGHMCState, step_size: float, batch: tuple = ()
) -> SGHMCState:
position, momentum, logprob_grad = state
rng_key: PRNGKey,
position: PyTree,
momentum: PyTree,
logprob_grad: PyTree,
step_size: float,
minibatch: tuple = (),
):
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum)
momentum = jax.tree_util.tree_map(
Expand All @@ -68,11 +76,6 @@ def one_step(
noise,
)

logprob_grad = logprob_grad_fn(position, batch)
return SGHMCState(
position,
momentum,
logprob_grad,
)
return position, momentum

return one_step
Loading