Skip to content

Commit

Permalink
Remove transform from MCLMC (#623)
Browse files Browse the repository at this point in the history
* ADD TRANSFORM

* ADD TRANSFORM

* ADD DOCSTRING AND TEST

* REMOVE TRANSFORMED_X FROM INFO

* REMOVE TRANSFORMED_X FROM INFO
  • Loading branch information
reubenharry authored Dec 10, 2023
1 parent 4ae2faf commit e05b5fe
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 29 deletions.
18 changes: 13 additions & 5 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,22 @@ def adaptation_L(state, params, num_steps, key):
adaptation_L_keys = jax.random.split(key, num_steps)

# run kernel in the normal way
state, info = jax.lax.scan(
f=lambda s, k: (
kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size)
),
def step(state, key):
next_state, _ = kernel(
rng_key=key,
state=state,
L=params.L,
step_size=params.step_size,
)

return next_state, next_state.position

state, samples = jax.lax.scan(
f=step,
init=state,
xs=adaptation_L_keys,
)
samples = info.transformed_position # tranform is the identity here

flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples)
ess = effective_sample_size(flat_samples[None, ...])

Expand Down
16 changes: 3 additions & 13 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan
from blackjax.types import Array, ArrayLike, PRNGKey
from blackjax.types import ArrayLike, PRNGKey
from blackjax.util import generate_unit_vector, pytree_size

__all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"]
Expand All @@ -31,8 +31,6 @@ class MCLMCInfo(NamedTuple):
"""
Additional information on the MCLMC transition.
transformed_position
The value of the samples after a transformation. This is typically a projection onto a lower dimensional subspace.
logdensity
The log-density of the distribution at the current step of the MCLMC chain.
kinetic_change
Expand All @@ -41,7 +39,6 @@ class MCLMCInfo(NamedTuple):
The difference in energy between the current and previous step.
"""

transformed_position: Array
logdensity: float
kinetic_change: float
energy_change: float
Expand All @@ -58,15 +55,13 @@ def init(x_initial: ArrayLike, logdensity_fn, rng_key):
)


def build_kernel(logdensity_fn, integrator, transform):
def build_kernel(logdensity_fn, integrator):
"""Build a HMC kernel.
Parameters
----------
integrator
The symplectic integrator to use to integrate the Hamiltonian dynamics.
transform
Value of the difference in energy above which we consider that the transition is divergent.
L
the momentum decoherence rate.
step_size
Expand Down Expand Up @@ -98,7 +93,6 @@ def kernel(
return IntegratorState(
position, momentum, logdensity, logdensitygrad
), MCLMCInfo(
transformed_position=transform(position),
logdensity=logdensity,
energy_change=kinetic_change - logdensity + state.logdensity,
kinetic_change=kinetic_change * (dim - 1),
Expand All @@ -125,7 +119,6 @@ class mclmc:
mclmc = blackjax.mcmc.mclmc.mclmc(
logdensity_fn=logdensity_fn,
transform=lambda x: x,
L=L,
step_size=step_size
)
Expand All @@ -143,8 +136,6 @@ class mclmc:
----------
logdensity_fn
The log-density function we wish to draw samples from.
transform
A function to perform on the samples drawn from the target distribution
L
the momentum decoherence rate
step_size
Expand All @@ -165,11 +156,10 @@ def __new__( # type: ignore[misc]
logdensity_fn: Callable,
L,
step_size,
transform: Callable = (lambda x: x),
integrator=noneuclidean_mclachlan,
seed=1,
) -> SamplingAlgorithm:
kernel = cls.build_kernel(logdensity_fn, integrator, transform)
kernel = cls.build_kernel(logdensity_fn, integrator)

def update_fn(rng_key, state):
return kernel(rng_key, state, L, step_size)
Expand Down
18 changes: 7 additions & 11 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):
kernel = blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.integrators.noneuclidean_mclachlan,
transform=lambda x: x,
)

(
Expand All @@ -97,24 +96,21 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key):
rng_key=tune_key,
)

keys = jax.random.split(run_key, num_steps)

sampling_alg = blackjax.mclmc(
logdensity_fn,
L=blackjax_mclmc_sampler_params.L,
step_size=blackjax_mclmc_sampler_params.step_size,
)

_, blackjax_mclmc_result = jax.lax.scan(
f=lambda state, k: sampling_alg.step(
rng_key=k,
state=state,
),
xs=keys,
init=blackjax_state_after_tuning,
_, samples, _ = run_inference_algorithm(
rng_key=run_key,
initial_state_or_position=blackjax_state_after_tuning,
inference_algorithm=sampling_alg,
num_steps=num_steps,
transform=lambda x: x.position,
)

return blackjax_mclmc_result.transformed_position
return samples

@parameterized.parameters(itertools.product(regression_test_cases, [True, False]))
def test_window_adaptation(self, case, is_mass_matrix_diagonal):
Expand Down

0 comments on commit e05b5fe

Please sign in to comment.