Skip to content

Commit

Permalink
Adding log_likelihood, observed_data, and sample_stats to numpyro sam…
Browse files Browse the repository at this point in the history
…pler (#5189)

* Adding observed_data and sample_stats to numpyro sampler
* Refactor find_observations
* Add log likehoods to trace object

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
zaxtax and ricardoV94 authored Nov 18, 2021
1 parent c22859d commit fe2d101
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 23 deletions.
40 changes: 21 additions & 19 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@
Var = Any # pylint: disable=invalid-name


def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
if model is None:
return None

observations = {}
for obs in model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
observations[obs.name] = obs_data
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {obs}")
else:
warnings.warn(f"No data for observation {obs}")

return observations


class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Expand Down Expand Up @@ -196,25 +216,7 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
self.dims = {**model_dims, **self.dims}

self.density_dist_obs = density_dist_obs
self.observations = self.find_observations()

def find_observations(self) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
if self.model is None:
return None
observations = {}
for obs in self.model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
observations[obs.name] = obs_data
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {obs}")
else:
warnings.warn(f"No data for observation {obs}")

return observations
self.observations = find_observations(self.model)

def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
"""Split MultiTrace object into posterior and warmup.
Expand Down
62 changes: 58 additions & 4 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from aesara.link.jax.dispatch import jax_funcify

from pymc import Model, modelcontext
from pymc.aesaraf import compile_rv_inplace, inputvars
from pymc.aesaraf import compile_rv_inplace
from pymc.backends.arviz import find_observations
from pymc.distributions import logpt
from pymc.util import get_default_varnames

warnings.warn("This module is experimental.")
Expand Down Expand Up @@ -95,6 +97,39 @@ def logp_fn_wrap(x):
return logp_fn_wrap


# Adopted from arviz numpyro extractor
def _sample_stats_to_xarray(posterior):
"""Extract sample_stats from NumPyro posterior."""
rename_key = {
"potential_energy": "lp",
"adapt_state.step_size": "step_size",
"num_steps": "n_steps",
"accept_prob": "acceptance_rate",
}
data = {}
for stat, value in posterior.get_extra_fields(group_by_chain=True).items():
if isinstance(value, (dict, tuple)):
continue
name = rename_key.get(stat, stat)
value = value.copy()
data[name] = value
if stat == "num_steps":
data["tree_depth"] = np.log2(value).astype(int) + 1
return data


def _get_log_likelihood(model, samples):
"Compute log-likelihood for all observations"
data = {}
for v in model.observed_RVs:
logp_v = replace_shared_variables([logpt(v)])
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
jax_fn = jax_funcify(fgraph)
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0]
data[v.name] = result
return data


def sample_numpyro_nuts(
draws=1000,
tune=1000,
Expand Down Expand Up @@ -151,9 +186,23 @@ def sample_numpyro_nuts(
map_seed = jax.random.split(seed, chains)

if chains == 1:
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",))
init_params = init_state
map_seed = seed
else:
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
init_params = init_state_batched

pmap_numpyro.run(
map_seed,
init_params=init_params,
extra_fields=(
"num_steps",
"potential_energy",
"energy",
"adapt_state.step_size",
"accept_prob",
"diverging",
),
)

raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

Expand All @@ -172,6 +221,11 @@ def sample_numpyro_nuts(
print("Transformation time = ", tic4 - tic3, file=sys.stdout)

posterior = mcmc_samples
az_trace = az.from_dict(posterior=posterior)
az_posterior = az.from_dict(posterior=posterior)

az_obs = az.from_dict(observed_data=find_observations(model))
az_stats = az.from_dict(sample_stats=_sample_stats_to_xarray(pmap_numpyro))
az_ll = az.from_dict(log_likelihood=_get_log_likelihood(model, raw_mcmc_samples))
az_trace = az.concat(az_posterior, az_ll, az_obs, az_stats)

return az_trace
19 changes: 19 additions & 0 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pymc as pm

from pymc.sampling_jax import (
_get_log_likelihood,
get_jaxified_logp,
replace_shared_variables,
sample_numpyro_nuts,
Expand Down Expand Up @@ -61,6 +62,24 @@ def test_deterministic_samples():
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2)


def test_get_log_likelihood():
obs = np.random.normal(10, 2, size=100)
obs_at = aesara.shared(obs, borrow=True, name="obs")
with pm.Model() as model:
a = pm.Normal("a", 0, 2)
sigma = pm.HalfNormal("sigma")
b = pm.Normal("b", a, sigma=sigma, observed=obs_at)

trace = pm.sample(tune=10, draws=10, chains=2, random_seed=1322)

b_true = trace.log_likelihood.b.values
a = np.array(trace.posterior.a)
sigma_log_ = np.log(np.array(trace.posterior.sigma))
b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"]

assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1))


def test_replace_shared_variables():
x = aesara.shared(5, name="shared_x")

Expand Down

0 comments on commit fe2d101

Please sign in to comment.