Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support warmup groups in from_pymc3 #1171

Merged
merged 5 commits into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* Add `num_chains` and `pred_dims` arguments to from_pyro and from_numpyro (#1090, #1125)
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
* New grayscale style. This also add two new cmaps `cet_grey_r` and `cet_grey_r`. These are perceptually uniform gray scale cmaps from colorcet (linear_grey_10_95_c0) (#1164)
* Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171)

### Maintenance and fixes
* Changed `diagonal` argument for `marginals` and fixed `point_estimate_marker_kwargs` in `plot_pair` (#1167)
Expand Down
2 changes: 1 addition & 1 deletion arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"predictions_constant_data",
]

WARMUP_TAG = "_warmup_"
WARMUP_TAG = "warmup_"

SUPPORTED_GROUPS_WARMUP = [
"{}posterior".format(WARMUP_TAG),
Expand Down
10 changes: 5 additions & 5 deletions arviz/data/inference_data.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class InferenceData:
prior: Optional[xr.Dataset]
prior_predictive: Optional[xr.Dataset]
sample_stats_prior: Optional[xr.Dataset]
_warmup_posterior: Optional[xr.Dataset]
_warmup_posterior_predictive: Optional[xr.Dataset]
_warmup_predictions: Optional[xr.Dataset]
_warmup_log_likelihood: Optional[xr.Dataset]
_warmup_sample_stats: Optional[xr.Dataset]
warmup_posterior: Optional[xr.Dataset]
warmup_posterior_predictive: Optional[xr.Dataset]
warmup_predictions: Optional[xr.Dataset]
warmup_log_likelihood: Optional[xr.Dataset]
warmup_sample_stats: Optional[xr.Dataset]
def __init__(self, **kwargs): ...
def __repr__(self) -> str: ...
def __delattr__(self, group: str) -> None: ...
Expand Down
102 changes: 78 additions & 24 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .. import utils
from .inference_data import InferenceData, concat
from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs, CoordSpec, DimSpec
from ..rcparams import rcParams

if TYPE_CHECKING:
import pymc3 as pm
Expand Down Expand Up @@ -60,7 +61,8 @@ def __init__(
predictions=None,
coords: Optional[Coords] = None,
dims: Optional[Dims] = None,
model=None
model=None,
save_warmup: Optional[bool] = None,
):
import pymc3
import theano
Expand All @@ -70,6 +72,7 @@ def __init__(
self.pymc3 = pymc3
self.theano = theano

self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
self.trace = trace

# this permits us to get the model from command-line argument or from with model:
Expand All @@ -78,26 +81,34 @@ def __init__(
except TypeError:
self.model = None

if self.model is None:
warnings.warn(
"Using `from_pymc3` without the model will be deprecated in a future release. "
"Not using the model will return less accurate and less useful results. "
"Make sure you use the model argument or call from_pymc3 within a model context.",
PendingDeprecationWarning,
)

# This next line is brittle and may not work forever, but is a secret
# way to access the model from the trace.
self.attrs = None
if trace is not None:
if self.model is None:
self.model = list(self.trace._straces.values())[ # pylint: disable=protected-access
0
].model
self.nchains = trace.nchains if hasattr(trace, "nchains") else 1
self.ndraws = len(trace)
if hasattr(trace.report, "n_tune"):
self.ndraws = trace.report.n_draws
self.attrs = {
"sampling_time": trace.report.t_sampling,
"tuning_steps": trace.report.n_tune,
}
else:
self.nchains = len(trace)
else:
self.nchains = self.ndraws = 0

if self.model is None:
warnings.warn(
"Using `from_pymc3` without the model will be deprecated in a future release. "
"Not using the model will return less accurate and less useful results. "
"Make sure you use the model argument or call from_pymc3 within a model context.",
PendingDeprecationWarning,
)

self.prior = prior
self.posterior_predictive = posterior_predictive
self.log_likelihood = log_likelihood
Expand Down Expand Up @@ -151,7 +162,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):

@requires("trace")
@requires("model")
def _extract_log_likelihood(self):
def _extract_log_likelihood(self, trace):
"""Compute log likelihood of each observation."""
# If we have predictions, then we have a thinned trace which does not
# support extracting a log likelihood.
Expand All @@ -165,44 +176,76 @@ def _extract_log_likelihood(self):
]
try:
log_likelihood_dict = self.pymc3.sampling._DefaultTrace( # pylint: disable=protected-access
len(self.trace.chains)
len(trace.chains)
)
except AttributeError:
raise AttributeError(
"Installed version of ArviZ requires PyMC3>=3.8. Please upgrade with "
"`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`."
)
for var, log_like_fun in cached:
for chain in self.trace.chains:
for chain in trace.chains:
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in self.trace.points([chain])
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), chain)
return log_likelihood_dict.trace_dict

