Skip to content

Commit

Permalink
Add GradientState and GradientEstimator data structures
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 7, 2022
1 parent da8947c commit 5410741
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 51 deletions.
4 changes: 2 additions & 2 deletions blackjax/mcmc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def generate_gaussian_noise(rng_key: PRNGKey, position):
return unravel_fn(noise_flat)


def overdamped_langevin(logprob_and_grad_fn):
def overdamped_langevin(logprob_grad_fn):
"""Euler solver for overdamped Langevin diffusion."""

def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = ()):
Expand All @@ -34,7 +34,7 @@ def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = ()
noise,
)

logprob, logprob_grad = logprob_and_grad_fn(position, *batch)
logprob, logprob_grad = logprob_grad_fn(position, *batch)
return DiffusionState(position, logprob, logprob_grad)

return one_step
31 changes: 17 additions & 14 deletions blackjax/sgmcmc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp

from blackjax.sgmcmc.gradients import GradientState
from blackjax.types import PRNGKey, PyTree

__all__ = ["overdamped_langevin"]
Expand All @@ -12,6 +13,7 @@
class DiffusionState(NamedTuple):
position: PyTree
logprob_grad: PyTree
grad_estimator_state: GradientState


def generate_gaussian_noise(rng_key: PRNGKey, position: PyTree):
Expand All @@ -20,13 +22,13 @@ def generate_gaussian_noise(rng_key: PRNGKey, position: PyTree):
return unravel_fn(noise_flat)


def overdamped_langevin(logprob_grad_fn):
def overdamped_langevin(grad_estimator_fn):
"""Euler solver for overdamped Langevin diffusion."""

