Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move pointwise log likelihood data to log_likelihood group (Pyro + NumPyro) #1044

Merged
merged 10 commits into from
Feb 6, 2020
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ matrix:
- name: "Python 3.6 Unit Test"
python: 3.6
env: PYTHON_VERSION=3.6 PYSTAN_VERSION=latest PYRO_VERSION=latest EMCEE_VERSION=latest COVERALLS_PARALLEL=true NAME="UNIT"
- name: "Python 3.6 Unit Test - PyStan=3 Pyro=0.5.1 Emcee=2 TF=1"
- name: "Python 3.6 Unit Test - PyStan=3 Pyro=1.0.0 Emcee=2 TF=1"
python: 3.6
env: PYTHON_VERSION=3.6 PYSTAN_VERSION=preview PYRO_VERSION=0.5.1 PYTORCH_VERSION=1.3.0 EMCEE_VERSION=2 TF_VERSION=1 COVERALLS_PARALLEL=true NAME="UNIT"
env: PYTHON_VERSION=3.6 PYSTAN_VERSION=preview PYRO_VERSION=1.0.0 PYTORCH_VERSION=1.3.0 EMCEE_VERSION=2 TF_VERSION=1 COVERALLS_PARALLEL=true NAME="UNIT"
- name: "Python 3.5 Unit Test"
python: 3.5
env: PYTHON_VERSION=3.5 PYSTAN_VERSION=latest PYRO_VERSION=latest EMCEE_VERSION=latest PYMC3_VERSION=3.8 COVERALLS_PARALLEL=true NAME="UNIT"
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

### New features
* Add out-of-sample predictions (`predictions` and `predictions_constant_data` groups) to pymc3 and pystan translations (#983 and #1032)
* Started adding pointwise log likelihood storage support (#794)
* Started adding pointwise log likelihood storage support (#794, #1044)
* Violinplot: rug-plot option (#997)
* Integrated rcParams `plot.point_estimate` (#994), `stats.ic_scale` (#993) and `stats.credible_interval` (#1017)
* Added `group` argument to `plot_ppc` (#1008), `plot_pair` (#1009) and `plot_joint` (#1012)
Expand Down
27 changes: 13 additions & 14 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,22 @@ def sample_stats_to_xarray(self):
data[name] = value
if stat == "num_steps":
data["depth"] = np.log2(value).astype(int) + 1
return dict_to_dataset(data, library=self.numpyro, dims=None, coords=self.coords)

# extract log_likelihood
dims = None
if self.observations is not None and len(self.observations) == 1:
@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from NumPyro posterior."""
data = {}
if self.observations is not None:
samples = self.posterior.get_samples(group_by_chain=False)
log_likelihood = self.numpyro.infer.log_likelihood(
log_likelihood_dict = self.numpyro.infer.log_likelihood(
self.model, samples, *self._args, **self._kwargs
)
obs_name, log_likelihood = list(log_likelihood.items())[0]
if self.dims is not None:
coord_name = self.dims.get("log_likelihood", self.dims.get(obs_name))
else:
coord_name = None
shape = (self.nchains, self.ndraws) + log_likelihood.shape[1:]
data["log_likelihood"] = np.reshape(log_likelihood.copy(), shape)
dims = {"log_likelihood": coord_name}

return dict_to_dataset(data, library=self.numpyro, dims=dims, coords=self.coords)
for obs_name, log_like in log_likelihood_dict.items():
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(log_like.copy(), shape)
return dict_to_dataset(data, library=self.numpyro, dims=self.dims, coords=self.coords)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
Expand Down Expand Up @@ -197,6 +195,7 @@ def to_inference_data(self):
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
Expand Down
29 changes: 15 additions & 14 deletions arviz/data/io_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,27 +85,27 @@ def sample_stats_to_xarray(self):
for i, k in enumerate(sorted(divergences)):
diverging[i, divergences[k]] = True
data = {"diverging": diverging}
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=None)

# extract log_likelihood
@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from Pyro posterior."""
data = {}
dims = None
if self.observations is not None and len(self.observations) == 1:
if self.observations is not None:
try:
obs_name = list(self.observations.keys())[0]
samples = self.posterior.get_samples(group_by_chain=False)
predictive = self.pyro.infer.Predictive(self.model, samples)
vectorized_trace = predictive.get_vectorized_trace(*self._args, **self._kwargs)
obs_site = vectorized_trace.nodes[obs_name]
log_likelihood = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
if self.dims is not None:
coord_name = self.dims.get("log_likelihood", self.dims.get(obs_name))
else:
coord_name = None
shape = (self.nchains, self.ndraws) + log_likelihood.shape[1:]
data["log_likelihood"] = np.reshape(log_likelihood, shape)
dims = {"log_likelihood": coord_name}
for obs_name in self.observations.keys():
obs_site = vectorized_trace.nodes[obs_name]
log_like = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(log_like, shape)
except: # pylint: disable=bare-except
# cannot get vectorized trace
pass
return None
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=dims)

@requires("posterior_predictive")
Expand All @@ -122,7 +122,7 @@ def posterior_predictive_to_xarray(self):
else:
data[k] = utils.expand_dims(ary)
_log.warning(
"posterior predictive shape not compatible with number of chains and draws. "
"posterior predictive shape not compatible with number of chains and draws."
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
Expand Down Expand Up @@ -179,6 +179,7 @@ def to_inference_data(self):
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
Expand Down
39 changes: 36 additions & 3 deletions arviz/tests/test_data_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class Data:
def get_inference_data(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.sampler.model
posterior_predictive = Predictive(model, posterior_samples).get_samples(
posterior_predictive = Predictive(model, posterior_samples)(
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
)
prior = Predictive(model, num_samples=500).get_samples(
prior = Predictive(model, num_samples=500)(
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
)
return from_numpyro(
Expand All @@ -43,11 +43,44 @@ def test_inference_data(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
test_dict = {
"posterior": ["mu", "tau", "eta"],
"sample_stats": ["diverging", "tree_size", "depth", "log_likelihood"],
"sample_stats": ["diverging", "tree_size", "depth"],
"log_likelihood": ["obs"],
"posterior_predictive": ["obs"],
"prior": ["mu", "tau", "eta"],
"prior_predictive": ["obs"],
"observed_data": ["obs"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

def test_multiple_observed_rv(self):
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

y1 = np.random.randn(10)
y2 = np.random.randn(100)

def model_example_multiple_obs(y1=None, y2=None):
x = numpyro.sample("x", dist.Normal(1, 3))
numpyro.sample("y1", dist.Normal(x, 1), obs=y1)
numpyro.sample("y2", dist.Normal(x, 1), obs=y2)

nuts_kernel = NUTS(model_example_multiple_obs)
mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
mcmc.run(PRNGKey(0), y1=y1, y2=y2)
inference_data = from_numpyro(mcmc)
test_dict = {
"posterior": ["x"],
"sample_stats": ["diverging"],
"log_likelihood": ["y1", "y2"],
"observed_data": ["y1", "y2"],
}
fails = check_multiple_attrs(test_dict, inference_data)
# from ..stats import waic
# waic_results = waic(inference_data)
# print(waic_results)
# print(waic_results.keys())
# print(waic_results.waic, waic_results.waic_se)
assert not fails
assert not hasattr(inference_data.sample_stats, "log_likelihood")
Loading