Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample stats for blackjax nuts #6264

Merged
merged 5 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,40 @@ def _sample_stats_to_xarray(posterior):
return data


def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
"""Extract compatible stats from blackjax NUTS sampler
with PyMC/Arviz naming conventions.

Parameters
----------
sample_stats: NUTSInfo
Blackjax NUTSInfo object containing sampler statistics
potential_energy: ArrayLike
Potential energy values of sampled positions.

Returns
-------
Dict[str, ArrayLike]
Dictionary of sampler statistics.
"""
rename_key = {
"is_divergent": "diverging",
"energy": "energy",
"num_trajectory_expansions": "tree_depth",
"num_integration_steps": "n_steps",
"acceptance_rate": "acceptance_rate", # naming here is
"acceptance_probability": "acceptance_rate", # depending on blackjax version
}
converted_stats = {}
converted_stats["lp"] = potential_energy
for old_name, new_name in rename_key.items():
value = getattr(sample_stats, old_name, None)
if value is None:
continue
converted_stats[new_name] = value
return converted_stats


def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
"""Compute log-likelihood for all observations"""
elemwise_logp = model.logp(model.observed_RVs, sum=False)
Expand Down Expand Up @@ -360,9 +394,9 @@ def sample_blackjax_nuts(
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
)

states, _ = map_fn(get_posterior_samples)(keys, init_params)
states, stats = map_fn(get_posterior_samples)(keys, init_params)
raw_mcmc_samples = states.position

potential_energy = states.potential_energy
tic3 = datetime.now()
print("Sampling time = ", tic3 - tic2, file=sys.stdout)

Expand All @@ -372,7 +406,7 @@ def sample_blackjax_nuts(
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
)
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
tic4 = datetime.now()
print("Transformation time = ", tic4 - tic3, file=sys.stdout)

Expand Down Expand Up @@ -406,6 +440,7 @@ def sample_blackjax_nuts(
log_likelihood=log_likelihood,
observed_data=find_observations(model),
constant_data=find_constants(model),
sample_stats=mcmc_stats,
coords=coords,
dims=dims,
attrs=make_attrs(attrs, library=blackjax),
Expand Down
49 changes: 49 additions & 0 deletions pymc/tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,52 @@ def test_numpyro_nuts_kwargs_are_used(mocked: mock.MagicMock):
assert nuts_sampler._adapt_step_size == adapt_step_size
assert nuts_sampler._adapt_mass_matrix
assert nuts_sampler._target_accept_prob == target_accept


@pytest.mark.parametrize(
"sampler_name",
[
"sample_blackjax_nuts",
"sample_numpyro_nuts",
],
)
def test_idata_contains_stats(sampler_name: str):
"""Tests whether sampler statistics were written to sample_stats
group of InferenceData"""
if sampler_name == "sample_blackjax_nuts":
sampler = sample_blackjax_nuts
elif sampler_name == "sample_numpyro_nuts":
sampler = sample_numpyro_nuts

with pm.Model():
pm.Normal("a")
idata = sampler(tune=50, draws=50)

stats = idata.get("sample_stats")
assert stats is not None
n_chains = stats.dims["chain"]
n_draws = stats.dims["draw"]

# Stats vars expected for both samplers
expected_stat_vars = {
"acceptance_rate": (n_chains, n_draws),
"diverging": (n_chains, n_draws),
"energy": (n_chains, n_draws),
"tree_depth": (n_chains, n_draws),
"lp": (n_chains, n_draws),
}
# Stats only expected for blackjax nuts
if sampler_name == "sample_blackjax_nuts":
blackjax_special_vars = {}
stat_vars = expected_stat_vars | blackjax_special_vars
# Stats only expected for numpyro nuts
elif sampler_name == "sample_numpyro_nuts":
numpyro_special_vars = {
"step_size": (n_chains, n_draws),
"n_steps": (n_chains, n_draws),
}
stat_vars = expected_stat_vars | numpyro_special_vars
# test existence and dimensionality
for stat_var, stat_var_dims in stat_vars.items():
assert stat_var in stats.variables
assert stats.get(stat_var).values.shape == stat_var_dims