def one_step(
rng_key: PRNGKey, state: DiffusionState, step_size: float, batch: tuple = ()
rng_key: PRNGKey, state: DiffusionState, step_size: float, minibatch: tuple = ()
):
position, logprob_grad = state
position, logprob_grad, grad_estimator_state = state
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(
lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n,
Expand All @@ -35,8 +37,10 @@ def one_step(
noise,
)

logprob_grad = logprob_grad_fn(position, batch)
return DiffusionState(position, logprob_grad)
logprob_grad, gradient_estimator_state = grad_estimator_fn(
grad_estimator_state, position, minibatch
)
return DiffusionState(position, logprob_grad, gradient_estimator_state)

return one_step

Expand All @@ -45,9 +49,10 @@ class SGHMCState(NamedTuple):
position: PyTree
momentum: PyTree
logprob_grad: PyTree
grad_estimator_state: GradientState


def sghmc(logprob_grad_fn, alpha: float = 0.01, beta: float = 0):
def sghmc(grad_estimator_fn, alpha: float = 0.01, beta: float = 0):
"""Solver for the diffusion equation of the SGHMC algorithm [0]_.
References
Expand All @@ -59,9 +64,9 @@ def sghmc(logprob_grad_fn, alpha: float = 0.01, beta: float = 0):
"""

def one_step(
rng_key: PRNGKey, state: SGHMCState, step_size: float, batch: tuple = ()
) -> SGHMCState:
position, momentum, logprob_grad = state
rng_key: PRNGKey, state: SGHMCState, step_size: float, minibatch: tuple = ()
):
position, momentum, logprob_grad, grad_estimator_state = state
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum)
momentum = jax.tree_util.tree_map(
Expand All @@ -73,11 +78,9 @@ def one_step(
noise,
)

logprob_grad = logprob_grad_fn(position, batch)
return SGHMCState(
position,
momentum,
logprob_grad,
logprob_grad, gradient_estimator_state = grad_estimator_fn(
grad_estimator_state, position, minibatch
)
return SGHMCState(position, momentum, logprob_grad, gradient_estimator_state)

return one_step
72 changes: 51 additions & 21 deletions blackjax/sgmcmc/gradients.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import Callable
from typing import Callable, NamedTuple, Tuple, Union

import jax
import jax.numpy as jnp

from blackjax.types import PyTree


class GradientEstimator(NamedTuple):
init: Callable
estimate: Callable


def grad_estimator(
logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int
) -> Callable:
) -> GradientEstimator:
"""Builds a simple gradient estimator.
This estimator first appeared in [1]_. The `logprior_fn` function has a
Expand All @@ -34,7 +39,12 @@ def grad_estimator(
"""

def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
def init_fn(_) -> None:
return None

def logposterior_estimator_fn(
position: PyTree, minibatch: PyTree
) -> Tuple[PyTree, None]:
"""Returns an approximation of the log-posterior density.
Parameters
Expand All @@ -53,20 +63,30 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
logprior = logprior_fn(position)
batch_loglikelihood = jax.vmap(loglikelihood_fn, in_axes=(None, 0))
return logprior + data_size * jnp.mean(
batch_loglikelihood(position, data_batch), axis=0
batch_loglikelihood(position, minibatch), axis=0
)

logprob_grad = jax.grad(logposterior_estimator_fn)
def grad_estimator_fn(_, position, data_batch):
return jax.grad(logposterior_estimator_fn)(position, data_batch), None

return GradientEstimator(init_fn, grad_estimator_fn)

return logprob_grad

class CVGradientState(NamedTuple):
"""The state of the CV gradient estimator contains the gradient of the
Control Variate computed on the whole dataset at initialization.
"""

control_variate_grad: PyTree


def cv_grad_estimator(
logprior_fn: Callable,
loglikelihood_fn: Callable,
data: PyTree,
centering_position: PyTree,
) -> Callable:
) -> GradientEstimator:
"""Builds a control variate gradient estimator [1]_.
Parameters
Expand All @@ -91,12 +111,17 @@ def cv_grad_estimator(
data_size = jax.tree_leaves(data)[0].shape[0]
logposterior_grad_estimator_fn = grad_estimator(
logprior_fn, loglikelihood_fn, data_size
)
).estimate

# Control Variates use the gradient on the full dataset
logposterior_grad_center = logposterior_grad_estimator_fn(centering_position, data)
def init_fn(full_dataset: PyTree) -> CVGradientState:
"""Compute the control variate on the whole dataset."""
return CVGradientState(
logposterior_grad_estimator_fn(None, centering_position, full_dataset)[0]
)

def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
def grad_estimator_fn(
grad_estimator_state: CVGradientState, position: PyTree, minibatch: PyTree
) -> Tuple[PyTree, CVGradientState]:
"""Return an approximation of the log-posterior density.
Parameters
Expand All @@ -114,20 +139,25 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
"""
logposterior_grad_estimate = logposterior_grad_estimator_fn(
position, data_batch
)
None, position, minibatch
)[0]
logposterior_grad_center_estimate = logposterior_grad_estimator_fn(
centering_position, data_batch
)
None, centering_position, minibatch
)[0]

def control_variate(grad_estimate, center_grad_estimate, center_grad):
return grad_estimate + center_grad - center_grad_estimate