@requires("trace")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
var_names = self.pymc3.util.get_default_varnames( # pylint: disable=no-member
var_names = self.pymc3.util.get_default_varnames(
self.trace.varnames, include_transformed=False
)
data = {}
data_warmup = {}
for var_name in var_names:
data[var_name] = np.array(self.trace.get_values(var_name, combine=False, squeeze=False))
return dict_to_dataset(data, library=self.pymc3, coords=self.coords, dims=self.dims)
if self.save_warmup:
data_warmup[var_name] = np.array(
self.trace[: -self.ndraws].get_values(var_name, combine=False, squeeze=False)
)
data[var_name] = np.array(
self.trace[-self.ndraws :].get_values(var_name, combine=False, squeeze=False)
)
return (
dict_to_dataset(
data, library=self.pymc3, coords=self.coords, dims=self.dims, attrs=self.attrs
),
dict_to_dataset(
data_warmup,
library=self.pymc3,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
),
)

@requires("trace")
def sample_stats_to_xarray(self):
"""Extract sample_stats from PyMC3 trace."""
data = {}
rename_key = {"model_logp": "lp"}
data = {}
data_warmup = {}
for stat in self.trace.stat_names:
name = rename_key.get(stat, stat)
data[name] = np.array(self.trace.get_sampler_stats(stat, combine=False))

return dict_to_dataset(data, library=self.pymc3, dims=None, coords=self.coords)
if name == "tune":
continue
if self.save_warmup:
data_warmup[name] = np.array(
self.trace[: -self.ndraws].get_sampler_stats(stat, combine=False)
)
data[name] = np.array(self.trace[-self.ndraws :].get_sampler_stats(stat, combine=False))

return (
dict_to_dataset(
data, library=self.pymc3, dims=None, coords=self.coords, attrs=self.attrs
),
dict_to_dataset(
data_warmup, library=self.pymc3, dims=None, coords=self.coords, attrs=self.attrs
),
)

@requires("trace")
@requires("model")
Expand All @@ -211,14 +254,20 @@ def log_likelihood_to_xarray(self):
if self.predictions or not self.log_likelihood:
return None
try:
data = self._extract_log_likelihood()
data = self._extract_log_likelihood(self.trace[-self.ndraws :])
except TypeError:
warnings.warn(
"""Could not compute log_likelihood, it will be omitted.
Check your model object or set log_likelihood=False"""
)
return None
return dict_to_dataset(data, library=self.pymc3, dims=self.dims, coords=self.coords)
data_warmup = {}
if self.save_warmup:
data_warmup = self._extract_log_likelihood(self.trace[: -self.ndraws])
return (
dict_to_dataset(data, library=self.pymc3, dims=self.dims, coords=self.coords),
dict_to_dataset(data_warmup, library=self.pymc3, dims=self.dims, coords=self.coords),
)

def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
Expand Down Expand Up @@ -375,7 +424,7 @@ def to_inference_data(self):
id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
else:
id_dict["constant_data"] = self.constant_data_to_xarray()
return InferenceData(**id_dict)
return InferenceData(save_warmup=self.save_warmup, **id_dict)


def from_pymc3(
Expand All @@ -386,7 +435,8 @@ def from_pymc3(
log_likelihood: Union[bool, Iterable[str]] = True,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
model: Optional[Model] = None
model: Optional[Model] = None,
save_warmup: Optional[bool] = None,
) -> InferenceData:
"""Convert pymc3 data into an InferenceData object.

Expand All @@ -413,6 +463,9 @@ def from_pymc3(
model : pymc3.Model, optional
Model used to generate ``trace``. It is not necessary to pass ``model`` if in
``with`` context.
save_warmup : bool, optional
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.

Returns
-------
Expand All @@ -426,6 +479,7 @@ def from_pymc3(
coords=coords,
dims=dims,
model=model,
save_warmup=save_warmup,
).to_inference_data()


Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def to_cds(
"posterior_groups_warmup"}
- posterior_groups: posterior, posterior_predictive, sample_stats
- prior_groups: prior, prior_predictive, sample_stats_prior
- posterior_groups_warmup: _warmup_posterior, _warmup_posterior_predictive,
_warmup_sample_stats
- posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive,
warmup_sample_stats
ignore_groups : str or list of str, optional
Ignore specific groups from CDS.
dimension : str, or list of str, optional
Expand Down
32 changes: 32 additions & 0 deletions arviz/tests/external_tests/test_data_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,35 @@ def test_no_model_deprecation(self):
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

@pytest.mark.parametrize("save_warmup", [False, True])
def test_save_warmup(self, save_warmup):
with pm.Model():
pm.Uniform("u1")
pm.Normal("n1")
trace = pm.sample(
tune=100,
draws=200,
chains=2,
cores=1,
step=pm.Metropolis(),
discard_tuned_samples=False,
)
assert isinstance(trace, pm.backends.base.MultiTrace)
idata = from_pymc3(trace, save_warmup=save_warmup)
prefix = "" if save_warmup else "~"
test_dict = {
"posterior": ["u1", "n1"],
"sample_stats": ["~tune", "accept"],
f"{prefix}warmup_posterior": ["u1", "n1"],
f"{prefix}warmup_sample_stats": ["~tune"],
"~warmup_log_likelihood": [],
"~log_likelihood": [],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == 200
if save_warmup:
assert idata.warmup_posterior.dims["chain"] == 2
assert idata.warmup_posterior.dims["draw"] == 100
2 changes: 1 addition & 1 deletion arviz/tests/external_tests/test_data_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_inference_data(self, data, eight_schools_params):
}
if pystan_version() == 2:
test_dict.update(
{"_warmup_posterior": ["theta"], "_warmup_sample_stats": ["diverging", "lp"]}
{"warmup_posterior": ["theta"], "warmup_sample_stats": ["diverging", "lp"]}
)
fails = check_multiple_attrs(test_dict, inference_data4)
assert not fails
Expand Down
6 changes: 3 additions & 3 deletions arviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ def flatten_inference_data_to_dict(
{"posterior_groups", "prior_groups", "posterior_groups_warmup"}
- posterior_groups: posterior, posterior_predictive, sample_stats
- prior_groups: prior, prior_predictive, sample_stats_prior
- posterior_groups_warmup: _warmup_posterior, _warmup_posterior_predictive,
_warmup_sample_stats
- posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive,
warmup_sample_stats
ignore_groups : str or list of str, optional
Ignore specific groups from CDS.
dimension : str, or list of str, optional
Expand Down Expand Up @@ -455,7 +455,7 @@ def flatten_inference_data_to_dict(
elif groups.lower() == "prior_groups":
groups = ["prior", "prior_predictive", "sample_stats_prior"]
elif groups.lower() == "posterior_groups_warmup":
groups = ["_warmup_posterior", "_warmup_posterior_predictive", "_warmup_sample_stats"]
groups = ["warmup_posterior", "warmup_posterior_predictive", "warmup_sample_stats"]
else:
raise TypeError(
(
Expand Down
1 change: 1 addition & 0 deletions scripts/container.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ if [[ $* == *--clear-cache* ]]; then
echo "Removing cached files and models"
find -type d -name __pycache__ -exec rm -rf {} +
rm -f arviz/tests/saved_models/*.pkl
rm -f arviz/tests/saved_models/*.pkl.gzip
fi

# Build container for use of testing or notebook
Expand Down