Skip to content

Commit

Permalink
record blackjax sample stats
Browse files Browse the repository at this point in the history
  • Loading branch information
TimOliverMaier committed Nov 8, 2022
1 parent 1dd85c1 commit 32f3da3
Showing 1 changed file with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,40 @@ def _sample_stats_to_xarray(posterior):
return data


def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
"""Extract compatible stats from blackjax NUTS sampler
with PyMC/Arviz naming conventions.
Parameters
----------
sample_stats: NUTSInfo
Blackjax NUTSInfo object containing sampler statistics
potential_energy: ArrayLike
Potential energy values of sampled positions.
Returns
-------
Dict[str, ArrayLike]
Dictionary of sampler statistics.
"""
rename_key = {
"is_divergent": "diverging",
"energy": "energy",
"num_trajectory_expansions": "tree_depth",
"num_integration_steps": "n_steps",
"acceptance_rate": "acceptance_rate", # naming here is
"acceptance_probability": "acceptance_rate", # depending on blackjax version
}
converted_stats = {}
converted_stats["lp"] = potential_energy
for old_name, new_name in rename_key.items():
value = getattr(sample_stats, old_name, None)
if value is None:
continue
converted_stats[new_name] = value
return converted_stats


def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
"""Compute log-likelihood for all observations"""
elemwise_logp = model.logp(model.observed_RVs, sum=False)
Expand Down Expand Up @@ -360,9 +394,9 @@ def sample_blackjax_nuts(
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
)

states, _ = map_fn(get_posterior_samples)(keys, init_params)
states, stats = map_fn(get_posterior_samples)(keys, init_params)
raw_mcmc_samples = states.position

potential_energy = states.potential_energy
tic3 = datetime.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)

Expand All @@ -372,7 +406,7 @@ def sample_blackjax_nuts(
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
tic4 = datetime.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)

Expand Down Expand Up @@ -410,7 +444,7 @@ def sample_blackjax_nuts(
dims=dims,
attrs=make_attrs(attrs, library=blackjax),
)
az_trace = to_trace(posterior=posterior, **idata_kwargs)
az_trace = to_trace(posterior=posterior, sample_stats=mcmc_stats, **idata_kwargs)

return az_trace

Expand Down

0 comments on commit 32f3da3

Please sign in to comment.