Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move pointwise log likelihood data to log_likelihood group (Pyro + NumPyro) #1044

Merged
merged 10 commits into from
Feb 6, 2020
Next Next commit
numpyro - Add more data to InferenceData objects
VincentBt committed Feb 4, 2020
commit 7119756de8d5a7d8131c191e87e5b176244a1420
29 changes: 14 additions & 15 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
@@ -103,24 +103,22 @@ def sample_stats_to_xarray(self):
data[name] = value
if stat == "num_steps":
data["depth"] = np.log2(value).astype(int) + 1

# extract log_likelihood
dims = None
if self.observations is not None and len(self.observations) == 1:
return dict_to_dataset(data, library=self.numpyro, dims=None, coords=self.coords)

@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from NumPyro posterior."""
data = {}
if self.observations is not None:
samples = self.posterior.get_samples(group_by_chain=False)
log_likelihood = self.numpyro.infer.log_likelihood(
log_likelihood_dict = self.numpyro.infer.log_likelihood(
self.model, samples, *self._args, **self._kwargs
)
obs_name, log_likelihood = list(log_likelihood.items())[0]
if self.dims is not None:
coord_name = self.dims.get("log_likelihood", self.dims.get(obs_name))
else:
coord_name = None
shape = (self.nchains, self.ndraws) + log_likelihood.shape[1:]
data["log_likelihood"] = np.reshape(log_likelihood.copy(), shape)
dims = {"log_likelihood": coord_name}

return dict_to_dataset(data, library=self.numpyro, dims=dims, coords=self.coords)
for obs_name, log_like in log_likelihood_dict.items():
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(log_like.copy(), shape)
return dict_to_dataset(data, library=self.numpyro, dims=self.dims, coords=self.coords)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
@@ -197,6 +195,7 @@ def to_inference_data(self):
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
38 changes: 35 additions & 3 deletions arviz/tests/test_data_numpyro.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import pytest
from jax.random import PRNGKey
from numpyro.infer import Predictive
import jax.numpy
VincentBt marked this conversation as resolved.
Show resolved Hide resolved

from ..data.io_numpyro import from_numpyro
from .helpers import ( # pylint: disable=unused-import
@@ -25,10 +26,10 @@ class Data:
def get_inference_data(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.sampler.model
posterior_predictive = Predictive(model, posterior_samples).get_samples(
posterior_predictive = Predictive(model, posterior_samples).__call__(
VincentBt marked this conversation as resolved.
Show resolved Hide resolved
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
)
prior = Predictive(model, num_samples=500).get_samples(
prior = Predictive(model, num_samples=500).__call__(
VincentBt marked this conversation as resolved.
Show resolved Hide resolved
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
)
return from_numpyro(
@@ -43,11 +44,42 @@ def test_inference_data(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
test_dict = {
"posterior": ["mu", "tau", "eta"],
"sample_stats": ["diverging", "tree_size", "depth", "log_likelihood"],
"sample_stats": ["diverging", "tree_size", "depth"],
"log_likelihood": ["obs"],
"posterior_predictive": ["obs"],
"prior": ["mu", "tau", "eta"],
"prior_predictive": ["obs"],
"observed_data": ["obs"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

def test_multiple_observed_rv(self):
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

y1 = np.random.randn(10)
y2 = np.random.randn(100)
def model_example_multiple_obs(y1=None, y2=None):
x = numpyro.sample('x', dist.Normal(1, 3))
numpyro.sample('y1', dist.Normal(x, 1), obs=y1)
numpyro.sample('y2', dist.Normal(x, 1), obs=y2)
nuts_kernel = NUTS(model_example_multiple_obs)
mcmc = MCMC(nuts_kernel, num_warmup=20, num_samples=20)
mcmc.run(PRNGKey(0), y1=y1, y2=y2)
inference_data = from_numpyro(mcmc)
test_dict = {
"posterior": ["x"],
"sample_stats": ["diverging"],
"log_likelihood": ["y1", "y2"],
"observed_data": ["y1", "y2"],
}
fails = check_multiple_attrs(test_dict, inference_data)
waic_results = az.stats.waic(inference_data)
print(waic_results)
print(waic_results.keys())
print(waic_results.waic, waic_results.waic_se)
assert not fails
assert not hasattr(inference_data.sample_stats, "log_likelihood")