diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 840693f81..be96fa4b1 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -306,6 +306,7 @@ def esh_dynamics_momentum_update_one_step( There are no exponentials e^delta, which prevents overflows when the gradient norm is large. """ + del is_last_call flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_momentum, _ = ravel_pytree(momentum) @@ -324,11 +325,9 @@ def esh_dynamics_momentum_update_one_step( delta - jnp.log(2) + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) - ) + ) * (dims - 1) if previous_kinetic_energy_change is not None: kinetic_energy_change += previous_kinetic_energy_change - if is_last_call: - kinetic_energy_change *= dims - 1 return next_momentum, next_momentum, kinetic_energy_change diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 33f6060c2..26bbda2b8 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -84,7 +84,7 @@ def kernel( ) # Langevin-like noise - momentum, dim = partially_refresh_momentum( + momentum = partially_refresh_momentum( momentum=momentum, rng_key=rng_key, L=L, step_size=step_size ) @@ -93,8 +93,7 @@ def kernel( ), MCLMCInfo( logdensity=logdensity, energy_change=kinetic_change - logdensity + state.logdensity, - # TODO: Potential bug here, see #625 - kinetic_change=kinetic_change * (dim - 1), + kinetic_change=kinetic_change, ) return kernel @@ -191,4 +190,4 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - return unravel_fn((m + z) / jnp.linalg.norm(m + z)), dim + return unravel_fn((m + z) / jnp.linalg.norm(m + z)) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 2f5020d00..2ef285dd2 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -228,6 +228,59 @@ def test_esh_momentum_update(self, dims): next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) + @chex.all_variants(with_pmap=False) + def test_noneuclidean_leapfrog(self): + cov = jnp.asarray([[1.0, 0.5, 0.1], [0.5, 2.0, -0.1], [0.1, -0.1, 3.0]]) + logdensity_fn = lambda x: stats.multivariate_normal.logpdf( + x, jnp.zeros([3]), cov + ) + + step = self.variant(integrators.noneuclidean_leapfrog(logdensity_fn)) + + rng = jax.random.key(4263456) + key0, key1 = jax.random.split(rng, 2) + position_init = jax.random.normal(key0, (3,)) + momentum_init = generate_unit_vector(key1, position_init) + step_size = 0.0001 + initial_state = integrators.new_integrator_state( + logdensity_fn, position_init, momentum_init + ) + next_state, kinetic_energy_change = step(initial_state, step_size) + + # explicit integration + op1 = esh_dynamics_momentum_update_one_step + op2 = integrators.euclidean_position_update_fn(logdensity_fn) + position, momentum, _, logdensity_grad = initial_state + momentum, kinetic_grad, kinetic_energy_change0 = op1( + momentum, + logdensity_grad, + step_size, + 0.5, + None, + ) + position, logdensity, logdensity_grad, position_update_info = op2( + position, + kinetic_grad, + step_size, + 1.0, + None, + ) + momentum, kinetic_grad, kinetic_energy_change1 = op1( + momentum, + logdensity_grad, + step_size, + 0.5, + None, + ) + next_state_ = integrators.IntegratorState( + position, momentum, logdensity, logdensity_grad + ) + + chex.assert_trees_all_close(next_state, next_state_) + np.testing.assert_almost_equal( + kinetic_energy_change, kinetic_energy_change0 + kinetic_energy_change1 + ) + @chex.all_variants(with_pmap=False) @parameterized.parameters( [