From ef4b434253403618a59a6db9207395ad4ffefb23 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Fri, 4 Oct 2024 11:40:41 -0500 Subject: [PATCH] fix test metrics --- tests/mcmc/test_metrics.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index 098649a9a..e6aa5879f 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -131,8 +131,12 @@ def test_gaussian_euclidean_dim_1(self): assert momentum_val == expected_momentum_val assert kinetic_energy_val == expected_kinetic_energy_val - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) - scaled_momentum = scale(arbitrary_position, momentum_val, False, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -164,8 +168,12 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) - scaled_momentum = scale(arbitrary_position, momentum_val, False, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val @@ -226,8 +234,12 @@ def test_gaussian_riemannian_dim_1(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) - scaled_momentum = scale(arbitrary_position, momentum_val, False, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -265,8 +277,12 @@ def test_gaussian_riemannian_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) - scaled_momentum = scale(arbitrary_position, momentum_val, False, False) + inv_scaled_momentum = scale( + arbitrary_position, momentum_val, inv=True, trans=False + ) + scaled_momentum = scale( + arbitrary_position, momentum_val, inv=False, trans=False + ) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val