Skip to content

Commit

Permalink
propagate changes of inv, trans as required kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Oct 4, 2024
1 parent eb9acbf commit fa2e70b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 30 deletions.
12 changes: 6 additions & 6 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def _compute_acceptance_probability(

y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x)
x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x)
z_tilde_x_to_y = metric.scale(x, y_minus_x, True, True)
z_tilde_y_to_x = metric.scale(y, x_minus_y, True, True)
z_tilde_x_to_y = metric.scale(x, y_minus_x, inv=True, trans=True)
z_tilde_y_to_x = metric.scale(y, x_minus_y, inv=True, trans=True)

c_x_to_y = metric.scale(x, log_x, False, True)
c_y_to_x = metric.scale(y, log_y, False, True)
c_x_to_y = metric.scale(x, log_x, inv=False, trans=True)
c_y_to_x = metric.scale(y, log_y, inv=False, trans=True)

z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y)
z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x)
Expand Down Expand Up @@ -256,7 +256,7 @@ def _barker_sample(key, mean, a, scale, metric):
key1, key2 = jax.random.split(key)

z = generate_gaussian_noise(key1, mean, sigma=scale)
c = metric.scale(mean, a, False, True)
c = metric.scale(mean, a, inv=False, trans=True)

# Sample b=1 with probability p and 0 with probability 1 - p where
# p = 1 / (1 + exp(-a * (z - mean)))
Expand All @@ -267,7 +267,7 @@ def _barker_sample(key, mean, a, scale, metric):
bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z)

return jax.tree_util.tree_map(
lambda a, b: a + b, mean, metric.scale(mean, bz, False, False)
lambda a, b: a + b, mean, metric.scale(mean, bz, inv=False, trans=False)
)


Expand Down
35 changes: 14 additions & 21 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -194,8 +193,9 @@ def is_turning(
def scale(
position: ArrayLikeTree,
element: ArrayLikeTree,
inv: ArrayLikeTree,
trans: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Expand All @@ -205,10 +205,9 @@ def scale(
The current position. Not used in this metric.
elements
Elements to scale
invs
inv
Whether to scale the elements by the inverse mass matrix or the mass matrix.
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
Same pytree structure as `elements`.
trans
whether to transpose mass matrix when scaling
Expand Down Expand Up @@ -296,8 +295,9 @@ def is_turning(
def scale(
position: ArrayLikeTree,
element: ArrayLikeTree,
inv: ArrayLikeTree,
trans: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Expand All @@ -317,21 +317,14 @@ def scale(
)
ravelled_element, unravel_fn = ravel_pytree(element)

def _linear_map_transpose():
return jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt.T, ravelled_element),
lambda: linear_map(mass_matrix_sqrt.T, ravelled_element),
)

def _linear_map():
return jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)
if inv:
left_hand_side_matrix = inv_mass_matrix_sqrt
else:
left_hand_side_matrix = mass_matrix_sqrt
if trans:
left_hand_side_matrix = left_hand_side_matrix.T

scaled = jax.lax.cond(trans, _linear_map_transpose, _linear_map)
scaled = linear_map(left_hand_side_matrix, ravelled_element)

return unravel_fn(scaled)

Expand Down
6 changes: 3 additions & 3 deletions tests/mcmc/test_barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ def logdensity(x, data):

# scaled, trivial pre-conditioning
def scaled_logdensity(x_scaled, data, metric):
x = metric.scale(x_scaled, x_scaled, False, False)
x = metric.scale(x_scaled, x_scaled, inv=False, trans=False)
return logdensity(x, data)

logposterior_fn2 = functools.partial(
scaled_logdensity, data=data, metric=metric
)
barker2 = blackjax.barker_proposal(logposterior_fn2, 1e-1, jnp.eye(2))

true_x_trans = metric.scale(true_x, true_x, True, True)
true_x_trans = metric.scale(true_x, true_x, inv=True, trans=True)
state2 = barker2.init(true_x_trans)

n_steps = 10
Expand All @@ -125,7 +125,7 @@ def scaled_logdensity(x_scaled, data, metric):
states2_trans = []
for ii in range(n_steps):
s = states2[ii]
states2_trans.append(metric.scale(s, s, False, False))
states2_trans.append(metric.scale(s, s, inv=False, trans=False))
states2_trans = jnp.array(states2_trans)
assert jnp.allclose(states1, states2_trans)

Expand Down

0 comments on commit fa2e70b

Please sign in to comment.