Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce JAX post-processing memory usage #7311

Merged
merged 8 commits into from
Jul 11, 2024
61 changes: 18 additions & 43 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def _postprocess_samples(
)
return [jnp.swapaxes(t, 0, 1) for t in outs]
elif postprocessing_vectorize == "vmap":
return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend))

def process_fn(x):
return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

return jax.jit(process_fn, donate_argnums=0)(raw_mcmc_samples)
else:
raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

Expand Down Expand Up @@ -253,7 +257,16 @@ def _blackjax_inference_loop(
def _one_step(state, xs):
_, rng_key = xs
state, info = kernel(rng_key, state)
return state, (state, info)
position = state.position
stats = {
"diverging": info.is_divergent,
"energy": info.energy,
"tree_depth": info.num_trajectory_expansions,
"n_steps": info.num_integration_steps,
"acceptance_rate": info.acceptance_rate,
"lp": state.logdensity,
}
return state, (position, stats)

progress_bar = adaptation_kwargs.pop("progress_bar", False)
if progress_bar:
Expand All @@ -264,43 +277,9 @@ def _one_step(state, xs):
one_step = jax.jit(_one_step)

keys = jax.random.split(seed, draws)
_, (states, infos) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))

return states, infos

_, (samples, stats) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))

def _blackjax_stats_to_dict(sample_stats, potential_energy) -> dict:
"""Extract compatible stats from blackjax NUTS sampler
with PyMC/Arviz naming conventions.

Parameters
----------
sample_stats: NUTSInfo
Blackjax NUTSInfo object containing sampler statistics
potential_energy: ArrayLike
Potential energy values of sampled positions.

Returns
-------
Dict[str, ArrayLike]
Dictionary of sampler statistics.
"""
rename_key = {
"is_divergent": "diverging",
"energy": "energy",
"num_trajectory_expansions": "tree_depth",
"num_integration_steps": "n_steps",
"acceptance_rate": "acceptance_rate", # naming here is
"acceptance_probability": "acceptance_rate", # depending on blackjax version
}
converted_stats = {}
converted_stats["lp"] = potential_energy
for old_name, new_name in rename_key.items():
value = getattr(sample_stats, old_name, None)
if value is None:
continue
converted_stats[new_name] = value
return converted_stats
return samples, stats


def _sample_blackjax_nuts(
Expand Down Expand Up @@ -410,11 +389,7 @@ def _sample_blackjax_nuts(
**nuts_kwargs,
)

states, stats = map_fn(get_posterior_samples)(keys, initial_points)
raw_mcmc_samples = states.position
potential_energy = states.logdensity.block_until_ready()
sample_stats = _blackjax_stats_to_dict(stats, potential_energy)

raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points)
return raw_mcmc_samples, sample_stats, blackjax


Expand Down
3 changes: 3 additions & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def _sample_external_nuts(
var_names: Sequence[str] | None,
progressbar: bool,
idata_kwargs: dict | None,
compute_convergence_checks: bool,
nuts_sampler_kwargs: dict | None,
**kwargs,
):
Expand Down Expand Up @@ -360,6 +361,7 @@ def _sample_external_nuts(
progressbar=progressbar,
nuts_sampler=sampler,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
**nuts_sampler_kwargs,
)
return idata
Expand Down Expand Up @@ -697,6 +699,7 @@ def sample(
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
nuts_sampler_kwargs=nuts_sampler_kwargs,
**kwargs,
)
Expand Down
Loading