Skip to content

Commit

Permalink
Move pointwise log likelihood data to log_likelihood group (Pyro + Nu…
Browse files Browse the repository at this point in the history
…mPyro) (#1044)

* numpyro - Add more data to InferenceData objects

* resolves pyro deprecation warnings

* pyro - Add more data to InferenceData objects

* removing comment in pyro test

* resolves pylint failures

* resolving additionnal failures (Predictive().forward does not exist)

* updates pyro version requirements

* updates pyro version for travis (Predictive.forward does not exist)

* modifies case vectorized trace cannot be obtained (pyro)

* updates Pyro version for Azure
VincentBt authored Feb 6, 2020
1 parent 3899c39 commit a82ae30
Showing 8 changed files with 195 additions and 136 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -11,9 +11,9 @@ matrix:
- name: "Python 3.6 Unit Test"
python: 3.6
env: PYTHON_VERSION=3.6 PYSTAN_VERSION=latest PYRO_VERSION=latest EMCEE_VERSION=latest COVERALLS_PARALLEL=true NAME="UNIT"
- name: "Python 3.6 Unit Test - PyStan=3 Pyro=0.5.1 Emcee=2 TF=1"
- name: "Python 3.6 Unit Test - PyStan=3 Pyro=1.0.0 Emcee=2 TF=1"
python: 3.6
env: PYTHON_VERSION=3.6 PYSTAN_VERSION=preview PYRO_VERSION=0.5.1 PYTORCH_VERSION=1.3.0 EMCEE_VERSION=2 TF_VERSION=1 COVERALLS_PARALLEL=true NAME="UNIT"
env: PYTHON_VERSION=3.6 PYSTAN_VERSION=preview PYRO_VERSION=1.0.0 PYTORCH_VERSION=1.3.0 EMCEE_VERSION=2 TF_VERSION=1 COVERALLS_PARALLEL=true NAME="UNIT"
- name: "Python 3.5 Unit Test"
python: 3.5
env: PYTHON_VERSION=3.5 PYSTAN_VERSION=latest PYRO_VERSION=latest EMCEE_VERSION=latest PYMC3_VERSION=3.8 COVERALLS_PARALLEL=true NAME="UNIT"
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@

### New features
* Add out-of-sample predictions (`predictions` and `predictions_constant_data` groups) to pymc3 and pystan translations (#983 and #1032)
* Started adding pointwise log likelihood storage support (#794)
* Started adding pointwise log likelihood storage support (#794, #1044)
* Violinplot: rug-plot option (#997)
* Integrated rcParams `plot.point_estimate` (#994), `stats.ic_scale` (#993) and `stats.credible_interval` (#1017)
* Added `group` argument to `plot_ppc` (#1008), `plot_pair` (#1009) and `plot_joint` (#1012)
27 changes: 13 additions & 14 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
@@ -103,24 +103,22 @@ def sample_stats_to_xarray(self):
data[name] = value
if stat == "num_steps":
data["depth"] = np.log2(value).astype(int) + 1
return dict_to_dataset(data, library=self.numpyro, dims=None, coords=self.coords)

# extract log_likelihood
dims = None
if self.observations is not None and len(self.observations) == 1:
@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from NumPyro posterior."""
data = {}
if self.observations is not None:
samples = self.posterior.get_samples(group_by_chain=False)
log_likelihood = self.numpyro.infer.log_likelihood(
log_likelihood_dict = self.numpyro.infer.log_likelihood(
self.model, samples, *self._args, **self._kwargs
)
obs_name, log_likelihood = list(log_likelihood.items())[0]
if self.dims is not None:
coord_name = self.dims.get("log_likelihood", self.dims.get(obs_name))
else:
coord_name = None
shape = (self.nchains, self.ndraws) + log_likelihood.shape[1:]
data["log_likelihood"] = np.reshape(log_likelihood.copy(), shape)
dims = {"log_likelihood": coord_name}

return dict_to_dataset(data, library=self.numpyro, dims=dims, coords=self.coords)
for obs_name, log_like in log_likelihood_dict.items():
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(log_like.copy(), shape)
return dict_to_dataset(data, library=self.numpyro, dims=self.dims, coords=self.coords)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
@@ -197,6 +195,7 @@ def to_inference_data(self):
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
29 changes: 15 additions & 14 deletions arviz/data/io_pyro.py
Original file line number Diff line number Diff line change
@@ -85,27 +85,27 @@ def sample_stats_to_xarray(self):
for i, k in enumerate(sorted(divergences)):
diverging[i, divergences[k]] = True
data = {"diverging": diverging}
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=None)

# extract log_likelihood
@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from Pyro posterior."""
data = {}
dims = None
if self.observations is not None and len(self.observations) == 1:
if self.observations is not None:
try:
obs_name = list(self.observations.keys())[0]
samples = self.posterior.get_samples(group_by_chain=False)
predictive = self.pyro.infer.Predictive(self.model, samples)
vectorized_trace = predictive.get_vectorized_trace(*self._args, **self._kwargs)
obs_site = vectorized_trace.nodes[obs_name]
log_likelihood = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
if self.dims is not None:
coord_name = self.dims.get("log_likelihood", self.dims.get(obs_name))
else:
coord_name = None
shape = (self.nchains, self.ndraws) + log_likelihood.shape[1:]
data["log_likelihood"] = np.reshape(log_likelihood, shape)
dims = {"log_likelihood": coord_name}
for obs_name in self.observations.keys():
obs_site = vectorized_trace.nodes[obs_name]
log_like = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(log_like, shape)
except: # pylint: disable=bare-except
# cannot get vectorized trace
pass
return None
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=dims)

@requires("posterior_predictive")
@@ -122,7 +122,7 @@ def posterior_predictive_to_xarray(self):
else:
data[k] = utils.expand_dims(ary)
_log.warning(
"posterior predictive shape not compatible with number of chains and draws. "
"posterior predictive shape not compatible with number of chains and draws."
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
@@ -179,6 +179,7 @@ def to_inference_data(self):
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
39 changes: 36 additions & 3 deletions arviz/tests/test_data_numpyro.py
Original file line number Diff line number Diff line change
@@ -25,10 +25,10 @@ class Data:
def get_inference_data(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.sampler.model
posterior_predictive = Predictive(model, posterior_samples).get_samples(
posterior_predictive = Predictive(model, posterior_samples)(
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
)
prior = Predictive(model, num_samples=500).get_samples(
prior = Predictive(model, num_samples=500)(
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
)
return from_numpyro(
@@ -43,11 +43,44 @@ def test_inference_data(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
test_dict = {
"posterior": ["mu", "tau", "eta"],
"sample_stats": ["diverging", "tree_size", "depth", "log_likelihood"],
"sample_stats": ["diverging", "tree_size", "depth"],
"log_likelihood": ["obs"],
"posterior_predictive": ["obs"],
"prior": ["mu", "tau", "eta"],
"prior_predictive": ["obs"],
"observed_data": ["obs"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

def test_multiple_observed_rv(self):
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

y1 = np.random.randn(10)
y2 = np.random.randn(100)

def model_example_multiple_obs(y1=None, y2=None):
x = numpyro.sample("x", dist.Normal(1, 3))
numpyro.sample("y1", dist.Normal(x, 1), obs=y1)
numpyro.sample("y2", dist.Normal(x, 1), obs=y2)

nuts_kernel = NUTS(model_example_multiple_obs)
mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
mcmc.run(PRNGKey(0), y1=y1, y2=y2)
inference_data = from_numpyro(mcmc)
test_dict = {
"posterior": ["x"],
"sample_stats": ["diverging"],
"log_likelihood": ["y1", "y2"],
"observed_data": ["y1", "y2"],
}
fails = check_multiple_attrs(test_dict, inference_data)
# from ..stats import waic
# waic_results = waic(inference_data)
# print(waic_results)
# print(waic_results.keys())
# print(waic_results.waic, waic_results.waic_se)
assert not fails
assert not hasattr(inference_data.sample_stats, "log_likelihood")
224 changes: 125 additions & 99 deletions arviz/tests/test_data_pyro.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,125 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
import numpy as np
import packaging
import pytest
import torch
import pyro
from pyro.infer import Predictive

from ..data.io_pyro import from_pyro
from .helpers import ( # pylint: disable=unused-import
chains,
check_multiple_attrs,
draws,
eight_schools_params,
load_cached_models,
)


class TestDataPyro:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
class Data:
obj = load_cached_models(eight_schools_params, draws, chains, "pyro")["pyro"]

return Data

def get_inference_data(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.kernel.model
posterior_predictive = Predictive(model, posterior_samples).get_samples(
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
)
prior = Predictive(model, num_samples=500).get_samples(
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
)
return from_pyro(
posterior=data.obj,
prior=prior,
posterior_predictive=posterior_predictive,
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)

def test_inference_data(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
test_dict = {
"posterior": ["mu", "tau", "eta"],
"sample_stats": ["diverging"],
"posterior_predictive": ["obs"],
"prior": ["mu", "tau", "eta"],
"prior_predictive": ["obs"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

@pytest.mark.skipif(
packaging.version.parse(pyro.__version__) < packaging.version.parse("1.0.0"),
reason="requires pyro 1.0.0 or higher",
)
def test_inference_data_has_log_likelihood_and_observed_data(self, data):
idata = from_pyro(data.obj)
test_dict = {"sample_stats": ["log_likelihood"], "observed_data": ["obs"]}
fails = check_multiple_attrs(test_dict, idata)
assert not fails

def test_inference_data_no_posterior(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.kernel.model
posterior_predictive = Predictive(model, posterior_samples).get_samples(
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
)
prior = Predictive(model, num_samples=500).get_samples(
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
)
idata = from_pyro(
prior=prior,
posterior_predictive=posterior_predictive,
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)
test_dict = {"posterior_predictive": ["obs"], "prior": ["mu", "tau", "eta", "obs"]}
fails = check_multiple_attrs(test_dict, idata)
assert not fails

def test_inference_data_only_posterior(self, data):
idata = from_pyro(data.obj)
test_dict = {"posterior": ["mu", "tau", "eta"], "sample_stats": ["diverging"]}
fails = check_multiple_attrs(test_dict, idata)
assert not fails

@pytest.mark.skipif(
packaging.version.parse(pyro.__version__) < packaging.version.parse("1.0.0"),
reason="requires pyro 1.0.0 or higher",
)
def test_inference_data_only_posterior_has_log_likelihood(self, data):
idata = from_pyro(data.obj)
test_dict = {"sample_stats": ["log_likelihood"]}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
# pylint: disable=no-member, invalid-name, redefined-outer-name
import numpy as np
import packaging
import pytest
import torch
import pyro
from pyro.infer import Predictive

from ..data.io_pyro import from_pyro
from .helpers import ( # pylint: disable=unused-import
chains,
check_multiple_attrs,
draws,
eight_schools_params,
load_cached_models,
)


class TestDataPyro:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
class Data:
obj = load_cached_models(eight_schools_params, draws, chains, "pyro")["pyro"]

return Data

def get_inference_data(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.kernel.model
posterior_predictive = Predictive(model, posterior_samples)(
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
)
prior = Predictive(model, num_samples=500)(
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
)
return from_pyro(
posterior=data.obj,
prior=prior,
posterior_predictive=posterior_predictive,
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)

def test_inference_data(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
test_dict = {
"posterior": ["mu", "tau", "eta"],
"sample_stats": ["diverging"],
"posterior_predictive": ["obs"],
"prior": ["mu", "tau", "eta"],
"prior_predictive": ["obs"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

@pytest.mark.skipif(
packaging.version.parse(pyro.__version__) < packaging.version.parse("1.0.0"),
reason="requires pyro 1.0.0 or higher",
)
def test_inference_data_has_log_likelihood_and_observed_data(self, data):
idata = from_pyro(data.obj)
test_dict = {"log_likelihood": ["obs"], "observed_data": ["obs"]}
fails = check_multiple_attrs(test_dict, idata)
assert not fails

def test_inference_data_no_posterior(self, data, eight_schools_params):
posterior_samples = data.obj.get_samples()
model = data.obj.kernel.model
posterior_predictive = Predictive(model, posterior_samples)(
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
)
prior = Predictive(model, num_samples=500)(
eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
)
idata = from_pyro(
prior=prior,
posterior_predictive=posterior_predictive,
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)
test_dict = {"posterior_predictive": ["obs"], "prior": ["mu", "tau", "eta", "obs"]}
fails = check_multiple_attrs(test_dict, idata)
assert not fails

def test_inference_data_only_posterior(self, data):
idata = from_pyro(data.obj)
test_dict = {"posterior": ["mu", "tau", "eta"], "sample_stats": ["diverging"]}
fails = check_multiple_attrs(test_dict, idata)
assert not fails

@pytest.mark.skipif(
packaging.version.parse(pyro.__version__) < packaging.version.parse("1.0.0"),
reason="requires pyro 1.0.0 or higher",
)
def test_inference_data_only_posterior_has_log_likelihood(self, data):
idata = from_pyro(data.obj)
test_dict = {"log_likelihood": ["obs"]}
fails = check_multiple_attrs(test_dict, idata)
assert not fails

def test_multiple_observed_rv(self):
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

y1 = torch.randn(10)
y2 = torch.randn(10)

def model_example_multiple_obs(y1=None, y2=None):
x = pyro.sample("x", dist.Normal(1, 3))
pyro.sample("y1", dist.Normal(x, 1), obs=y1)
pyro.sample("y2", dist.Normal(x, 1), obs=y2)

nuts_kernel = NUTS(model_example_multiple_obs)
mcmc = MCMC(nuts_kernel, num_samples=10)
mcmc.run(y1=y1, y2=y2)
inference_data = from_pyro(mcmc)
test_dict = {
"posterior": ["x"],
"sample_stats": ["diverging"],
"log_likelihood": ["y1", "y2"],
"observed_data": ["y1", "y2"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert not hasattr(inference_data.sample_stats, "log_likelihood")
4 changes: 2 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
@@ -14,10 +14,10 @@ jobs:
PYRO_VERSION: "latest"
EMCEE_VERSION: "latest"
NAME: "UNIT"
Python_36_Unit_Test_PyStan_3_Pyro_0.5.1_Emcee_2_tf_1:
Python_36_Unit_Test_PyStan_3_Pyro_1.0.0_Emcee_2_tf_1:
PYTHON_VERSION: 3.6
PYSTAN_VERSION: "preview"
PYRO_VERSION: 0.5.1
PYRO_VERSION: 1.0.0
PYTORCH_VERSION: 1.3.0
EMCEE_VERSION: 2
TF_VERSION: 1
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ nbsphinx
numpydoc
pydocstyle<5.0
pylint
pyro-ppl>=0.5.1
pyro-ppl>=1.0.0
tensorflow
tensorflow-probability
pytest

0 comments on commit a82ae30

Please sign in to comment.