diff --git a/CHANGELOG.md b/CHANGELOG.md index cc0dcd1467..fb08719949 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ * Add `num_chains` and `pred_dims` arguments to from_pyro and from_numpyro (#1090, #1125) * Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) * New grayscale style. This also add two new cmaps `cet_grey_r` and `cet_grey_r`. These are perceptually uniform gray scale cmaps from colorcet (linear_grey_10_95_c0) (#1164) +* Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171) ### Maintenance and fixes * Changed `diagonal` argument for `marginals` and fixed `point_estimate_marker_kwargs` in `plot_pair` (#1167) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index f2c52bdac4..027416efca 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -25,7 +25,7 @@ "predictions_constant_data", ] -WARMUP_TAG = "_warmup_" +WARMUP_TAG = "warmup_" SUPPORTED_GROUPS_WARMUP = [ "{}posterior".format(WARMUP_TAG), diff --git a/arviz/data/inference_data.pyi b/arviz/data/inference_data.pyi index 1b025ef681..3503909b0a 100644 --- a/arviz/data/inference_data.pyi +++ b/arviz/data/inference_data.pyi @@ -19,11 +19,11 @@ class InferenceData: prior: Optional[xr.Dataset] prior_predictive: Optional[xr.Dataset] sample_stats_prior: Optional[xr.Dataset] - _warmup_posterior: Optional[xr.Dataset] - _warmup_posterior_predictive: Optional[xr.Dataset] - _warmup_predictions: Optional[xr.Dataset] - _warmup_log_likelihood: Optional[xr.Dataset] - _warmup_sample_stats: Optional[xr.Dataset] + warmup_posterior: Optional[xr.Dataset] + warmup_posterior_predictive: Optional[xr.Dataset] + warmup_predictions: Optional[xr.Dataset] + warmup_log_likelihood: Optional[xr.Dataset] + warmup_sample_stats: Optional[xr.Dataset] def __init__(self, **kwargs): ... def __repr__(self) -> str: ... def __delattr__(self, group: str) -> None: ... diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index f7900d35f7..3ffcdc7e03 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -9,6 +9,7 @@ from .. import utils from .inference_data import InferenceData, concat from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs, CoordSpec, DimSpec +from ..rcparams import rcParams if TYPE_CHECKING: import pymc3 as pm @@ -60,7 +61,8 @@ def __init__( predictions=None, coords: Optional[Coords] = None, dims: Optional[Dims] = None, - model=None + model=None, + save_warmup: Optional[bool] = None, ): import pymc3 import theano @@ -70,6 +72,7 @@ def __init__( self.pymc3 = pymc3 self.theano = theano + self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup self.trace = trace # this permits us to get the model from command-line argument or from with model: @@ -78,26 +81,34 @@ def __init__( except TypeError: self.model = None + if self.model is None: + warnings.warn( + "Using `from_pymc3` without the model will be deprecated in a future release. " + "Not using the model will return less accurate and less useful results. " + "Make sure you use the model argument or call from_pymc3 within a model context.", + PendingDeprecationWarning, + ) + # This next line is brittle and may not work forever, but is a secret # way to access the model from the trace. + self.attrs = None if trace is not None: if self.model is None: self.model = list(self.trace._straces.values())[ # pylint: disable=protected-access 0 ].model self.nchains = trace.nchains if hasattr(trace, "nchains") else 1 - self.ndraws = len(trace) + if hasattr(trace.report, "n_tune"): + self.ndraws = trace.report.n_draws + self.attrs = { + "sampling_time": trace.report.t_sampling, + "tuning_steps": trace.report.n_tune, + } + else: + self.nchains = len(trace) else: self.nchains = self.ndraws = 0 - if self.model is None: - warnings.warn( - "Using `from_pymc3` without the model will be deprecated in a future release. " - "Not using the model will return less accurate and less useful results. " - "Make sure you use the model argument or call from_pymc3 within a model context.", - PendingDeprecationWarning, - ) - self.prior = prior self.posterior_predictive = posterior_predictive self.log_likelihood = log_likelihood @@ -151,7 +162,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun): @requires("trace") @requires("model") - def _extract_log_likelihood(self): + def _extract_log_likelihood(self, trace): """Compute log likelihood of each observation.""" # If we have predictions, then we have a thinned trace which does not # support extracting a log likelihood. @@ -165,7 +176,7 @@ def _extract_log_likelihood(self): ] try: log_likelihood_dict = self.pymc3.sampling._DefaultTrace( # pylint: disable=protected-access - len(self.trace.chains) + len(trace.chains) ) except AttributeError: raise AttributeError( @@ -173,10 +184,10 @@ def _extract_log_likelihood(self): "`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`." ) for var, log_like_fun in cached: - for chain in self.trace.chains: + for chain in trace.chains: log_like_chain = [ self.log_likelihood_vals_point(point, var, log_like_fun) - for point in self.trace.points([chain]) + for point in trace.points([chain]) ] log_likelihood_dict.insert(var.name, np.stack(log_like_chain), chain) return log_likelihood_dict.trace_dict @@ -184,13 +195,31 @@ def _extract_log_likelihood(self): @requires("trace") def posterior_to_xarray(self): """Convert the posterior to an xarray dataset.""" - var_names = self.pymc3.util.get_default_varnames( # pylint: disable=no-member + var_names = self.pymc3.util.get_default_varnames( self.trace.varnames, include_transformed=False ) data = {} + data_warmup = {} for var_name in var_names: - data[var_name] = np.array(self.trace.get_values(var_name, combine=False, squeeze=False)) - return dict_to_dataset(data, library=self.pymc3, coords=self.coords, dims=self.dims) + if self.save_warmup: + data_warmup[var_name] = np.array( + self.trace[: -self.ndraws].get_values(var_name, combine=False, squeeze=False) + ) + data[var_name] = np.array( + self.trace[-self.ndraws :].get_values(var_name, combine=False, squeeze=False) + ) + return ( + dict_to_dataset( + data, library=self.pymc3, coords=self.coords, dims=self.dims, attrs=self.attrs + ), + dict_to_dataset( + data_warmup, + library=self.pymc3, + coords=self.coords, + dims=self.dims, + attrs=self.attrs, + ), + ) @requires("trace") def sample_stats_to_xarray(self): @@ -198,11 +227,25 @@ def sample_stats_to_xarray(self): data = {} rename_key = {"model_logp": "lp"} data = {} + data_warmup = {} for stat in self.trace.stat_names: name = rename_key.get(stat, stat) - data[name] = np.array(self.trace.get_sampler_stats(stat, combine=False)) - - return dict_to_dataset(data, library=self.pymc3, dims=None, coords=self.coords) + if name == "tune": + continue + if self.save_warmup: + data_warmup[name] = np.array( + self.trace[: -self.ndraws].get_sampler_stats(stat, combine=False) + ) + data[name] = np.array(self.trace[-self.ndraws :].get_sampler_stats(stat, combine=False)) + + return ( + dict_to_dataset( + data, library=self.pymc3, dims=None, coords=self.coords, attrs=self.attrs + ), + dict_to_dataset( + data_warmup, library=self.pymc3, dims=None, coords=self.coords, attrs=self.attrs + ), + ) @requires("trace") @requires("model") @@ -211,14 +254,20 @@ def log_likelihood_to_xarray(self): if self.predictions or not self.log_likelihood: return None try: - data = self._extract_log_likelihood() + data = self._extract_log_likelihood(self.trace[-self.ndraws :]) except TypeError: warnings.warn( """Could not compute log_likelihood, it will be omitted. Check your model object or set log_likelihood=False""" ) return None - return dict_to_dataset(data, library=self.pymc3, dims=self.dims, coords=self.coords) + data_warmup = {} + if self.save_warmup: + data_warmup = self._extract_log_likelihood(self.trace[: -self.ndraws]) + return ( + dict_to_dataset(data, library=self.pymc3, dims=self.dims, coords=self.coords), + dict_to_dataset(data_warmup, library=self.pymc3, dims=self.dims, coords=self.coords), + ) def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset: """Take Dict of variables to numpy ndarrays (samples) and translate into dataset.""" @@ -375,7 +424,7 @@ def to_inference_data(self): id_dict["predictions_constant_data"] = self.constant_data_to_xarray() else: id_dict["constant_data"] = self.constant_data_to_xarray() - return InferenceData(**id_dict) + return InferenceData(save_warmup=self.save_warmup, **id_dict) def from_pymc3( @@ -386,7 +435,8 @@ def from_pymc3( log_likelihood: Union[bool, Iterable[str]] = True, coords: Optional[CoordSpec] = None, dims: Optional[DimSpec] = None, - model: Optional[Model] = None + model: Optional[Model] = None, + save_warmup: Optional[bool] = None, ) -> InferenceData: """Convert pymc3 data into an InferenceData object. @@ -413,6 +463,9 @@ def from_pymc3( model : pymc3.Model, optional Model used to generate ``trace``. It is not necessary to pass ``model`` if in ``with`` context. + save_warmup : bool, optional + Save warmup iterations InferenceData object. If not defined, use default + defined by the rcParams. Returns ------- @@ -426,6 +479,7 @@ def from_pymc3( coords=coords, dims=dims, model=model, + save_warmup=save_warmup, ).to_inference_data() diff --git a/arviz/plots/backends/__init__.py b/arviz/plots/backends/__init__.py index 1151f504ed..2ee57a34d6 100644 --- a/arviz/plots/backends/__init__.py +++ b/arviz/plots/backends/__init__.py @@ -33,8 +33,8 @@ def to_cds( "posterior_groups_warmup"} - posterior_groups: posterior, posterior_predictive, sample_stats - prior_groups: prior, prior_predictive, sample_stats_prior - - posterior_groups_warmup: _warmup_posterior, _warmup_posterior_predictive, - _warmup_sample_stats + - posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive, + warmup_sample_stats ignore_groups : str or list of str, optional Ignore specific groups from CDS. dimension : str, or list of str, optional diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index 5de64db809..ccab273b9d 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -397,3 +397,35 @@ def test_no_model_deprecation(self): } fails = check_multiple_attrs(test_dict, inference_data) assert not fails + + @pytest.mark.parametrize("save_warmup", [False, True]) + def test_save_warmup(self, save_warmup): + with pm.Model(): + pm.Uniform("u1") + pm.Normal("n1") + trace = pm.sample( + tune=100, + draws=200, + chains=2, + cores=1, + step=pm.Metropolis(), + discard_tuned_samples=False, + ) + assert isinstance(trace, pm.backends.base.MultiTrace) + idata = from_pymc3(trace, save_warmup=save_warmup) + prefix = "" if save_warmup else "~" + test_dict = { + "posterior": ["u1", "n1"], + "sample_stats": ["~tune", "accept"], + f"{prefix}warmup_posterior": ["u1", "n1"], + f"{prefix}warmup_sample_stats": ["~tune"], + "~warmup_log_likelihood": [], + "~log_likelihood": [], + } + fails = check_multiple_attrs(test_dict, idata) + assert not fails + assert idata.posterior.dims["chain"] == 2 + assert idata.posterior.dims["draw"] == 200 + if save_warmup: + assert idata.warmup_posterior.dims["chain"] == 2 + assert idata.warmup_posterior.dims["draw"] == 100 diff --git a/arviz/tests/external_tests/test_data_pystan.py b/arviz/tests/external_tests/test_data_pystan.py index fe3dcc8dcc..f144d661ba 100644 --- a/arviz/tests/external_tests/test_data_pystan.py +++ b/arviz/tests/external_tests/test_data_pystan.py @@ -179,7 +179,7 @@ def test_inference_data(self, data, eight_schools_params): } if pystan_version() == 2: test_dict.update( - {"_warmup_posterior": ["theta"], "_warmup_sample_stats": ["diverging", "lp"]} + {"warmup_posterior": ["theta"], "warmup_sample_stats": ["diverging", "lp"]} ) fails = check_multiple_attrs(test_dict, inference_data4) assert not fails diff --git a/arviz/utils.py b/arviz/utils.py index 8eb2268f0a..897295e2b3 100644 --- a/arviz/utils.py +++ b/arviz/utils.py @@ -406,8 +406,8 @@ def flatten_inference_data_to_dict( {"posterior_groups", "prior_groups", "posterior_groups_warmup"} - posterior_groups: posterior, posterior_predictive, sample_stats - prior_groups: prior, prior_predictive, sample_stats_prior - - posterior_groups_warmup: _warmup_posterior, _warmup_posterior_predictive, - _warmup_sample_stats + - posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive, + warmup_sample_stats ignore_groups : str or list of str, optional Ignore specific groups from CDS. dimension : str, or list of str, optional @@ -455,7 +455,7 @@ def flatten_inference_data_to_dict( elif groups.lower() == "prior_groups": groups = ["prior", "prior_predictive", "sample_stats_prior"] elif groups.lower() == "posterior_groups_warmup": - groups = ["_warmup_posterior", "_warmup_posterior_predictive", "_warmup_sample_stats"] + groups = ["warmup_posterior", "warmup_posterior_predictive", "warmup_sample_stats"] else: raise TypeError( ( diff --git a/scripts/container.sh b/scripts/container.sh index 766f999aac..25ee38b482 100755 --- a/scripts/container.sh +++ b/scripts/container.sh @@ -5,6 +5,7 @@ if [[ $* == *--clear-cache* ]]; then echo "Removing cached files and models" find -type d -name __pycache__ -exec rm -rf {} + rm -f arviz/tests/saved_models/*.pkl + rm -f arviz/tests/saved_models/*.pkl.gzip fi # Build container for use of testing or notebook