Skip to content

Commit

Permalink
fix test metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Oct 4, 2024
1 parent fa2e70b commit ef4b434
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions tests/mcmc/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,12 @@ def test_gaussian_euclidean_dim_1(self):
assert momentum_val == expected_momentum_val
assert kinetic_energy_val == expected_kinetic_energy_val

inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False)
scaled_momentum = scale(arbitrary_position, momentum_val, False, False)
inv_scaled_momentum = scale(
arbitrary_position, momentum_val, inv=True, trans=False
)
scaled_momentum = scale(
arbitrary_position, momentum_val, inv=False, trans=False
)

expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix)
expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix)
Expand Down Expand Up @@ -164,8 +168,12 @@ def test_gaussian_euclidean_dim_2(self):
np.testing.assert_allclose(expected_momentum_val, momentum_val)
np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val)

inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False)
scaled_momentum = scale(arbitrary_position, momentum_val, False, False)
inv_scaled_momentum = scale(
arbitrary_position, momentum_val, inv=True, trans=False
)
scaled_momentum = scale(
arbitrary_position, momentum_val, inv=False, trans=False
)

expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val
expected_scaled_momentum = L_inv @ momentum_val
Expand Down Expand Up @@ -226,8 +234,12 @@ def test_gaussian_riemannian_dim_1(self):
np.testing.assert_allclose(expected_momentum_val, momentum_val)
np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val)

inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False)
scaled_momentum = scale(arbitrary_position, momentum_val, False, False)
inv_scaled_momentum = scale(
arbitrary_position, momentum_val, inv=True, trans=False
)
scaled_momentum = scale(
arbitrary_position, momentum_val, inv=False, trans=False
)
expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix)
expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix)

Expand Down Expand Up @@ -265,8 +277,12 @@ def test_gaussian_riemannian_dim_2(self):
np.testing.assert_allclose(expected_momentum_val, momentum_val)
np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val)

inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False)
scaled_momentum = scale(arbitrary_position, momentum_val, False, False)
inv_scaled_momentum = scale(
arbitrary_position, momentum_val, inv=True, trans=False
)
scaled_momentum = scale(
arbitrary_position, momentum_val, inv=False, trans=False
)
expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val
expected_scaled_momentum = L_inv @ momentum_val

Expand Down

0 comments on commit ef4b434

Please sign in to comment.