From 32f3da34a9a8a797342641216f078b9cbab2d867 Mon Sep 17 00:00:00 2001 From: "Tim Maier (Ubuntu Desktop)" Date: Tue, 8 Nov 2022 09:11:35 +0100 Subject: [PATCH] record blackjax sample stats --- pymc/sampling/jax.py | 42 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index c47f1da2f87..6f8de6262e0 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -142,6 +142,40 @@ def _sample_stats_to_xarray(posterior): return data +def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict: + """Extract compatible stats from blackjax NUTS sampler + with PyMC/Arviz naming conventions. + + Parameters + ---------- + sample_stats: NUTSInfo + Blackjax NUTSInfo object containing sampler statistics + potential_energy: ArrayLike + Potential energy values of sampled positions. + + Returns + ------- + Dict[str, ArrayLike] + Dictionary of sampler statistics. + """ + rename_key = { + "is_divergent": "diverging", + "energy": "energy", + "num_trajectory_expansions": "tree_depth", + "num_integration_steps": "n_steps", + "acceptance_rate": "acceptance_rate", # naming here is + "acceptance_probability": "acceptance_rate", # depending on blackjax version + } + converted_stats = {} + converted_stats["lp"] = potential_energy + for old_name, new_name in rename_key.items(): + value = getattr(sample_stats, old_name, None) + if value is None: + continue + converted_stats[new_name] = value + return converted_stats + + def _get_log_likelihood(model: Model, samples, backend=None) -> Dict: """Compute log-likelihood for all observations""" elemwise_logp = model.logp(model.observed_RVs, sum=False) @@ -360,9 +394,9 @@ def sample_blackjax_nuts( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' ) - states, _ = map_fn(get_posterior_samples)(keys, init_params) + states, stats = map_fn(get_posterior_samples)(keys, init_params) raw_mcmc_samples = states.position - + potential_energy = states.potential_energy tic3 = datetime.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) @@ -372,7 +406,7 @@ def sample_blackjax_nuts( *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]) ) mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} - + mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy) tic4 = datetime.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) @@ -410,7 +444,7 @@ def sample_blackjax_nuts( dims=dims, attrs=make_attrs(attrs, library=blackjax), ) - az_trace = to_trace(posterior=posterior, **idata_kwargs) + az_trace = to_trace(posterior=posterior, sample_stats=mcmc_stats, **idata_kwargs) return az_trace