From e05b5fe30586326dc2f6dcc1f57fcb318d569e4c Mon Sep 17 00:00:00 2001 From: Reuben Date: Sun, 10 Dec 2023 13:25:45 -0500 Subject: [PATCH] Remove transform from MCLMC (#623) * ADD TRANSFORM * ADD TRANSFORM * ADD DOCSTRING AND TEST * REMOVE TRANSFORMED_X FROM INFO * REMOVE TRANSFORMED_X FROM INFO --- blackjax/adaptation/mclmc_adaptation.py | 18 +++++++++++++----- blackjax/mcmc/mclmc.py | 16 +++------------- tests/mcmc/test_sampling.py | 18 +++++++----------- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 5b1d4a4ed..4e1f4ee75 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -259,14 +259,22 @@ def adaptation_L(state, params, num_steps, key): adaptation_L_keys = jax.random.split(key, num_steps) # run kernel in the normal way - state, info = jax.lax.scan( - f=lambda s, k: ( - kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size) - ), + def step(state, key): + next_state, _ = kernel( + rng_key=key, + state=state, + L=params.L, + step_size=params.step_size, + ) + + return next_state, next_state.position + + state, samples = jax.lax.scan( + f=step, init=state, xs=adaptation_L_keys, ) - samples = info.transformed_position # tranform is the identity here + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index a48d9bfeb..b65e5b6e4 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -21,7 +21,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan -from blackjax.types import Array, ArrayLike, PRNGKey +from blackjax.types import ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] @@ -31,8 +31,6 @@ class MCLMCInfo(NamedTuple): """ Additional information on the MCLMC transition. - transformed_position - The value of the samples after a transformation. This is typically a projection onto a lower dimensional subspace. logdensity The log-density of the distribution at the current step of the MCLMC chain. kinetic_change @@ -41,7 +39,6 @@ class MCLMCInfo(NamedTuple): The difference in energy between the current and previous step. """ - transformed_position: Array logdensity: float kinetic_change: float energy_change: float @@ -58,15 +55,13 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, integrator, transform): +def build_kernel(logdensity_fn, integrator): """Build a HMC kernel. Parameters ---------- integrator The symplectic integrator to use to integrate the Hamiltonian dynamics. - transform - Value of the difference in energy above which we consider that the transition is divergent. L the momentum decoherence rate. step_size @@ -98,7 +93,6 @@ def kernel( return IntegratorState( position, momentum, logdensity, logdensitygrad ), MCLMCInfo( - transformed_position=transform(position), logdensity=logdensity, energy_change=kinetic_change - logdensity + state.logdensity, kinetic_change=kinetic_change * (dim - 1), @@ -125,7 +119,6 @@ class mclmc: mclmc = blackjax.mcmc.mclmc.mclmc( logdensity_fn=logdensity_fn, - transform=lambda x: x, L=L, step_size=step_size ) @@ -143,8 +136,6 @@ class mclmc: ---------- logdensity_fn The log-density function we wish to draw samples from. - transform - A function to perform on the samples drawn from the target distribution L the momentum decoherence rate step_size @@ -165,11 +156,10 @@ def __new__( # type: ignore[misc] logdensity_fn: Callable, L, step_size, - transform: Callable = (lambda x: x), integrator=noneuclidean_mclachlan, seed=1, ) -> SamplingAlgorithm: - kernel = cls.build_kernel(logdensity_fn, integrator, transform) + kernel = cls.build_kernel(logdensity_fn, integrator) def update_fn(rng_key, state): return kernel(rng_key, state, L, step_size) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index ec47f1180..2a9fd07c5 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -84,7 +84,6 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): kernel = blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.integrators.noneuclidean_mclachlan, - transform=lambda x: x, ) ( @@ -97,24 +96,21 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): rng_key=tune_key, ) - keys = jax.random.split(run_key, num_steps) - sampling_alg = blackjax.mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, ) - _, blackjax_mclmc_result = jax.lax.scan( - f=lambda state, k: sampling_alg.step( - rng_key=k, - state=state, - ), - xs=keys, - init=blackjax_state_after_tuning, + _, samples, _ = run_inference_algorithm( + rng_key=run_key, + initial_state_or_position=blackjax_state_after_tuning, + inference_algorithm=sampling_alg, + num_steps=num_steps, + transform=lambda x: x.position, ) - return blackjax_mclmc_result.transformed_position + return samples @parameterized.parameters(itertools.product(regression_test_cases, [True, False])) def test_window_adaptation(self, case, is_mass_matrix_diagonal):