Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move pointwise log likelihood data to log_likelihood group (Pyro + NumPyro) #1044

Merged
merged 10 commits into from
Feb 6, 2020
Prev Previous commit
Next Next commit
resolves pylint failures
VincentBt committed Feb 4, 2020
commit 08db69325af175ed50d0d6e10d6d078cd8dd65cb
2 changes: 1 addition & 1 deletion arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
@@ -104,7 +104,7 @@ def sample_stats_to_xarray(self):
if stat == "num_steps":
data["depth"] = np.log2(value).astype(int) + 1
return dict_to_dataset(data, library=self.numpyro, dims=None, coords=self.coords)

@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
9 changes: 4 additions & 5 deletions arviz/data/io_pyro.py
Original file line number Diff line number Diff line change
@@ -91,7 +91,6 @@ def sample_stats_to_xarray(self):
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from Pyro posterior."""
data = {}
dims = None
if self.observations is not None:
try:
@@ -101,9 +100,9 @@ def log_likelihood_to_xarray(self):
data = {}
VincentBt marked this conversation as resolved.
Show resolved Hide resolved
for obs_name in self.observations.keys():
obs_site = vectorized_trace.nodes[obs_name]
log_likelihood = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
shape = (self.nchains, self.ndraws) + log_likelihood.shape[1:]
data[obs_name] = np.reshape(log_likelihood.copy(), shape)
log_like = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(log_like, shape)
except: # pylint: disable=bare-except
# cannot get vectorized trace
pass
@@ -123,7 +122,7 @@ def posterior_predictive_to_xarray(self):
else:
data[k] = utils.expand_dims(ary)
_log.warning(
"posterior predictive shape not compatible with number of chains and draws. "
"posterior predictive shape not compatible with number of chains and draws."
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
25 changes: 13 additions & 12 deletions arviz/tests/test_data_numpyro.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@
import pytest
from jax.random import PRNGKey
from numpyro.infer import Predictive
import jax.numpy

from ..data.io_numpyro import from_numpyro
from .helpers import ( # pylint: disable=unused-import
@@ -26,10 +25,10 @@ class Data:
def get_inference_data(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.sampler.model
posterior_predictive = Predictive(model, posterior_samples).__call__(
posterior_predictive = Predictive(model, posterior_samples)(
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
)
prior = Predictive(model, num_samples=500).__call__(
prior = Predictive(model, num_samples=500)(
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
)
return from_numpyro(
@@ -58,12 +57,15 @@ def test_multiple_observed_rv(self):
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

y1 = np.random.randn(10)
y2 = np.random.randn(100)

def model_example_multiple_obs(y1=None, y2=None):
x = numpyro.sample('x', dist.Normal(1, 3))
numpyro.sample('y1', dist.Normal(x, 1), obs=y1)
numpyro.sample('y2', dist.Normal(x, 1), obs=y2)
x = numpyro.sample("x", dist.Normal(1, 3))
numpyro.sample("y1", dist.Normal(x, 1), obs=y1)
numpyro.sample("y2", dist.Normal(x, 1), obs=y2)

nuts_kernel = NUTS(model_example_multiple_obs)
mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
mcmc.run(PRNGKey(0), y1=y1, y2=y2)
@@ -75,11 +77,10 @@ def model_example_multiple_obs(y1=None, y2=None):
"observed_data": ["y1", "y2"],
}
fails = check_multiple_attrs(test_dict, inference_data)
# from ..stats import waic
# waic_results = waic(inference_data)
# print(waic_results)
# print(waic_results.keys())
# print(waic_results.waic, waic_results.waic_se)
# from ..stats import waic
# waic_results = waic(inference_data)
# print(waic_results)
# print(waic_results.keys())
# print(waic_results.waic, waic_results.waic_se)
assert not fails
assert not hasattr(inference_data.sample_stats, "log_likelihood")

15 changes: 8 additions & 7 deletions arviz/tests/test_data_pyro.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ def get_inference_data(self, data, eight_schools_params):
prior=prior,
posterior_predictive=posterior_predictive,
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
dims={"theta": ["school"], "eta": ["school"]}
)

def test_inference_data(self, data, eight_schools_params):
@@ -99,16 +99,17 @@ def test_inference_data_only_posterior_has_log_likelihood(self, data):
assert not fails

def test_multiple_observed_rv(self):
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import torch

y1 = torch.randn(10)
y2 = torch.randn(10)

def model_example_multiple_obs(y1=None, y2=None):
x = pyro.sample('x', dist.Normal(1, 3))
pyro.sample('y1', dist.Normal(x, 1), obs=y1)
pyro.sample('y2', dist.Normal(x, 1), obs=y2)
x = pyro.sample("x", dist.Normal(1, 3))
pyro.sample("y1", dist.Normal(x, 1), obs=y1)
pyro.sample("y2", dist.Normal(x, 1), obs=y2)

nuts_kernel = NUTS(model_example_multiple_obs)
mcmc = MCMC(nuts_kernel, num_samples=10)
mcmc.run(y1=y1, y2=y2)
@@ -121,4 +122,4 @@ def model_example_multiple_obs(y1=None, y2=None):
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert not hasattr(inference_data.sample_stats, "log_likelihood")
assert not hasattr(inference_data.sample_stats, "log_likelihood")