diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py index 56138aa05..7ae7d2463 100644 --- a/blackjax/mcmc/barker.py +++ b/blackjax/mcmc/barker.py @@ -94,11 +94,11 @@ def _compute_acceptance_probability( y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x) x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x) - z_tilde_x_to_y = metric.scale(x, y_minus_x, True, True) - z_tilde_y_to_x = metric.scale(y, x_minus_y, True, True) + z_tilde_x_to_y = metric.scale(x, y_minus_x, inv=True, trans=True) + z_tilde_y_to_x = metric.scale(y, x_minus_y, inv=True, trans=True) - c_x_to_y = metric.scale(x, log_x, False, True) - c_y_to_x = metric.scale(y, log_y, False, True) + c_x_to_y = metric.scale(x, log_x, inv=False, trans=True) + c_y_to_x = metric.scale(y, log_y, inv=False, trans=True) z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y) z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x) @@ -256,7 +256,7 @@ def _barker_sample(key, mean, a, scale, metric): key1, key2 = jax.random.split(key) z = generate_gaussian_noise(key1, mean, sigma=scale) - c = metric.scale(mean, a, False, True) + c = metric.scale(mean, a, inv=False, trans=True) # Sample b=1 with probability p and 0 with probability 1 - p where # p = 1 / (1 + exp(-a * (z - mean))) @@ -267,7 +267,7 @@ def _barker_sample(key, mean, a, scale, metric): bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z) return jax.tree_util.tree_map( - lambda a, b: a + b, mean, metric.scale(mean, bz, False, False) + lambda a, b: a + b, mean, metric.scale(mean, bz, inv=False, trans=False) ) diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 51e60fbcf..f0720acf4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -30,7 +30,6 @@ """ from typing import Callable, NamedTuple, Optional, Protocol, Union -import jax import jax.numpy as jnp import jax.scipy as jscipy from jax.flatten_util import ravel_pytree @@ -194,8 +193,9 @@ def is_turning( def scale( position: ArrayLikeTree, element: ArrayLikeTree, - inv: ArrayLikeTree, - trans: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: """Scale elements by the mass matrix. @@ -205,10 +205,9 @@ def scale( The current position. Not used in this metric. elements Elements to scale - invs + inv Whether to scale the elements by the inverse mass matrix or the mass matrix. If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem. - Same pytree structure as `elements`. trans whether to transpose mass matrix when scaling @@ -296,8 +295,9 @@ def is_turning( def scale( position: ArrayLikeTree, element: ArrayLikeTree, - inv: ArrayLikeTree, - trans: ArrayLikeTree, + *, + inv: bool, + trans: bool, ) -> ArrayLikeTree: """Scale elements by the mass matrix. @@ -317,21 +317,14 @@ def scale( ) ravelled_element, unravel_fn = ravel_pytree(element) - def _linear_map_transpose(): - return jax.lax.cond( - inv, - lambda: linear_map(inv_mass_matrix_sqrt.T, ravelled_element), - lambda: linear_map(mass_matrix_sqrt.T, ravelled_element), - ) - - def _linear_map(): - return jax.lax.cond( - inv, - lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element), - lambda: linear_map(mass_matrix_sqrt, ravelled_element), - ) + if inv: + left_hand_side_matrix = inv_mass_matrix_sqrt + else: + left_hand_side_matrix = mass_matrix_sqrt + if trans: + left_hand_side_matrix = left_hand_side_matrix.T - scaled = jax.lax.cond(trans, _linear_map_transpose, _linear_map) + scaled = linear_map(left_hand_side_matrix, ravelled_element) return unravel_fn(scaled) diff --git a/tests/mcmc/test_barker.py b/tests/mcmc/test_barker.py index 65571b61c..04a86d1d4 100644 --- a/tests/mcmc/test_barker.py +++ b/tests/mcmc/test_barker.py @@ -93,7 +93,7 @@ def logdensity(x, data): # scaled, trivial pre-conditioning def scaled_logdensity(x_scaled, data, metric): - x = metric.scale(x_scaled, x_scaled, False, False) + x = metric.scale(x_scaled, x_scaled, inv=False, trans=False) return logdensity(x, data) logposterior_fn2 = functools.partial( @@ -101,7 +101,7 @@ def scaled_logdensity(x_scaled, data, metric): ) barker2 = blackjax.barker_proposal(logposterior_fn2, 1e-1, jnp.eye(2)) - true_x_trans = metric.scale(true_x, true_x, True, True) + true_x_trans = metric.scale(true_x, true_x, inv=True, trans=True) state2 = barker2.init(true_x_trans) n_steps = 10 @@ -125,7 +125,7 @@ def scaled_logdensity(x_scaled, data, metric): states2_trans = [] for ii in range(n_steps): s = states2[ii] - states2_trans.append(metric.scale(s, s, False, False)) + states2_trans.append(metric.scale(s, s, inv=False, trans=False)) states2_trans = jnp.array(states2_trans) assert jnp.allclose(states1, states2_trans)