Skip to content

Commit

Permalink
WIP - Add gradient state
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 20, 2022
1 parent 45375a0 commit b3d5ee9
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 21 deletions.
19 changes: 11 additions & 8 deletions blackjax/sgmcmc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def one_step(
noise,
)

logprob_grad = logprob_grad_fn(position, batch)
return DiffusionState(position, logprob_grad)
gradient_state, logprob_grad = logprob_grad_fn(position, batch)
return DiffusionState(position, logprob_grad), gradient_state

return one_step

Expand All @@ -60,7 +60,7 @@ def sghmc(logprob_grad_fn, alpha: float = 0.01, beta: float = 0):

def one_step(
rng_key: PRNGKey, state: SGHMCState, step_size: float, batch: tuple = ()
) -> SGHMCState:
):
position, momentum, logprob_grad = state
noise = generate_gaussian_noise(rng_key, position)
position = jax.tree_util.tree_map(lambda x, p: x + p, position, momentum)
Expand All @@ -73,11 +73,14 @@ def one_step(
noise,
)

logprob_grad = logprob_grad_fn(position, batch)
return SGHMCState(
position,
momentum,
logprob_grad,
gradient_state, logprob_grad = logprob_grad_fn(position, batch)
return (
SGHMCState(
position,
momentum,
logprob_grad,
),
gradient_state,
)

return one_step
19 changes: 13 additions & 6 deletions blackjax/sgmcmc/gradients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -61,6 +61,10 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
return logprob_grad


class GradientState(NamedTuple):
control_variate_grad: PyTree


def cv_grad_estimator(
logprior_fn: Callable,
loglikelihood_fn: Callable,
Expand Down Expand Up @@ -93,10 +97,13 @@ def cv_grad_estimator(
logprior_fn, loglikelihood_fn, data_size
)

# Control Variates use the gradient on the full dataset
logposterior_grad_center = logposterior_grad_estimator_fn(centering_position, data)
def init(data):
"""Compute the control variate on the whole dataset."""
return GradientState(logposterior_grad_estimator_fn(centering_position, data))

def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
def logposterior_estimator_fn(
gradient_state: GradientState, position: PyTree, data_batch: PyTree
):
"""Return an approximation of the log-posterior density.
Parameters
Expand All @@ -123,10 +130,10 @@ def logposterior_estimator_fn(position: PyTree, data_batch: PyTree) -> float:
def control_variate(grad_estimate, center_grad_estimate, center_grad):
return grad_estimate + center_grad - center_grad_estimate

return jax.tree_util.tree_map(
return gradient_state, jax.tree_util.tree_map(
control_variate,
logposterior_grad_estimate,
logposterior_grad_center,
gradient_state.control_variate_grad,
logposterior_grad_center_estimate,
)

Expand Down
12 changes: 10 additions & 2 deletions blackjax/sgmcmc/sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from blackjax.sgmcmc.diffusion import SGHMCState, sghmc
from blackjax.sgmcmc.sgld import SGLDState
from blackjax.sgmcmc.sgld import init as sgld_init
from blackjax.types import PRNGKey, PyTree

__all__ = ["kernel"]
Expand All @@ -19,6 +20,9 @@ def sample_momentum(rng_key: PRNGKey, position: PyTree, step_size: float):
return unravel_fn(noise_flat)


init = sgld_init


def kernel(
grad_estimator_fn: Callable, alpha: float = 0.01, beta: float = 0
) -> Callable:
Expand All @@ -28,7 +32,7 @@ def one_step(
rng_key: PRNGKey, state: SGLDState, data_batch: PyTree, step_size: float, L: int
) -> SGLDState:

step, position, logprob_grad = state
step, position, logprob_grad, _ = state
momentum = sample_momentum(rng_key, position, step_size)
diffusion_state = SGHMCState(position, momentum, logprob_grad)

Expand All @@ -39,6 +43,10 @@ def body_fn(state, rng_key):
keys = jax.random.split(rng_key, L)
last_state, _ = jax.lax.scan(body_fn, diffusion_state, keys)

return SGLDState(step + 1, last_state.position, last_state.logprob_grad)
position = last_state.position
logprob_grad = last_state.logprob_grad
gradient_state = last_state[-1]

return SGLDState(step + 1, position, logprob_grad, gradient_state)

return one_step
19 changes: 14 additions & 5 deletions blackjax/sgmcmc/sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, NamedTuple

from blackjax.sgmcmc.diffusion import overdamped_langevin
from blackjax.sgmcmc.gradients import GradientState
from blackjax.types import PRNGKey, PyTree

__all__ = ["SGLDState", "init", "kernel"]
Expand All @@ -11,23 +12,31 @@ class SGLDState(NamedTuple):
step: int
position: PyTree
logprob_grad: PyTree
gradient_state: GradientState


# We can compute the gradient at the begining of the kernel step
# This allows to get rid of much of the init function, AND
# Prevents a last useless gradient computation at the last step


def init(position: PyTree, batch, grad_estimator_fn: Callable):
logprob_grad = grad_estimator_fn(position, batch)
return SGLDState(0, position, logprob_grad)
gradient_state, logprob_grad = grad_estimator_fn(position, batch)
return SGLDState(0, position, logprob_grad, gradient_state)


def kernel(grad_estimator_fn: Callable) -> Callable:
integrator = overdamped_langevin(grad_estimator_fn)

def one_step(
rng_key: PRNGKey, state: SGLDState, data_batch: PyTree, step_size: float
) -> SGLDState:
):

step, *diffusion_state = state
new_state = integrator(rng_key, diffusion_state, step_size, data_batch)
new_state, gradient_state = integrator(
rng_key, diffusion_state, step_size, data_batch
)

return SGLDState(step + 1, *new_state)
return SGLDState(step + 1, *new_state), gradient_state

return one_step

0 comments on commit b3d5ee9

Please sign in to comment.