Skip to content

Commit

Permalink
Update noneuclidean integrator (#626)
Browse files Browse the repository at this point in the history
* Update noneuclidean integrator

to match default implmentation in https://github.com/JakobRobnik/MicroCanonicalHMC/blob/master/mclmc/dynamics.py

* fix formatting

* Fix #625
  • Loading branch information
junpenglao authored Dec 11, 2023
1 parent d8fd15a commit fac1d5e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
5 changes: 2 additions & 3 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def esh_dynamics_momentum_update_one_step(
There are no exponentials e^delta, which prevents overflows when the gradient norm
is large.
"""
del is_last_call

flatten_grads, unravel_fn = ravel_pytree(logdensity_grad)
flatten_momentum, _ = ravel_pytree(momentum)
Expand All @@ -324,11 +325,9 @@ def esh_dynamics_momentum_update_one_step(
delta
- jnp.log(2)
+ jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2)
)
) * (dims - 1)
if previous_kinetic_energy_change is not None:
kinetic_energy_change += previous_kinetic_energy_change
if is_last_call:
kinetic_energy_change *= dims - 1
return next_momentum, next_momentum, kinetic_energy_change


Expand Down
7 changes: 3 additions & 4 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def kernel(
)

# Langevin-like noise
momentum, dim = partially_refresh_momentum(
momentum = partially_refresh_momentum(
momentum=momentum, rng_key=rng_key, L=L, step_size=step_size
)

Expand All @@ -93,8 +93,7 @@ def kernel(
), MCLMCInfo(
logdensity=logdensity,
energy_change=kinetic_change - logdensity + state.logdensity,
# TODO: Potential bug here, see #625
kinetic_change=kinetic_change * (dim - 1),
kinetic_change=kinetic_change,
)

return kernel
Expand Down Expand Up @@ -191,4 +190,4 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L):
dim = m.shape[0]
nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim)
z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype)
return unravel_fn((m + z) / jnp.linalg.norm(m + z)), dim
return unravel_fn((m + z) / jnp.linalg.norm(m + z))
53 changes: 53 additions & 0 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,59 @@ def test_esh_momentum_update(self, dims):
next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0)
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)

@chex.all_variants(with_pmap=False)
def test_noneuclidean_leapfrog(self):
cov = jnp.asarray([[1.0, 0.5, 0.1], [0.5, 2.0, -0.1], [0.1, -0.1, 3.0]])
logdensity_fn = lambda x: stats.multivariate_normal.logpdf(
x, jnp.zeros([3]), cov
)

step = self.variant(integrators.noneuclidean_leapfrog(logdensity_fn))

rng = jax.random.key(4263456)
key0, key1 = jax.random.split(rng, 2)
position_init = jax.random.normal(key0, (3,))
momentum_init = generate_unit_vector(key1, position_init)
step_size = 0.0001
initial_state = integrators.new_integrator_state(
logdensity_fn, position_init, momentum_init
)
next_state, kinetic_energy_change = step(initial_state, step_size)

# explicit integration
op1 = esh_dynamics_momentum_update_one_step
op2 = integrators.euclidean_position_update_fn(logdensity_fn)
position, momentum, _, logdensity_grad = initial_state
momentum, kinetic_grad, kinetic_energy_change0 = op1(
momentum,
logdensity_grad,
step_size,
0.5,
None,
)
position, logdensity, logdensity_grad, position_update_info = op2(
position,
kinetic_grad,
step_size,
1.0,
None,
)
momentum, kinetic_grad, kinetic_energy_change1 = op1(
momentum,
logdensity_grad,
step_size,
0.5,
None,
)
next_state_ = integrators.IntegratorState(
position, momentum, logdensity, logdensity_grad
)

chex.assert_trees_all_close(next_state, next_state_)
np.testing.assert_almost_equal(
kinetic_energy_change, kinetic_energy_change0 + kinetic_energy_change1
)

@chex.all_variants(with_pmap=False)
@parameterized.parameters(
[
Expand Down

0 comments on commit fac1d5e

Please sign in to comment.