Skip to content

Commit

Permalink
Add the control variates gradient estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 19, 2022
1 parent 1beea20 commit 45375a0
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
72 changes: 72 additions & 0 deletions blackjax/sgmcmc/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,75 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
logprob_grad = jax.grad(logposterior_estimator_fn)

return logprob_grad


def cv_grad_estimator(
logprior_fn: Callable,
loglikelihood_fn: Callable,
data: PyTree,
centering_position: PyTree,
) -> Callable:
"""Builds a control variate gradient estimator [1]_.
Parameters
----------
logprior_fn
The log-probability density function corresponding to the prior
distribution.
loglikelihood_fn
The log-probability density function corresponding to the likelihood.
data
The full dataset.
centering_position
Centering position for the control variates (typically the MAP).
References
----------
.. [1]: Baker, J., Fearnhead, P., Fox, E. B., & Nemeth, C. (2019).
Control variates for stochastic gradient MCMC. Statistics
and Computing, 29(3), 599-615.
"""
data_size = jax.tree_leaves(data)[0].shape[0]
logposterior_grad_estimator_fn = grad_estimator(
logprior_fn, loglikelihood_fn, data_size
)

# Control Variates use the gradient on the full dataset
logposterior_grad_center = logposterior_grad_estimator_fn(centering_position, data)

def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
"""Return an approximation of the log-posterior density.
Parameters
----------
position
The current value of the random variables.
batch
The current batch of data. The first dimension is assumed to be the
batch dimension.
Returns
-------
An approximation of the value of the log-posterior density function for
the current value of the random variables.
"""
logposterior_grad_estimate = logposterior_grad_estimator_fn(
position, data_batch
)
logposterior_grad_center_estimate = logposterior_grad_estimator_fn(
centering_position, data_batch
)

def control_variate(grad_estimate, center_grad_estimate, center_grad):
return grad_estimate + center_grad - center_grad_estimate

return jax.tree_util.tree_map(
control_variate,
logposterior_grad_estimate,
logposterior_grad_center,
logposterior_grad_center_estimate,
)

return logposterior_estimator_fn
23 changes: 23 additions & 0 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,29 @@ def test_linear_regression_sgld(self, learning_rate, error_expected):
data_batch = X_data[100:200, :]
_ = sgld.step(rng_key, init_state, data_batch)

def test_linear_regression_sgld_cv(self):
"""Test the HMC kernel and the Stan warmup."""
import blackjax.sgmcmc.gradients

rng_key, data_key = jax.random.split(self.key, 2)

data_size = 1000
X_data = jax.random.normal(data_key, shape=(data_size, 5))

centering_position = jnp.ones(5)
grad_fn = blackjax.sgmcmc.gradients.cv_grad_estimator(
self.logprior_fn, self.loglikelihood_fn, X_data, centering_position
)

sgld = blackjax.sgld(grad_fn, 1e-3)
init_position = 1.0
data_batch = X_data[:100, :]
init_state = sgld.init(init_position, data_batch)

_, rng_key = jax.random.split(rng_key)
data_batch = X_data[100:200, :]
_ = sgld.step(rng_key, init_state, data_batch)

@parameterized.parameters((1e-3, False), (constant_step_size, False), (1, True))
def test_linear_regression_sghmc(self, learning_rate, error_expected):
"""Test the HMC kernel and the Stan warmup."""
Expand Down

0 comments on commit 45375a0

Please sign in to comment.