From b3d5ee9b36e6a4558cd637550035ef86bb9f300b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 20 Sep 2022 21:00:06 +0200 Subject: [PATCH] WIP - Add gradient state --- blackjax/sgmcmc/diffusion.py | 19 +++++++++++-------- blackjax/sgmcmc/gradients.py | 19 +++++++++++++------ blackjax/sgmcmc/sghmc.py | 12 ++++++++++-- blackjax/sgmcmc/sgld.py | 19 ++++++++++++++----- 4 files changed, 48 insertions(+), 21 deletions(-) diff --git a/blackjax/sgmcmc/diffusion.py b/blackjax/sgmcmc/diffusion.py index da04cce48..29cee1d74 100644 --- a/blackjax/sgmcmc/diffusion.py +++ b/blackjax/sgmcmc/diffusion.py @@ -35,8 +35,8 @@ def one_step( noise, ) - logprob_grad = logprob_grad_fn(position, batch) - return DiffusionState(position, logprob_grad) + gradient_state, logprob_grad = logprob_grad_fn(position, batch) + return DiffusionState(position, logprob_grad), gradient_state return one_step @@ -60,7 +60,7 @@ def sghmc(logprob_grad_fn, alpha: float = 0.01, beta: float = 0): def one_step( rng_key: PRNGKey, state: SGHMCState, step_size: float, batch: tuple = () - ) -> SGHMCState: + ): position, momentum, logprob_grad = state noise = generate_gaussian_noise(rng_key, position) position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum) @@ -73,11 +73,14 @@ def one_step( noise, ) - logprob_grad = logprob_grad_fn(position, batch) - return SGHMCState( - position, - momentum, - logprob_grad, + gradient_state, logprob_grad = logprob_grad_fn(position, batch) + return ( + SGHMCState( + position, + momentum, + logprob_grad, + ), + gradient_state, ) return one_step diff --git a/blackjax/sgmcmc/gradients.py b/blackjax/sgmcmc/gradients.py index 7ea7ee9f7..4b5389498 100644 --- a/blackjax/sgmcmc/gradients.py +++ b/blackjax/sgmcmc/gradients.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -61,6 +61,10 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: return logprob_grad +class GradientState(NamedTuple): + control_variate_grad: PyTree + + def cv_grad_estimator( logprior_fn: Callable, loglikelihood_fn: Callable, @@ -93,10 +97,13 @@ def cv_grad_estimator( logprior_fn, loglikelihood_fn, data_size ) - # Control Variates use the gradient on the full dataset - logposterior_grad_center = logposterior_grad_estimator_fn(centering_position, data) + def init(data): + """Compute the control variate on the whole dataset.""" + return GradientState(logposterior_grad_estimator_fn(centering_position, data)) - def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: + def logposterior_estimator_fn( + gradient_state: GradientState, position: PyTree, data_batch: PyTree + ): """Return an approximation of the log-posterior density. Parameters @@ -123,10 +130,10 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: def control_variate(grad_estimate, center_grad_estimate, center_grad): return grad_estimate + center_grad - center_grad_estimate - return jax.tree_util.tree_map( + return gradient_state, jax.tree_util.tree_map( control_variate, logposterior_grad_estimate, - logposterior_grad_center, + gradient_state.control_variate_grad, logposterior_grad_center_estimate, ) diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index a1840fab7..a663f6ac4 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -6,6 +6,7 @@ from blackjax.sgmcmc.diffusion import SGHMCState, sghmc from blackjax.sgmcmc.sgld import SGLDState +from blackjax.sgmcmc.sgld import init as sgld_init from blackjax.types import PRNGKey, PyTree __all__ = ["kernel"] @@ -19,6 +20,9 @@ def sample_momentum(rng_key: PRNGKey, position: PyTree, step_size: float): return unravel_fn(noise_flat) +init = sgld_init + + def kernel( grad_estimator_fn: Callable, alpha: float = 0.01, beta: float = 0 ) -> Callable: @@ -28,7 +32,7 @@ def one_step( rng_key: PRNGKey, state: SGLDState, data_batch: PyTree, step_size: float, L: int ) -> SGLDState: - step, position, logprob_grad = state + step, position, logprob_grad, _ = state momentum = sample_momentum(rng_key, position, step_size) diffusion_state = SGHMCState(position, momentum, logprob_grad) @@ -39,6 +43,10 @@ def body_fn(state, rng_key): keys = jax.random.split(rng_key, L) last_state, _ = jax.lax.scan(body_fn, diffusion_state, keys) - return SGLDState(step + 1, last_state.position, last_state.logprob_grad) + position = last_state.position + logprob_grad = last_state.logprob_grad + gradient_state = last_state[-1] + + return SGLDState(step + 1, position, logprob_grad, gradient_state) return one_step diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index 41d53b882..8abc3caf9 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -2,6 +2,7 @@ from typing import Callable, NamedTuple from blackjax.sgmcmc.diffusion import overdamped_langevin +from blackjax.sgmcmc.gradients import GradientState from blackjax.types import PRNGKey, PyTree __all__ = ["SGLDState", "init", "kernel"] @@ -11,11 +12,17 @@ class SGLDState(NamedTuple): step: int position: PyTree logprob_grad: PyTree + gradient_state: GradientState + + +# We can compute the gradient at the begining of the kernel step +# This allows to get rid of much of the init function, AND +# Prevents a last useless gradient computation at the last step def init(position: PyTree, batch, grad_estimator_fn: Callable): - logprob_grad = grad_estimator_fn(position, batch) - return SGLDState(0, position, logprob_grad) + gradient_state, logprob_grad = grad_estimator_fn(position, batch) + return SGLDState(0, position, logprob_grad, gradient_state) def kernel(grad_estimator_fn: Callable) -> Callable: @@ -23,11 +30,13 @@ def kernel(grad_estimator_fn: Callable) -> Callable: def one_step( rng_key: PRNGKey, state: SGLDState, data_batch: PyTree, step_size: float - ) -> SGLDState: + ): step, *diffusion_state = state - new_state = integrator(rng_key, diffusion_state, step_size, data_batch) + new_state, gradient_state = integrator( + rng_key, diffusion_state, step_size, data_batch + ) - return SGLDState(step + 1, *new_state) + return SGLDState(step + 1, *new_state), gradient_state return one_step