return jax.tree_util.tree_map(
control_variate,
logposterior_grad_estimate,
logposterior_grad_center,
logposterior_grad_center_estimate,
return (
control_variate(
logposterior_grad_estimate,
grad_estimator_state.control_variate_grad,
logposterior_grad_center_estimate,
),
grad_estimator_state,
)

return logposterior_estimator_fn
return GradientEstimator(init_fn, grad_estimator_fn)


GradientState = Union[None, CVGradientState]
22 changes: 16 additions & 6 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import jax.numpy as jnp

from blackjax.sgmcmc.diffusion import SGHMCState, sghmc
from blackjax.sgmcmc.gradients import GradientEstimator
from blackjax.sgmcmc.sgld import SGLDState
from blackjax.sgmcmc.sgld import init as sgld_init
from blackjax.types import PRNGKey, PyTree

__all__ = ["kernel"]
Expand All @@ -19,26 +21,34 @@ def sample_momentum(rng_key: PRNGKey, position: PyTree, step_size: float):
return unravel_fn(noise_flat)


init = sgld_init


def kernel(
grad_estimator_fn: Callable, alpha: float = 0.01, beta: float = 0
gradient_estimator: GradientEstimator, alpha: float = 0.01, beta: float = 0
) -> Callable:

grad_estimator_fn = gradient_estimator.estimate
integrator = sghmc(grad_estimator_fn)

def one_step(
rng_key: PRNGKey, state: SGLDState, data_batch: PyTree, step_size: float, L: int
rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float, L: int
) -> SGLDState:

step, position, logprob_grad = state
step, position, logprob_grad, grad_estimator_state = state
momentum = sample_momentum(rng_key, position, step_size)
diffusion_state = SGHMCState(position, momentum, logprob_grad)
diffusion_state = SGHMCState(
position, momentum, logprob_grad, grad_estimator_state
)

def body_fn(state, rng_key):
new_state = integrator(rng_key, state, step_size, data_batch)
new_state = integrator(rng_key, state, step_size, minibatch)
return new_state, new_state

keys = jax.random.split(rng_key, L)
last_state, _ = jax.lax.scan(body_fn, diffusion_state, keys)
position, _, logprob_grad, grad_estimator_state = last_state

return SGLDState(step + 1, last_state.position, last_state.logprob_grad)
return SGLDState(step + 1, position, logprob_grad, grad_estimator_state)

return one_step
26 changes: 19 additions & 7 deletions blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, NamedTuple

from blackjax.sgmcmc.diffusion import overdamped_langevin
from blackjax.sgmcmc.gradients import GradientEstimator, GradientState
from blackjax.types import PRNGKey, PyTree

__all__ = ["SGLDState", "init", "kernel"]
Expand All @@ -11,22 +12,33 @@ class SGLDState(NamedTuple):
step: int
position: PyTree
logprob_grad: PyTree
grad_estimator_state: GradientState


def init(position: PyTree, batch, grad_estimator_fn: Callable):
logprob_grad = grad_estimator_fn(position, batch)
return SGLDState(0, position, logprob_grad)
# We can compute the gradient at the begining of the kernel step
# This allows to get rid of much of the init function, AND
# Prevents a last useless gradient computation at the last step


def kernel(grad_estimator_fn: Callable) -> Callable:
def init(position: PyTree, minibatch, gradient_estimator: GradientEstimator):
grad_estimator_state = gradient_estimator.init(minibatch)
logprob_grad, grad_estimator_state = gradient_estimator.estimate(
grad_estimator_state, position, minibatch
)
return SGLDState(0, position, logprob_grad, grad_estimator_state)


def kernel(gradient_estimator: GradientEstimator) -> Callable:

grad_estimator_fn = gradient_estimator.estimate
integrator = overdamped_langevin(grad_estimator_fn)

def one_step(
rng_key: PRNGKey, state: SGLDState, data_batch: PyTree, step_size: float
) -> SGLDState:
rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float
):

step, *diffusion_state = state
new_state = integrator(rng_key, diffusion_state, step_size, data_batch)
new_state = integrator(rng_key, diffusion_state, step_size, minibatch)

return SGLDState(step + 1, *new_state)

Expand Down
25 changes: 24 additions & 1 deletion tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_linear_regression_sgld_cv(self):
data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))

centering_position = jnp.ones(5)
centering_position = 1.0
grad_fn = blackjax.sgmcmc.gradients.cv_grad_estimator(
self.logprior_fn, self.loglikelihood_fn, X_data, centering_position
)
Expand Down Expand Up @@ -310,6 +310,29 @@ def test_linear_regression_sghmc(self, learning_rate, error_expected):
data_batch = X_data[100:200, :]
_ = sghmc.step(rng_key, init_state, data_batch)

def test_linear_regression_sghmc_cv(self):
"""Test the HMC kernel and the Stan warmup."""
import blackjax.sgmcmc.gradients

rng_key, data_key = jax.random.split(self.key, 2)

data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))

centering_position = 1.0
grad_fn = blackjax.sgmcmc.gradients.cv_grad_estimator(
self.logprior_fn, self.loglikelihood_fn, X_data, centering_position
)

sghmc = blackjax.sghmc(grad_fn, 1e-3, 10)
init_position = 1.0
data_batch = X_data[:100, :]
init_state = sghmc.init(init_position, data_batch)

_, rng_key = jax.random.split(rng_key)
data_batch = X_data[100:200, :]
_ = sghmc.step(rng_key, init_state, data_batch)


class LatentGaussianTest(chex.TestCase):
"""Test sampling of a linear regression model."""
Expand Down

0 comments on commit 5410741

Please sign in to comment.