From 750e983f0ed31469044d9cf72f80793d1fedb6e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 19 Sep 2022 12:50:17 +0200 Subject: [PATCH 1/9] Add the control variates gradient estimator --- blackjax/sgmcmc/gradients.py | 72 ++++++++++++++++++++++++++++++++++++ tests/test_sampling.py | 23 ++++++++++++ 2 files changed, 95 insertions(+) diff --git a/blackjax/sgmcmc/gradients.py b/blackjax/sgmcmc/gradients.py index ff5be17ba..7ea7ee9f7 100644 --- a/blackjax/sgmcmc/gradients.py +++ b/blackjax/sgmcmc/gradients.py @@ -59,3 +59,75 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: logprob_grad = jax.grad(logposterior_estimator_fn) return logprob_grad + + +def cv_grad_estimator( + logprior_fn: Callable, + loglikelihood_fn: Callable, + data: PyTree, + centering_position: PyTree, +) -> Callable: + """Builds a control variate gradient estimator [1]_. + + Parameters + ---------- + logprior_fn + The log-probability density function corresponding to the prior + distribution. + loglikelihood_fn + The log-probability density function corresponding to the likelihood. + data + The full dataset. + centering_position + Centering position for the control variates (typically the MAP). + + References + ---------- + .. [1]: Baker, J., Fearnhead, P., Fox, E. B., & Nemeth, C. (2019). + Control variates for stochastic gradient MCMC. Statistics + and Computing, 29(3), 599-615. + + """ + data_size = jax.tree_leaves(data)[0].shape[0] + logposterior_grad_estimator_fn = grad_estimator( + logprior_fn, loglikelihood_fn, data_size + ) + + # Control Variates use the gradient on the full dataset + logposterior_grad_center = logposterior_grad_estimator_fn(centering_position, data) + + def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: + """Return an approximation of the log-posterior density. + + Parameters + ---------- + position + The current value of the random variables. + batch + The current batch of data. The first dimension is assumed to be the + batch dimension. + + Returns + ------- + An approximation of the value of the log-posterior density function for + the current value of the random variables. + + """ + logposterior_grad_estimate = logposterior_grad_estimator_fn( + position, data_batch + ) + logposterior_grad_center_estimate = logposterior_grad_estimator_fn( + centering_position, data_batch + ) + + def control_variate(grad_estimate, center_grad_estimate, center_grad): + return grad_estimate + center_grad - center_grad_estimate + + return jax.tree_util.tree_map( + control_variate, + logposterior_grad_estimate, + logposterior_grad_center, + logposterior_grad_center_estimate, + ) + + return logposterior_estimator_fn diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 8f87e8b04..f6e989055 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -262,6 +262,29 @@ def test_linear_regression_sgld(self, learning_rate, error_expected): data_batch = X_data[100:200, :] _ = sgld.step(rng_key, init_state, data_batch) + def test_linear_regression_sgld_cv(self): + """Test the HMC kernel and the Stan warmup.""" + import blackjax.sgmcmc.gradients + + rng_key, data_key = jax.random.split(self.key, 2) + + data_size = 1000 + X_data = jax.random.normal(data_key, shape=(data_size, 5)) + + centering_position = jnp.ones(5) + grad_fn = blackjax.sgmcmc.gradients.cv_grad_estimator( + self.logprior_fn, self.loglikelihood_fn, X_data, centering_position + ) + + sgld = blackjax.sgld(grad_fn, 1e-3) + init_position = 1.0 + data_batch = X_data[:100, :] + init_state = sgld.init(init_position, data_batch) + + _, rng_key = jax.random.split(rng_key) + data_batch = X_data[100:200, :] + _ = sgld.step(rng_key, init_state, data_batch) + @parameterized.parameters((1e-3, False), (constant_step_size, False), (1, True)) def test_linear_regression_sghmc(self, learning_rate, error_expected): """Test the HMC kernel and the Stan warmup.""" From 7d1c5810708dd5e6d24a6ee25f22514bb46b8943 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 20 Sep 2022 21:00:06 +0200 Subject: [PATCH 2/9] Add GradientState and GradientEstimator data structures --- blackjax/mcmc/diffusion.py | 4 +- blackjax/sgmcmc/diffusion.py | 31 +++++++++------- blackjax/sgmcmc/gradients.py | 72 +++++++++++++++++++++++++----------- blackjax/sgmcmc/sghmc.py | 19 +++++++--- blackjax/sgmcmc/sgld.py | 26 +++++++++---- tests/test_sampling.py | 25 ++++++++++++- 6 files changed, 126 insertions(+), 51 deletions(-) diff --git a/blackjax/mcmc/diffusion.py b/blackjax/mcmc/diffusion.py index e138b6b72..8c80caa4e 100644 --- a/blackjax/mcmc/diffusion.py +++ b/blackjax/mcmc/diffusion.py @@ -16,7 +16,7 @@ class DiffusionState(NamedTuple): logprob_grad: PyTree -def overdamped_langevin(logprob_and_grad_fn): +def overdamped_langevin(logprob_grad_fn): """Euler solver for overdamped Langevin diffusion.""" def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = ()): @@ -29,7 +29,7 @@ def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = () noise, ) - logprob, logprob_grad = logprob_and_grad_fn(position, *batch) + logprob, logprob_grad = logprob_grad_fn(position, *batch) return DiffusionState(position, logprob, logprob_grad) return one_step diff --git a/blackjax/sgmcmc/diffusion.py b/blackjax/sgmcmc/diffusion.py index 86cefd09d..f14001126 100644 --- a/blackjax/sgmcmc/diffusion.py +++ b/blackjax/sgmcmc/diffusion.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +from blackjax.sgmcmc.gradients import GradientState from blackjax.types import PRNGKey, PyTree from blackjax.util import generate_gaussian_noise @@ -13,15 +14,16 @@ class DiffusionState(NamedTuple): position: PyTree logprob_grad: PyTree + grad_estimator_state: GradientState -def overdamped_langevin(logprob_grad_fn): +def overdamped_langevin(grad_estimator_fn): """Euler solver for overdamped Langevin diffusion.""" def one_step( - rng_key: PRNGKey, state: DiffusionState, step_size: float, batch: tuple = () + rng_key: PRNGKey, state: DiffusionState, step_size: float, minibatch: tuple = () ): - position, logprob_grad = state + position, logprob_grad, grad_estimator_state = state noise = generate_gaussian_noise(rng_key, position) position = jax.tree_util.tree_map( lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n, @@ -30,8 +32,10 @@ def one_step( noise, ) - logprob_grad = logprob_grad_fn(position, batch) - return DiffusionState(position, logprob_grad) + logprob_grad, gradient_estimator_state = grad_estimator_fn( + grad_estimator_state, position, minibatch + ) + return DiffusionState(position, logprob_grad, gradient_estimator_state) return one_step @@ -40,9 +44,10 @@ class SGHMCState(NamedTuple): position: PyTree momentum: PyTree logprob_grad: PyTree + grad_estimator_state: GradientState -def sghmc(logprob_grad_fn, alpha: float = 0.01, beta: float = 0): +def sghmc(grad_estimator_fn, alpha: float = 0.01, beta: float = 0): """Solver for the diffusion equation of the SGHMC algorithm [0]_. References @@ -54,9 +59,9 @@ def sghmc(logprob_grad_fn, alpha: float = 0.01, beta: float = 0): """ def one_step( - rng_key: PRNGKey, state: SGHMCState, step_size: float, batch: tuple = () - ) -> SGHMCState: - position, momentum, logprob_grad = state + rng_key: PRNGKey, state: SGHMCState, step_size: float, minibatch: tuple = () + ): + position, momentum, logprob_grad, grad_estimator_state = state noise = generate_gaussian_noise(rng_key, position) position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum) momentum = jax.tree_util.tree_map( @@ -68,11 +73,9 @@ def one_step( noise, ) - logprob_grad = logprob_grad_fn(position, batch) - return SGHMCState( - position, - momentum, - logprob_grad, + logprob_grad, gradient_estimator_state = grad_estimator_fn( + grad_estimator_state, position, minibatch ) + return SGHMCState(position, momentum, logprob_grad, gradient_estimator_state) return one_step diff --git a/blackjax/sgmcmc/gradients.py b/blackjax/sgmcmc/gradients.py index 7ea7ee9f7..23098b21e 100644 --- a/blackjax/sgmcmc/gradients.py +++ b/blackjax/sgmcmc/gradients.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, NamedTuple, Tuple, Union import jax import jax.numpy as jnp @@ -6,9 +6,14 @@ from blackjax.types import PyTree +class GradientEstimator(NamedTuple): + init: Callable + estimate: Callable + + def grad_estimator( logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int -) -> Callable: +) -> GradientEstimator: """Builds a simple gradient estimator. This estimator first appeared in [1]_. The `logprior_fn` function has a @@ -34,7 +39,12 @@ def grad_estimator( """ - def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: + def init_fn(_) -> None: + return None + + def logposterior_estimator_fn( + position: PyTree, minibatch: PyTree + ) -> Tuple[PyTree, None]: """Returns an approximation of the log-posterior density. Parameters @@ -53,12 +63,22 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: logprior = logprior_fn(position) batch_loglikelihood = jax.vmap(loglikelihood_fn, in_axes=(None, 0)) return logprior + data_size * jnp.mean( - batch_loglikelihood(position, data_batch), axis=0 + batch_loglikelihood(position, minibatch), axis=0 ) - logprob_grad = jax.grad(logposterior_estimator_fn) + def grad_estimator_fn(_, position, data_batch): + return jax.grad(logposterior_estimator_fn)(position, data_batch), None + + return GradientEstimator(init_fn, grad_estimator_fn) - return logprob_grad + +class CVGradientState(NamedTuple): + """The state of the CV gradient estimator contains the gradient of the + Control Variate computed on the whole dataset at initialization. + + """ + + control_variate_grad: PyTree def cv_grad_estimator( @@ -66,7 +86,7 @@ def cv_grad_estimator( loglikelihood_fn: Callable, data: PyTree, centering_position: PyTree, -) -> Callable: +) -> GradientEstimator: """Builds a control variate gradient estimator [1]_. Parameters @@ -91,12 +111,17 @@ def cv_grad_estimator( data_size = jax.tree_leaves(data)[0].shape[0] logposterior_grad_estimator_fn = grad_estimator( logprior_fn, loglikelihood_fn, data_size - ) + ).estimate - # Control Variates use the gradient on the full dataset - logposterior_grad_center = logposterior_grad_estimator_fn(centering_position, data) + def init_fn(full_dataset: PyTree) -> CVGradientState: + """Compute the control variate on the whole dataset.""" + return CVGradientState( + logposterior_grad_estimator_fn(None, centering_position, full_dataset)[0] + ) - def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: + def grad_estimator_fn( + grad_estimator_state: CVGradientState, position: PyTree, minibatch: PyTree + ) -> Tuple[PyTree, CVGradientState]: """Return an approximation of the log-posterior density. Parameters @@ -114,20 +139,25 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float: """ logposterior_grad_estimate = logposterior_grad_estimator_fn( - position, data_batch - ) + None, position, minibatch + )[0] logposterior_grad_center_estimate = logposterior_grad_estimator_fn( - centering_position, data_batch - ) + None, centering_position, minibatch + )[0] def control_variate(grad_estimate, center_grad_estimate, center_grad): return grad_estimate + center_grad - center_grad_estimate - return jax.tree_util.tree_map( - control_variate, - logposterior_grad_estimate, - logposterior_grad_center, - logposterior_grad_center_estimate, + return ( + control_variate( + logposterior_grad_estimate, + grad_estimator_state.control_variate_grad, + logposterior_grad_center_estimate, + ), + grad_estimator_state, ) - return logposterior_estimator_fn + return GradientEstimator(init_fn, grad_estimator_fn) + + +GradientState = Union[None, CVGradientState] diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 5bd455fa0..9eb0ada51 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -5,7 +5,9 @@ import jax.numpy as jnp from blackjax.sgmcmc.diffusion import SGHMCState, sghmc +from blackjax.sgmcmc.gradients import GradientEstimator from blackjax.sgmcmc.sgld import SGLDState +from blackjax.sgmcmc.sgld import init as sgld_init from blackjax.types import PRNGKey, PyTree from blackjax.util import generate_gaussian_noise @@ -13,25 +15,30 @@ def kernel( - grad_estimator_fn: Callable, alpha: float = 0.01, beta: float = 0 + gradient_estimator: GradientEstimator, alpha: float = 0.01, beta: float = 0 ) -> Callable: + + grad_estimator_fn = gradient_estimator.estimate integrator = sghmc(grad_estimator_fn) def one_step( - rng_key: PRNGKey, state: SGLDState, data_batch: PyTree, step_size: float, L: int + rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float, L: int ) -> SGLDState: - step, position, logprob_grad = state + step, position, logprob_grad, grad_estimator_state = state momentum = generate_gaussian_noise(rng_key, position, jnp.sqrt(step_size)) - diffusion_state = SGHMCState(position, momentum, logprob_grad) + diffusion_state = SGHMCState( + position, momentum, logprob_grad, grad_estimator_state + ) def body_fn(state, rng_key): - new_state = integrator(rng_key, state, step_size, data_batch) + new_state = integrator(rng_key, state, step_size, minibatch) return new_state, new_state keys = jax.random.split(rng_key, L) last_state, _ = jax.lax.scan(body_fn, diffusion_state, keys) + position, _, logprob_grad, grad_estimator_state = last_state - return SGLDState(step + 1, last_state.position, last_state.logprob_grad) + return SGLDState(step + 1, position, logprob_grad, grad_estimator_state) return one_step diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index 41d53b882..293745b67 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -2,6 +2,7 @@ from typing import Callable, NamedTuple from blackjax.sgmcmc.diffusion import overdamped_langevin +from blackjax.sgmcmc.gradients import GradientEstimator, GradientState from blackjax.types import PRNGKey, PyTree __all__ = ["SGLDState", "init", "kernel"] @@ -11,22 +12,33 @@ class SGLDState(NamedTuple): step: int position: PyTree logprob_grad: PyTree + grad_estimator_state: GradientState -def init(position: PyTree, batch, grad_estimator_fn: Callable): - logprob_grad = grad_estimator_fn(position, batch) - return SGLDState(0, position, logprob_grad) +# We can compute the gradient at the begining of the kernel step +# This allows to get rid of much of the init function, AND +# Prevents a last useless gradient computation at the last step -def kernel(grad_estimator_fn: Callable) -> Callable: +def init(position: PyTree, minibatch, gradient_estimator: GradientEstimator): + grad_estimator_state = gradient_estimator.init(minibatch) + logprob_grad, grad_estimator_state = gradient_estimator.estimate( + grad_estimator_state, position, minibatch + ) + return SGLDState(0, position, logprob_grad, grad_estimator_state) + + +def kernel(gradient_estimator: GradientEstimator) -> Callable: + + grad_estimator_fn = gradient_estimator.estimate integrator = overdamped_langevin(grad_estimator_fn) def one_step( - rng_key: PRNGKey, state: SGLDState, data_batch: PyTree, step_size: float - ) -> SGLDState: + rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float + ): step, *diffusion_state = state - new_state = integrator(rng_key, diffusion_state, step_size, data_batch) + new_state = integrator(rng_key, diffusion_state, step_size, minibatch) return SGLDState(step + 1, *new_state) diff --git a/tests/test_sampling.py b/tests/test_sampling.py index f6e989055..8b9c9a432 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -271,7 +271,7 @@ def test_linear_regression_sgld_cv(self): data_size = 1000 X_data = jax.random.normal(data_key, shape=(data_size, 5)) - centering_position = jnp.ones(5) + centering_position = 1.0 grad_fn = blackjax.sgmcmc.gradients.cv_grad_estimator( self.logprior_fn, self.loglikelihood_fn, X_data, centering_position ) @@ -313,6 +313,29 @@ def test_linear_regression_sghmc(self, learning_rate, error_expected): data_batch = X_data[100:200, :] _ = sghmc.step(rng_key, init_state, data_batch) + def test_linear_regression_sghmc_cv(self): + """Test the HMC kernel and the Stan warmup.""" + import blackjax.sgmcmc.gradients + + rng_key, data_key = jax.random.split(self.key, 2) + + data_size = 1000 + X_data = jax.random.normal(data_key, shape=(data_size, 5)) + + centering_position = 1.0 + grad_fn = blackjax.sgmcmc.gradients.cv_grad_estimator( + self.logprior_fn, self.loglikelihood_fn, X_data, centering_position + ) + + sghmc = blackjax.sghmc(grad_fn, 1e-3, 10) + init_position = 1.0 + data_batch = X_data[:100, :] + init_state = sghmc.init(init_position, data_batch) + + _, rng_key = jax.random.split(rng_key) + data_batch = X_data[100:200, :] + _ = sghmc.step(rng_key, init_state, data_batch) + class LatentGaussianTest(chex.TestCase): """Test sampling of a linear regression model.""" From 31b94b913f9844f7d5a7943144487e0d70bcb32f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 7 Oct 2022 15:23:11 +0200 Subject: [PATCH 3/9] Move gradient computation outside of diffusion integrators This simplifies the solvers a lot. --- blackjax/sgmcmc/diffusion.py | 48 +++++++++++++----------------------- blackjax/sgmcmc/sghmc.py | 37 ++++++++++++++------------- blackjax/sgmcmc/sgld.py | 18 ++++++-------- 3 files changed, 45 insertions(+), 58 deletions(-) diff --git a/blackjax/sgmcmc/diffusion.py b/blackjax/sgmcmc/diffusion.py index f14001126..a85cbce46 100644 --- a/blackjax/sgmcmc/diffusion.py +++ b/blackjax/sgmcmc/diffusion.py @@ -1,29 +1,24 @@ """Solvers for Langevin diffusions.""" -from typing import NamedTuple - import jax import jax.numpy as jnp -from blackjax.sgmcmc.gradients import GradientState from blackjax.types import PRNGKey, PyTree from blackjax.util import generate_gaussian_noise __all__ = ["overdamped_langevin"] -class DiffusionState(NamedTuple): - position: PyTree - logprob_grad: PyTree - grad_estimator_state: GradientState - - -def overdamped_langevin(grad_estimator_fn): +def overdamped_langevin(): """Euler solver for overdamped Langevin diffusion.""" def one_step( - rng_key: PRNGKey, state: DiffusionState, step_size: float, minibatch: tuple = () - ): - position, logprob_grad, grad_estimator_state = state + rng_key: PRNGKey, + position: PyTree, + logprob_grad: PyTree, + step_size: float, + minibatch: tuple = (), + ) -> PyTree: + noise = generate_gaussian_noise(rng_key, position) position = jax.tree_util.tree_map( lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n, @@ -32,22 +27,12 @@ def one_step( noise, ) - logprob_grad, gradient_estimator_state = grad_estimator_fn( - grad_estimator_state, position, minibatch - ) - return DiffusionState(position, logprob_grad, gradient_estimator_state) + return position return one_step -class SGHMCState(NamedTuple): - position: PyTree - momentum: PyTree - logprob_grad: PyTree - grad_estimator_state: GradientState - - -def sghmc(grad_estimator_fn, alpha: float = 0.01, beta: float = 0): +def sghmc(alpha: float = 0.01, beta: float = 0): """Solver for the diffusion equation of the SGHMC algorithm [0]_. References @@ -59,9 +44,13 @@ def sghmc(grad_estimator_fn, alpha: float = 0.01, beta: float = 0): """ def one_step( - rng_key: PRNGKey, state: SGHMCState, step_size: float, minibatch: tuple = () + rng_key: PRNGKey, + position: PyTree, + momentum: PyTree, + logprob_grad: PyTree, + step_size: float, + minibatch: tuple = (), ): - position, momentum, logprob_grad, grad_estimator_state = state noise = generate_gaussian_noise(rng_key, position) position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum) momentum = jax.tree_util.tree_map( @@ -73,9 +62,6 @@ def one_step( noise, ) - logprob_grad, gradient_estimator_state = grad_estimator_fn( - grad_estimator_state, position, minibatch - ) - return SGHMCState(position, momentum, logprob_grad, gradient_estimator_state) + return position, momentum return one_step diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 9eb0ada51..ce658abef 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -2,12 +2,10 @@ from typing import Callable import jax -import jax.numpy as jnp -from blackjax.sgmcmc.diffusion import SGHMCState, sghmc +from blackjax.sgmcmc.diffusion import sghmc from blackjax.sgmcmc.gradients import GradientEstimator from blackjax.sgmcmc.sgld import SGLDState -from blackjax.sgmcmc.sgld import init as sgld_init from blackjax.types import PRNGKey, PyTree from blackjax.util import generate_gaussian_noise @@ -18,27 +16,32 @@ def kernel( gradient_estimator: GradientEstimator, alpha: float = 0.01, beta: float = 0 ) -> Callable: - grad_estimator_fn = gradient_estimator.estimate - integrator = sghmc(grad_estimator_fn) + integrator = sghmc(alpha, beta) def one_step( rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float, L: int ) -> SGLDState: - - step, position, logprob_grad, grad_estimator_state = state - momentum = generate_gaussian_noise(rng_key, position, jnp.sqrt(step_size)) - diffusion_state = SGHMCState( - position, momentum, logprob_grad, grad_estimator_state - ) - def body_fn(state, rng_key): - new_state = integrator(rng_key, state, step_size, minibatch) - return new_state, new_state + position, momentum, grad_estimator_state = state + logprob_grad, grad_estimator_state = gradient_estimator.estimate( + grad_estimator_state, position, minibatch + ) + position, momentum = integrator( + rng_key, position, momentum, logprob_grad, step_size, minibatch + ) + return ( + (position, momentum, grad_estimator_state), + (position, grad_estimator_state), + ) + + step, position, grad_estimator_state = state + momentum = generate_gaussian_noise(rng_key, position, step_size) + init_diffusion_state = (position, momentum, grad_estimator_state) keys = jax.random.split(rng_key, L) - last_state, _ = jax.lax.scan(body_fn, diffusion_state, keys) - position, _, logprob_grad, grad_estimator_state = last_state + last_state, _ = jax.lax.scan(body_fn, init_diffusion_state, keys) + position, _, grad_estimator_state = last_state - return SGLDState(step + 1, position, logprob_grad, grad_estimator_state) + return SGLDState(step + 1, position, grad_estimator_state) return one_step diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index 293745b67..8f5546bf8 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -11,7 +11,6 @@ class SGLDState(NamedTuple): step: int position: PyTree - logprob_grad: PyTree grad_estimator_state: GradientState @@ -22,24 +21,23 @@ class SGLDState(NamedTuple): def init(position: PyTree, minibatch, gradient_estimator: GradientEstimator): grad_estimator_state = gradient_estimator.init(minibatch) - logprob_grad, grad_estimator_state = gradient_estimator.estimate( - grad_estimator_state, position, minibatch - ) - return SGLDState(0, position, logprob_grad, grad_estimator_state) + return SGLDState(0, position, grad_estimator_state) def kernel(gradient_estimator: GradientEstimator) -> Callable: - grad_estimator_fn = gradient_estimator.estimate - integrator = overdamped_langevin(grad_estimator_fn) + integrator = overdamped_langevin() def one_step( rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float ): - step, *diffusion_state = state - new_state = integrator(rng_key, diffusion_state, step_size, minibatch) + step, position, grad_estimator_state = state + logprob_grad, grad_estimator_state = gradient_estimator.estimate( + grad_estimator_state, position, minibatch + ) + new_position = integrator(rng_key, position, logprob_grad, step_size, minibatch) - return SGLDState(step + 1, *new_state) + return SGLDState(step + 1, new_position, grad_estimator_state) return one_step From 2d1b754365dcaf9c17f9d76bce90665cd83e7abe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 7 Oct 2022 15:26:31 +0200 Subject: [PATCH 4/9] Add references to SGMCMCJAX --- blackjax/sgmcmc/diffusion.py | 16 +++++++++++++++- blackjax/sgmcmc/gradients.py | 11 +++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/blackjax/sgmcmc/diffusion.py b/blackjax/sgmcmc/diffusion.py index a85cbce46..46aa14c5b 100644 --- a/blackjax/sgmcmc/diffusion.py +++ b/blackjax/sgmcmc/diffusion.py @@ -9,7 +9,16 @@ def overdamped_langevin(): - """Euler solver for overdamped Langevin diffusion.""" + """Euler solver for overdamped Langevin diffusion. + + This algorithm was ported from [0]_. + + References + ---------- + .. [0]: Coullon, J., & Nemeth, C. (2022). SGMCMCJax: a lightweight JAX + library for stochastic gradient Markov chain Monte Carlo algorithms. + Journal of Open Source Software, 7(72), 4113. + """ def one_step( rng_key: PRNGKey, @@ -35,11 +44,16 @@ def one_step( def sghmc(alpha: float = 0.01, beta: float = 0): """Solver for the diffusion equation of the SGHMC algorithm [0]_. + This algorithm was ported from [1]_. + References ---------- .. [0]: Chen, T., Fox, E., & Guestrin, C. (2014, June). Stochastic gradient hamiltonian monte carlo. In International conference on machine learning (pp. 1683-1691). PMLR. + .. [1]: Coullon, J., & Nemeth, C. (2022). SGMCMCJax: a lightweight JAX + library for stochastic gradient Markov chain Monte Carlo algorithms. + Journal of Open Source Software, 7(72), 4113. """ diff --git a/blackjax/sgmcmc/gradients.py b/blackjax/sgmcmc/gradients.py index 23098b21e..e29032825 100644 --- a/blackjax/sgmcmc/gradients.py +++ b/blackjax/sgmcmc/gradients.py @@ -22,6 +22,8 @@ def grad_estimator( data; if there are several variables (as, for instance, in a supervised learning contexts), they are passed in a tuple. + This algorithm was ported from [2]_. + Parameters ---------- logprior_fn @@ -36,6 +38,10 @@ def grad_estimator( ---------- .. [1]: Robbins H. and Monro S. A stochastic approximation method. Annals of Mathematical Statistics, 22(30):400-407, 1951. + .. [2]: Coullon, J., & Nemeth, C. (2022). SGMCMCJax: a lightweight JAX + library for stochastic gradient Markov chain Monte Carlo algorithms. + Journal of Open Source Software, 7(72), 4113. + """ @@ -89,6 +95,8 @@ def cv_grad_estimator( ) -> GradientEstimator: """Builds a control variate gradient estimator [1]_. + This algorithm was ported from [2]_. + Parameters ---------- logprior_fn @@ -106,6 +114,9 @@ def cv_grad_estimator( .. [1]: Baker, J., Fearnhead, P., Fox, E. B., & Nemeth, C. (2019). Control variates for stochastic gradient MCMC. Statistics and Computing, 29(3), 599-615. + .. [2]: Coullon, J., & Nemeth, C. (2022). SGMCMCJax: a lightweight JAX + library for stochastic gradient Markov chain Monte Carlo algorithms. + Journal of Open Source Software, 7(72), 4113. """ data_size = jax.tree_leaves(data)[0].shape[0] From c176290646de82c1b05a5f58e14cf35e98ec66fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 7 Oct 2022 15:52:16 +0200 Subject: [PATCH 5/9] Remove `schedule_fn` interface from the high-level interface This is impractical in practice. --- blackjax/kernels.py | 110 +++++++++++++++------------------------ blackjax/sgmcmc/sghmc.py | 12 +++-- blackjax/sgmcmc/sgld.py | 12 ++--- examples/SGMCMC.md | 8 +-- tests/test_sampling.py | 35 ++++--------- 5 files changed, 66 insertions(+), 111 deletions(-) diff --git a/blackjax/kernels.py b/blackjax/kernels.py index b45fed624..678c390e3 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -462,10 +462,10 @@ def step_fn(rng_key: PRNGKey, state, delta: float): class sgld: """Implements the (basic) user interface for the SGLD kernel. - The general sgld kernel (:meth:`blackjax.mcmc.sgld.kernel`, alias `blackjax.sgld.kernel`) can be - cumbersome to manipulate. Since most users only need to specify the kernel - parameters at initialization time, we provide a helper function that - specializes the general kernel. + The general sgld kernel (:meth:`blackjax.mcmc.sgld.kernel`, alias + `blackjax.sgld.kernel`) can be cumbersome to manipulate. Since most users + only need to specify the kernel parameters at initialization time, we + provide a helper function that specializes the general kernel. Example ------- @@ -476,35 +476,36 @@ class sgld: .. code:: - schedule_fn = lambda _: 1e-3 grad_fn = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size) We can now initialize the sgld kernel and the state: .. code:: - sgld = blackjax.sgld(grad_fn, schedule_fn) + sgld = blackjax.sgld(grad_fn) state = sgld.init(position) - Assuming we have an iterator `batches` that yields batches of data we can perform one step: + Assuming we have an iterator `batches` that yields batches of data we can + perform one step: .. code:: - data_batch = next(batches) - new_state = sgld.step(rng_key, state, data_batch) + step_size = 1e-3 + minibatch = next(batches) + new_state = sgld.step(rng_key, state, minibatch, step_size) Kernels are not jit-compiled by default so you will need to do it manually: .. code:: step = jax.jit(sgld.step) - new_state, info = step(rng_key, state) + new_state, info = step(rng_key, state, minibatch, step_size) Parameters ---------- - gradient_estimator_fn - A function which, given a position and a batch of data, returns an estimation - of the value of the gradient of the log-posterior distribution at this position. + gradient_estimator + A tuple of functions that initialize and update the gradient estimation + state. schedule_fn A function which returns a step size given a step number. @@ -519,31 +520,16 @@ class sgld: def __new__( # type: ignore[misc] cls, - grad_estimator_fn: Callable, - learning_rate: Union[Callable[[int], float], float], + grad_estimator: sgmcmc.gradients.GradientEstimator, ) -> MCMCSamplingAlgorithm: - step = cls.kernel(grad_estimator_fn) - - if callable(learning_rate): - learning_rate_fn = learning_rate - elif float(learning_rate): + step = cls.kernel(grad_estimator) - def learning_rate_fn(_): - return learning_rate + def init_fn(position: PyTree, minibatch: PyTree): + return cls.init(position, minibatch, grad_estimator) - else: - raise TypeError( - "The learning rate must either be a float (which corresponds to a constant learning rate) " - f"or a function of the index of the current iteration. Got {type(learning_rate)} instead." - ) - - def init_fn(position: PyTree, data_batch: PyTree): - return cls.init(position, data_batch, grad_estimator_fn) - - def step_fn(rng_key: PRNGKey, state, data_batch: PyTree): - step_size = learning_rate_fn(state.step) - return step(rng_key, state, data_batch, step_size) + def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float): + return step(rng_key, state, minibatch, step_size) return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] @@ -551,10 +537,10 @@ def step_fn(rng_key: PRNGKey, state, data_batch: PyTree): class sghmc: """Implements the (basic) user interface for the SGHMC kernel. - The general sghmc kernel (:meth:`blackjax.mcmc.sghmc.kernel`, alias `blackjax.sghmc.kernel`) can be - cumbersome to manipulate. Since most users only need to specify the kernel - parameters at initialization time, we provide a helper function that - specializes the general kernel. + The general sghmc kernel (:meth:`blackjax.mcmc.sghmc.kernel`, alias + `blackjax.sghmc.kernel`) can be cumbersome to manipulate. Since most users + only need to specify the kernel parameters at initialization time, we + provide a helper function that specializes the general kernel. Example ------- @@ -565,35 +551,36 @@ class sghmc: .. code:: - schedule_fn = lambda _: 1e-3 - grad_fn = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size) + grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size) We can now initialize the sghmc kernel and the state. Like HMC, SGHMC needs the user to specify a number of integration steps. .. code:: - sghmc = blackjax.sghmc(grad_fn, schedule_fn, num_integration_steps) + sghmc = blackjax.sghmc(grad_estimator, num_integration_steps) state = sghmc.init(position) - Assuming we have an iterator `batches` that yields batches of data we can perform one step: + Assuming we have an iterator `batches` that yields batches of data we can + perform one step: .. code:: - data_batch = next(batches) - new_state = sghmc.step(rng_key, state, data_batch) + step_size = 1e-3 + minibatch = next(batches) + new_state = sghmc.step(rng_key, state, minibatch, step_size) Kernels are not jit-compiled by default so you will need to do it manually: .. code:: step = jax.jit(sghmc.step) - new_state, info = step(rng_key, state) + new_state, info = step(rng_key, state, minibatch, step_size) Parameters ---------- - gradient_estimator_fn - A function which, given a position and a batch of data, returns an estimation - of the value of the gradient of the log-posterior distribution at this position. + gradient_estimator + A tuple of functions that initialize and update the gradient estimation + state. schedule_fn A function which returns a step size given a step number. @@ -608,32 +595,17 @@ class sghmc: def __new__( # type: ignore[misc] cls, - grad_estimator_fn: Callable, - learning_rate: Union[Callable[[int], float], float], + grad_estimator: sgmcmc.gradients.GradientEstimator, num_integration_steps: int = 10, ) -> MCMCSamplingAlgorithm: - step = cls.kernel(grad_estimator_fn) - - if callable(learning_rate): - learning_rate_fn = learning_rate - elif float(learning_rate): - - def learning_rate_fn(_): - return learning_rate - - else: - raise TypeError( - "The learning rate must either be a float (which corresponds to a constant learning rate) " - f"or a function of the index of the current iteration. Got {type(learning_rate)} instead." - ) + step = cls.kernel(grad_estimator) - def init_fn(position: PyTree, data_batch: PyTree): - return cls.init(position, data_batch, grad_estimator_fn) + def init_fn(position: PyTree, minibatch: PyTree): + return cls.init(position, minibatch, grad_estimator) - def step_fn(rng_key: PRNGKey, state, data_batch: PyTree): - step_size = learning_rate_fn(state.step) - return step(rng_key, state, data_batch, step_size, num_integration_steps) + def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float): + return step(rng_key, state, minibatch, step_size, num_integration_steps) return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index ce658abef..fa4853f43 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -19,7 +19,11 @@ def kernel( integrator = sghmc(alpha, beta) def one_step( - rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float, L: int + rng_key: PRNGKey, + state: SGLDState, + minibatch: PyTree, + step_size: float, + num_integration_steps: int, ) -> SGLDState: def body_fn(state, rng_key): position, momentum, grad_estimator_state = state @@ -34,14 +38,14 @@ def body_fn(state, rng_key): (position, grad_estimator_state), ) - step, position, grad_estimator_state = state + position, grad_estimator_state = state momentum = generate_gaussian_noise(rng_key, position, step_size) init_diffusion_state = (position, momentum, grad_estimator_state) - keys = jax.random.split(rng_key, L) + keys = jax.random.split(rng_key, num_integration_steps) last_state, _ = jax.lax.scan(body_fn, init_diffusion_state, keys) position, _, grad_estimator_state = last_state - return SGLDState(step + 1, position, grad_estimator_state) + return SGLDState(position, grad_estimator_state) return one_step diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index 8f5546bf8..ac3bca239 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -9,19 +9,13 @@ class SGLDState(NamedTuple): - step: int position: PyTree grad_estimator_state: GradientState -# We can compute the gradient at the begining of the kernel step -# This allows to get rid of much of the init function, AND -# Prevents a last useless gradient computation at the last step - - def init(position: PyTree, minibatch, gradient_estimator: GradientEstimator): grad_estimator_state = gradient_estimator.init(minibatch) - return SGLDState(0, position, grad_estimator_state) + return SGLDState(position, grad_estimator_state) def kernel(gradient_estimator: GradientEstimator) -> Callable: @@ -32,12 +26,12 @@ def one_step( rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float ): - step, position, grad_estimator_state = state + position, grad_estimator_state = state logprob_grad, grad_estimator_state = gradient_estimator.estimate( grad_estimator_state, position, minibatch ) new_position = integrator(rng_key, position, logprob_grad, step_size, minibatch) - return SGLDState(step + 1, new_position, grad_estimator_state) + return SGLDState(new_position, grad_estimator_state) return one_step diff --git a/examples/SGMCMC.md b/examples/SGMCMC.md index 77bd930cd..5e6380444 100644 --- a/examples/SGMCMC.md +++ b/examples/SGMCMC.md @@ -161,7 +161,7 @@ init_positions = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1])) # Build the SGLD kernel with a constant learning rate grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size) -sgld = blackjax.sgld(grad_fn, lambda _: step_size) +sgld = blackjax.sgld(grad_fn) state = sgld.init(init_positions, next(batches)) @@ -172,7 +172,7 @@ steps = [] for step in progress_bar(range(num_samples + num_warmup)): _, rng_key = jax.random.split(rng_key) batch = next(batches) - state = jax.jit(sgld.step)(rng_key, state, batch) + state = jax.jit(sgld.step)(rng_key, state, batch, step_size) if step % 100 == 0: accuracy = compute_accuracy(state.position, X_test, y_test) accuracies.append(accuracy) @@ -208,7 +208,7 @@ We can also use SGHMC to samples from this model # Build the SGHMC kernel with a constant learning rate step_size = 9e-6 grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size) -sghmc = blackjax.sghmc(grad_fn, lambda _: step_size) +sghmc = blackjax.sghmc(grad_fn) # Batch the data state = sghmc.init(init_positions, next(batches)) @@ -220,7 +220,7 @@ steps = [] for step in progress_bar(range(num_samples + num_warmup)): _, rng_key = jax.random.split(rng_key) batch = next(batches) - state = jax.jit(sghmc.step)(rng_key, state, batch) + state = jax.jit(sghmc.step)(rng_key, state, batch, step_size) if step % 100 == 0: sghmc_accuracy = compute_accuracy(state.position, X_test, y_test) sghmc_accuracies.append(sghmc_accuracy) diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 8b9c9a432..435f3c8f4 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -235,9 +235,7 @@ def loglikelihood_fn(self, position, x): def constant_step_size(_): return 1e-3 - @parameterized.parameters((1e-3, False), (constant_step_size, False), (1, True)) - def test_linear_regression_sgld(self, learning_rate, error_expected): - """Test the HMC kernel and the Stan warmup.""" + def test_linear_regression_sgld(self): import blackjax.sgmcmc.gradients rng_key, data_key = jax.random.split(self.key, 2) @@ -248,22 +246,17 @@ def test_linear_regression_sgld(self, learning_rate, error_expected): grad_fn = blackjax.sgmcmc.gradients.grad_estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) + sgld = blackjax.sgld(grad_fn) - if error_expected: - self.assertRaises(TypeError, blackjax.sgld(grad_fn, learning_rate)) - return - - sgld = blackjax.sgld(grad_fn, learning_rate) init_position = 1.0 data_batch = X_data[:100, :] init_state = sgld.init(init_position, data_batch) _, rng_key = jax.random.split(rng_key) data_batch = X_data[100:200, :] - _ = sgld.step(rng_key, init_state, data_batch) + _ = sgld.step(rng_key, init_state, data_batch, 1e-3) def test_linear_regression_sgld_cv(self): - """Test the HMC kernel and the Stan warmup.""" import blackjax.sgmcmc.gradients rng_key, data_key = jax.random.split(self.key, 2) @@ -276,18 +269,16 @@ def test_linear_regression_sgld_cv(self): self.logprior_fn, self.loglikelihood_fn, X_data, centering_position ) - sgld = blackjax.sgld(grad_fn, 1e-3) + sgld = blackjax.sgld(grad_fn) init_position = 1.0 data_batch = X_data[:100, :] init_state = sgld.init(init_position, data_batch) _, rng_key = jax.random.split(rng_key) data_batch = X_data[100:200, :] - _ = sgld.step(rng_key, init_state, data_batch) + _ = sgld.step(rng_key, init_state, data_batch, 1e-3) - @parameterized.parameters((1e-3, False), (constant_step_size, False), (1, True)) - def test_linear_regression_sghmc(self, learning_rate, error_expected): - """Test the HMC kernel and the Stan warmup.""" + def test_linear_regression_sghmc(self): import blackjax.sgmcmc.gradients rng_key, data_key = jax.random.split(self.key, 2) @@ -298,12 +289,7 @@ def test_linear_regression_sghmc(self, learning_rate, error_expected): grad_fn = blackjax.sgmcmc.gradients.grad_estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) - - if error_expected: - self.assertRaises(TypeError, blackjax.sgld(grad_fn, learning_rate)) - return - - sghmc = blackjax.sghmc(grad_fn, learning_rate, 10) + sghmc = blackjax.sghmc(grad_fn, 10) init_position = 1.0 data_batch = X_data[:100, :] @@ -311,10 +297,9 @@ def test_linear_regression_sghmc(self, learning_rate, error_expected): _, rng_key = jax.random.split(rng_key) data_batch = X_data[100:200, :] - _ = sghmc.step(rng_key, init_state, data_batch) + _ = sghmc.step(rng_key, init_state, data_batch, 1e-3) def test_linear_regression_sghmc_cv(self): - """Test the HMC kernel and the Stan warmup.""" import blackjax.sgmcmc.gradients rng_key, data_key = jax.random.split(self.key, 2) @@ -327,14 +312,14 @@ def test_linear_regression_sghmc_cv(self): self.logprior_fn, self.loglikelihood_fn, X_data, centering_position ) - sghmc = blackjax.sghmc(grad_fn, 1e-3, 10) + sghmc = blackjax.sghmc(grad_fn, 10) init_position = 1.0 data_batch = X_data[:100, :] init_state = sghmc.init(init_position, data_batch) _, rng_key = jax.random.split(rng_key) data_batch = X_data[100:200, :] - _ = sghmc.step(rng_key, init_state, data_batch) + _ = sghmc.step(rng_key, init_state, data_batch, 1e-3) class LatentGaussianTest(chex.TestCase): From fb02c8ca10debf86c318d3af7199ebfac96cf88c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sat, 22 Oct 2022 16:34:12 +0200 Subject: [PATCH 6/9] Build Control Variates by wrapping the Robbins-Monro estimator --- blackjax/kernels.py | 33 +++++++------- blackjax/sgmcmc/__init__.py | 4 +- blackjax/sgmcmc/gradients.py | 87 +++++++++--------------------------- blackjax/sgmcmc/sghmc.py | 31 +++++-------- blackjax/sgmcmc/sgld.py | 31 +++++-------- tests/test_sampling.py | 56 +++++++++++------------ 6 files changed, 88 insertions(+), 154 deletions(-) diff --git a/blackjax/kernels.py b/blackjax/kernels.py index 678c390e3..c5f12815a 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -515,23 +515,19 @@ class sgld: """ - init = staticmethod(sgmcmc.sgld.init) kernel = staticmethod(sgmcmc.sgld.kernel) def __new__( # type: ignore[misc] cls, grad_estimator: sgmcmc.gradients.GradientEstimator, - ) -> MCMCSamplingAlgorithm: - - step = cls.kernel(grad_estimator) + ) -> Callable: - def init_fn(position: PyTree, minibatch: PyTree): - return cls.init(position, minibatch, grad_estimator) + step = cls.kernel() def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float): - return step(rng_key, state, minibatch, step_size) + return step(rng_key, state, grad_estimator, minibatch, step_size) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return step_fn class sghmc: @@ -590,24 +586,27 @@ class sghmc: """ - init = staticmethod(sgmcmc.sgld.init) kernel = staticmethod(sgmcmc.sghmc.kernel) def __new__( # type: ignore[misc] cls, - grad_estimator: sgmcmc.gradients.GradientEstimator, + grad_estimator: Callable, num_integration_steps: int = 10, - ) -> MCMCSamplingAlgorithm: + ) -> Callable: - step = cls.kernel(grad_estimator) - - def init_fn(position: PyTree, minibatch: PyTree): - return cls.init(position, minibatch, grad_estimator) + step = cls.kernel() def step_fn(rng_key: PRNGKey, state, minibatch: PyTree, step_size: float): - return step(rng_key, state, minibatch, step_size, num_integration_steps) + return step( + rng_key, + state, + grad_estimator, + minibatch, + step_size, + num_integration_steps, + ) - return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return step_fn # ----------------------------------------------------------------------------- diff --git a/blackjax/sgmcmc/__init__.py b/blackjax/sgmcmc/__init__.py index 0f0046c00..40a748c13 100644 --- a/blackjax/sgmcmc/__init__.py +++ b/blackjax/sgmcmc/__init__.py @@ -1,3 +1,3 @@ -from . import sghmc, sgld +from . import gradients, sghmc, sgld -__all__ = ["sgld", "sghmc"] +__all__ = ["gradients", "sgld", "sghmc"] diff --git a/blackjax/sgmcmc/gradients.py b/blackjax/sgmcmc/gradients.py index e29032825..73e65a211 100644 --- a/blackjax/sgmcmc/gradients.py +++ b/blackjax/sgmcmc/gradients.py @@ -1,4 +1,4 @@ -from typing import Callable, NamedTuple, Tuple, Union +from typing import Callable, NamedTuple import jax import jax.numpy as jnp @@ -11,7 +11,7 @@ class GradientEstimator(NamedTuple): estimate: Callable -def grad_estimator( +def estimator( logprior_fn: Callable, loglikelihood_fn: Callable, data_size: int ) -> GradientEstimator: """Builds a simple gradient estimator. @@ -45,12 +45,7 @@ def grad_estimator( """ - def init_fn(_) -> None: - return None - - def logposterior_estimator_fn( - position: PyTree, minibatch: PyTree - ) -> Tuple[PyTree, None]: + def logposterior_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: """Returns an approximation of the log-posterior density. Parameters @@ -72,38 +67,22 @@ def logposterior_estimator_fn( batch_loglikelihood(position, minibatch), axis=0 ) - def grad_estimator_fn(_, position, data_batch): - return jax.grad(logposterior_estimator_fn)(position, data_batch), None - - return GradientEstimator(init_fn, grad_estimator_fn) - - -class CVGradientState(NamedTuple): - """The state of the CV gradient estimator contains the gradient of the - Control Variate computed on the whole dataset at initialization. - - """ - - control_variate_grad: PyTree + return jax.grad(logposterior_estimator_fn) -def cv_grad_estimator( - logprior_fn: Callable, - loglikelihood_fn: Callable, - data: PyTree, +def control_variates( + grad_estimator: Callable, centering_position: PyTree, -) -> GradientEstimator: + data: PyTree, +) -> Callable: """Builds a control variate gradient estimator [1]_. This algorithm was ported from [2]_. Parameters ---------- - logprior_fn - The log-probability density function corresponding to the prior - distribution. - loglikelihood_fn - The log-probability density function corresponding to the likelihood. + grad_estimator + A function that approximates the target's gradient function. data The full dataset. centering_position @@ -119,20 +98,10 @@ def cv_grad_estimator( Journal of Open Source Software, 7(72), 4113. """ - data_size = jax.tree_leaves(data)[0].shape[0] - logposterior_grad_estimator_fn = grad_estimator( - logprior_fn, loglikelihood_fn, data_size - ).estimate - - def init_fn(full_dataset: PyTree) -> CVGradientState: - """Compute the control variate on the whole dataset.""" - return CVGradientState( - logposterior_grad_estimator_fn(None, centering_position, full_dataset)[0] - ) - def grad_estimator_fn( - grad_estimator_state: CVGradientState, position: PyTree, minibatch: PyTree - ) -> Tuple[PyTree, CVGradientState]: + cv_grad_value = grad_estimator(centering_position, data) + + def cv_grad_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: """Return an approximation of the log-posterior density. Parameters @@ -149,26 +118,14 @@ def grad_estimator_fn( the current value of the random variables. """ - logposterior_grad_estimate = logposterior_grad_estimator_fn( - None, position, minibatch - )[0] - logposterior_grad_center_estimate = logposterior_grad_estimator_fn( - None, centering_position, minibatch - )[0] - - def control_variate(grad_estimate, center_grad_estimate, center_grad): - return grad_estimate + center_grad - center_grad_estimate - - return ( - control_variate( - logposterior_grad_estimate, - grad_estimator_state.control_variate_grad, - logposterior_grad_center_estimate, - ), - grad_estimator_state, + grad_estimate = grad_estimator(position, minibatch) + center_grad_estimate = grad_estimator(centering_position, minibatch) + + return jax.tree_map( + lambda grad_est, cv_grad_est, cv_grad: cv_grad + grad_est - cv_grad_est, + grad_estimate, + center_grad_estimate, + cv_grad_value, ) - return GradientEstimator(init_fn, grad_estimator_fn) - - -GradientState = Union[None, CVGradientState] + return cv_grad_estimator_fn diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index fa4853f43..0e7ea55c1 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -4,48 +4,37 @@ import jax from blackjax.sgmcmc.diffusion import sghmc -from blackjax.sgmcmc.gradients import GradientEstimator -from blackjax.sgmcmc.sgld import SGLDState from blackjax.types import PRNGKey, PyTree from blackjax.util import generate_gaussian_noise __all__ = ["kernel"] -def kernel( - gradient_estimator: GradientEstimator, alpha: float = 0.01, beta: float = 0 -) -> Callable: +def kernel(alpha: float = 0.01, beta: float = 0) -> Callable: + """Stochastic gradient Hamiltonian Monte Carlo (SgHMC) algorithm.""" integrator = sghmc(alpha, beta) def one_step( rng_key: PRNGKey, - state: SGLDState, + position: PyTree, + grad_estimator: Callable, minibatch: PyTree, step_size: float, num_integration_steps: int, - ) -> SGLDState: + ) -> PyTree: def body_fn(state, rng_key): - position, momentum, grad_estimator_state = state - logprob_grad, grad_estimator_state = gradient_estimator.estimate( - grad_estimator_state, position, minibatch - ) + position, momentum = state + logprob_grad = grad_estimator(position, minibatch) position, momentum = integrator( rng_key, position, momentum, logprob_grad, step_size, minibatch ) - return ( - (position, momentum, grad_estimator_state), - (position, grad_estimator_state), - ) + return ((position, momentum), position) - position, grad_estimator_state = state momentum = generate_gaussian_noise(rng_key, position, step_size) - init_diffusion_state = (position, momentum, grad_estimator_state) - keys = jax.random.split(rng_key, num_integration_steps) - last_state, _ = jax.lax.scan(body_fn, init_diffusion_state, keys) - position, _, grad_estimator_state = last_state + position, _ = jax.lax.scan(body_fn, (position, momentum), keys) - return SGLDState(position, grad_estimator_state) + return position return one_step diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index ac3bca239..f066177ad 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -1,37 +1,28 @@ """Public API for the Stochastic gradient Langevin Dynamics kernel.""" -from typing import Callable, NamedTuple +from typing import Callable from blackjax.sgmcmc.diffusion import overdamped_langevin -from blackjax.sgmcmc.gradients import GradientEstimator, GradientState from blackjax.types import PRNGKey, PyTree -__all__ = ["SGLDState", "init", "kernel"] +__all__ = ["kernel"] -class SGLDState(NamedTuple): - position: PyTree - grad_estimator_state: GradientState - - -def init(position: PyTree, minibatch, gradient_estimator: GradientEstimator): - grad_estimator_state = gradient_estimator.init(minibatch) - return SGLDState(position, grad_estimator_state) - - -def kernel(gradient_estimator: GradientEstimator) -> Callable: +def kernel() -> Callable: + """Stochastic gradient Langevin Dynamics (SgLD) algorithm.""" integrator = overdamped_langevin() def one_step( - rng_key: PRNGKey, state: SGLDState, minibatch: PyTree, step_size: float + rng_key: PRNGKey, + position: PyTree, + grad_estimator: Callable, + minibatch: PyTree, + step_size: float, ): - position, grad_estimator_state = state - logprob_grad, grad_estimator_state = gradient_estimator.estimate( - grad_estimator_state, position, minibatch - ) + logprob_grad = grad_estimator(position, minibatch) new_position = integrator(rng_key, position, logprob_grad, step_size, minibatch) - return SGLDState(new_position, grad_estimator_state) + return new_position return one_step diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 435f3c8f4..89337beee 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -243,18 +243,15 @@ def test_linear_regression_sgld(self): data_size = 1000 X_data = jax.random.normal(data_key, shape=(data_size, 5)) - grad_fn = blackjax.sgmcmc.gradients.grad_estimator( + grad_fn = blackjax.sgmcmc.gradients.estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) sgld = blackjax.sgld(grad_fn) - init_position = 1.0 - data_batch = X_data[:100, :] - init_state = sgld.init(init_position, data_batch) - _, rng_key = jax.random.split(rng_key) - data_batch = X_data[100:200, :] - _ = sgld.step(rng_key, init_state, data_batch, 1e-3) + data_batch = X_data[:100, :] + init_position = 1.0 + _ = sgld(rng_key, init_position, data_batch, 1e-3) def test_linear_regression_sgld_cv(self): import blackjax.sgmcmc.gradients @@ -265,18 +262,20 @@ def test_linear_regression_sgld_cv(self): X_data = jax.random.normal(data_key, shape=(data_size, 5)) centering_position = 1.0 - grad_fn = blackjax.sgmcmc.gradients.cv_grad_estimator( - self.logprior_fn, self.loglikelihood_fn, X_data, centering_position + + grad_fn = blackjax.sgmcmc.gradients.estimator( + self.logprior_fn, self.loglikelihood_fn, data_size + ) + cv_grad_fn = blackjax.sgmcmc.gradients.control_variates( + grad_fn, centering_position, X_data ) - sgld = blackjax.sgld(grad_fn) - init_position = 1.0 - data_batch = X_data[:100, :] - init_state = sgld.init(init_position, data_batch) + sgld = blackjax.sgld(cv_grad_fn) _, rng_key = jax.random.split(rng_key) - data_batch = X_data[100:200, :] - _ = sgld.step(rng_key, init_state, data_batch, 1e-3) + init_position = 1.0 + data_batch = X_data[:100, :] + _ = sgld(rng_key, init_position, data_batch, 1e-3) def test_linear_regression_sghmc(self): import blackjax.sgmcmc.gradients @@ -286,18 +285,16 @@ def test_linear_regression_sghmc(self): data_size = 1000 X_data = jax.random.normal(data_key, shape=(data_size, 5)) - grad_fn = blackjax.sgmcmc.gradients.grad_estimator( + grad_fn = blackjax.sgmcmc.gradients.estimator( self.logprior_fn, self.loglikelihood_fn, data_size ) sghmc = blackjax.sghmc(grad_fn, 10) - init_position = 1.0 - data_batch = X_data[:100, :] - init_state = sghmc.init(init_position, data_batch) - _, rng_key = jax.random.split(rng_key) data_batch = X_data[100:200, :] - _ = sghmc.step(rng_key, init_state, data_batch, 1e-3) + init_position = 1.0 + data_batch = X_data[:100, :] + _ = sghmc(rng_key, init_position, data_batch, 1e-3) def test_linear_regression_sghmc_cv(self): import blackjax.sgmcmc.gradients @@ -308,18 +305,19 @@ def test_linear_regression_sghmc_cv(self): X_data = jax.random.normal(data_key, shape=(data_size, 5)) centering_position = 1.0 - grad_fn = blackjax.sgmcmc.gradients.cv_grad_estimator( - self.logprior_fn, self.loglikelihood_fn, X_data, centering_position + grad_fn = blackjax.sgmcmc.gradients.estimator( + self.logprior_fn, self.loglikelihood_fn, data_size + ) + cv_grad_fn = blackjax.sgmcmc.gradients.control_variates( + grad_fn, centering_position, X_data ) - sghmc = blackjax.sghmc(grad_fn, 10) - init_position = 1.0 - data_batch = X_data[:100, :] - init_state = sghmc.init(init_position, data_batch) + sghmc = blackjax.sghmc(cv_grad_fn, 10) _, rng_key = jax.random.split(rng_key) - data_batch = X_data[100:200, :] - _ = sghmc.step(rng_key, init_state, data_batch, 1e-3) + init_position = 1.0 + data_batch = X_data[:100, :] + _ = sghmc(rng_key, init_position, data_batch, 1e-3) class LatentGaussianTest(chex.TestCase): From d996dbe076a794a2db0e7a2a21ae2b7eac733eac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sat, 22 Oct 2022 17:04:41 +0200 Subject: [PATCH 7/9] s/diffusion/diffusions --- blackjax/mcmc/{diffusion.py => diffusions.py} | 0 blackjax/mcmc/mala.py | 4 ++-- blackjax/sgmcmc/{diffusion.py => diffusions.py} | 0 blackjax/sgmcmc/sghmc.py | 4 ++-- blackjax/sgmcmc/sgld.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) rename blackjax/mcmc/{diffusion.py => diffusions.py} (100%) rename blackjax/sgmcmc/{diffusion.py => diffusions.py} (100%) diff --git a/blackjax/mcmc/diffusion.py b/blackjax/mcmc/diffusions.py similarity index 100% rename from blackjax/mcmc/diffusion.py rename to blackjax/mcmc/diffusions.py diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index d8ff0bb3f..26618b5e3 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -5,7 +5,7 @@ import jax import jax.numpy as jnp -from blackjax.mcmc.diffusion import overdamped_langevin +import blackjax.mcmc.diffusions as diffusions from blackjax.types import PRNGKey, PyTree __all__ = ["MALAState", "MALAInfo", "init", "kernel"] @@ -83,7 +83,7 @@ def one_step( """ grad_fn = jax.value_and_grad(logprob_fn) - integrator = overdamped_langevin(grad_fn) + integrator = diffusions.overdamped_langevin(grad_fn) key_integrator, key_rmh = jax.random.split(rng_key) diff --git a/blackjax/sgmcmc/diffusion.py b/blackjax/sgmcmc/diffusions.py similarity index 100% rename from blackjax/sgmcmc/diffusion.py rename to blackjax/sgmcmc/diffusions.py diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 0e7ea55c1..c92442f65 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -3,7 +3,7 @@ import jax -from blackjax.sgmcmc.diffusion import sghmc +import blackjax.sgmcmc.diffusions as diffusions from blackjax.types import PRNGKey, PyTree from blackjax.util import generate_gaussian_noise @@ -13,7 +13,7 @@ def kernel(alpha: float = 0.01, beta: float = 0) -> Callable: """Stochastic gradient Hamiltonian Monte Carlo (SgHMC) algorithm.""" - integrator = sghmc(alpha, beta) + integrator = diffusions.sghmc(alpha, beta) def one_step( rng_key: PRNGKey, diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index f066177ad..66ec4e563 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -1,7 +1,7 @@ """Public API for the Stochastic gradient Langevin Dynamics kernel.""" from typing import Callable -from blackjax.sgmcmc.diffusion import overdamped_langevin +import blackjax.sgmcmc.diffusions as diffusions from blackjax.types import PRNGKey, PyTree __all__ = ["kernel"] @@ -10,7 +10,7 @@ def kernel() -> Callable: """Stochastic gradient Langevin Dynamics (SgLD) algorithm.""" - integrator = overdamped_langevin() + integrator = diffusions.overdamped_langevin() def one_step( rng_key: PRNGKey, From d0b3d8bd455bcc4ee8ca3cacf54b5ae0bb6423cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sun, 20 Nov 2022 18:32:06 +0100 Subject: [PATCH 8/9] Simplify resampling test --- tests/test_resampling.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_resampling.py b/tests/test_resampling.py index b75c017cb..fb0cb871c 100644 --- a/tests/test_resampling.py +++ b/tests/test_resampling.py @@ -1,6 +1,4 @@ """Test the resampling functions for SMC.""" -import itertools - import chex import jax import numpy as np @@ -28,10 +26,10 @@ def integrand(x): class ResamplingTest(chex.TestCase): @chex.all_variants(with_pmap=False) - @parameterized.parameters( - itertools.product([100, 500, 1_000, 100_000], resampling_methods.keys()) - ) - def test_resampling_methods(self, N, method_name): + @parameterized.parameters(resampling_methods.keys()) + def test_resampling_methods(self, method_name): + N = 10_000 + np.random.seed(42) batch_size = 100 w = np.random.rand(N) From 8684689dede7639a73b99d124469832962f64c65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sun, 20 Nov 2022 18:33:05 +0100 Subject: [PATCH 9/9] Simplify ESS tests --- tests/test_smc_ess.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_smc_ess.py b/tests/test_smc_ess.py index ed4b19520..e8ec5fb64 100644 --- a/tests/test_smc_ess.py +++ b/tests/test_smc_ess.py @@ -16,7 +16,7 @@ class SMCEffectiveSampleSizeTest(chex.TestCase): @chex.all_variants(with_pmap=False) - @parameterized.parameters([100, 1000, 5000]) + @parameterized.parameters([1000, 5000]) def test_ess(self, N): log_ess_fn = self.variant(functools.partial(ess.ess, log=True)) ess_fn = self.variant(functools.partial(ess.ess, log=False)) @@ -39,7 +39,7 @@ def test_ess(self, N): ) @chex.all_variants(with_pmap=False) - @parameterized.parameters(itertools.product([0.25, 0.5], [100, 1000, 5000])) + @parameterized.parameters(itertools.product([0.25, 0.5], [1000, 5000])) def test_ess_solver(self, target_ess, N): potential_fn = lambda pytree: -univariate_logpdf(pytree, scale=0.1) potential = jax.vmap(lambda x: potential_fn(x), in_axes=[0]) @@ -47,7 +47,7 @@ def test_ess_solver(self, target_ess, N): self.ess_solver_test_case(potential, particles, target_ess, N, 1.0) @chex.all_variants(with_pmap=False) - @parameterized.parameters(itertools.product([0.25, 0.5], [100, 1000, 5000])) + @parameterized.parameters(itertools.product([0.25, 0.5], [1000, 5000])) def test_ess_solver_multivariate(self, target_ess, N): """ Posterior with more than one variable. Let's assume we want to @@ -63,7 +63,7 @@ def test_ess_solver_multivariate(self, target_ess, N): self.ess_solver_test_case(potential, particles, target_ess, N, 10.0) @chex.all_variants(with_pmap=False) - @parameterized.parameters(itertools.product([0.25, 0.5], [100, 1000, 5000])) + @parameterized.parameters(itertools.product([0.25, 0.5], [1000, 5000])) def test_ess_solver_posterior_signature(self, target_ess, N): """ Posterior with more than one variable. Let's assume we want to