Skip to content

Commit

Permalink
Update posterior predictive sampling and notebook (#5268)
Browse files Browse the repository at this point in the history
* improve idata-sample_posterior_predictive integration and add thinning example to docstring
* remove plot_posterior_predictive_glm (arviz.plot_lm should be used now)

Co-authored-by: Michael Osthege <michael.osthege@outlook.com>
  • Loading branch information
OriolAbril and michaelosthege authored Dec 20, 2021
1 parent c200802 commit cf95a78
Show file tree
Hide file tree
Showing 15 changed files with 4,092 additions and 510 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ jobs:
--ignore=pymc/tests/test_dist_math.py
--ignore=pymc/tests/test_minibatches.py
--ignore=pymc/tests/test_pickling.py
--ignore=pymc/tests/test_plots.py
--ignore=pymc/tests/test_updates.py
--ignore=pymc/tests/test_gp.py
--ignore=pymc/tests/test_model.py
Expand All @@ -68,7 +67,6 @@ jobs:
pymc/tests/test_dist_math.py
pymc/tests/test_minibatches.py
pymc/tests/test_pickling.py
pymc/tests/test_plots.py
pymc/tests/test_updates.py
pymc/tests/test_transforms.py
Expand Down
4,161 changes: 3,939 additions & 222 deletions docs/source/learn/examples/posterior_predictive.ipynb

Large diffs are not rendered by default.

74 changes: 29 additions & 45 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def __init__(
dims: Optional[DimSpec] = None,
model=None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
):

self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
Expand Down Expand Up @@ -175,29 +174,11 @@ def __init__(
self.log_likelihood = log_likelihood
self.predictions = predictions

def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
return next(iter(dct.values()))

if trace is None:
# if you have a posterior_predictive built with keep_dims,
# you'll lose here, but there's nothing I can do about that.
self.nchains = 1
get_from = None
if predictions is not None:
get_from = predictions
elif posterior_predictive is not None:
get_from = posterior_predictive
elif prior is not None:
get_from = prior
if get_from is None:
# pylint: disable=line-too-long
raise ValueError(
"When constructing InferenceData must have at least"
" one of trace, prior, posterior_predictive or predictions."
)

aelem = arbitrary_element(get_from)
self.ndraws = aelem.shape[0]
if all(elem is None for elem in (trace, predictions, posterior_predictive, prior)):
raise ValueError(
"When constructing InferenceData you must pass at least"
" one of trace, prior, posterior_predictive or predictions."
)

self.coords = {**self.model.coords, **(coords or {})}
self.coords = {
Expand All @@ -214,7 +195,6 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
}
self.dims = {**model_dims, **self.dims}

self.density_dist_obs = density_dist_obs
self.observations = find_observations(self.model)

def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
Expand Down Expand Up @@ -396,34 +376,35 @@ def log_likelihood_to_xarray(self):
),
)

def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
def translate_posterior_predictive_dict_to_xarray(self, dct, kind) -> xr.Dataset:
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
data = {}
warning_vars = []
for k, ary in dct.items():
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
if (ary.shape[0] == self.nchains) and (ary.shape[1] == self.ndraws):
data[k] = ary
elif shape[0] == self.nchains * self.ndraws:
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
else:
data[k] = np.expand_dims(ary, 0)
# pylint: disable=line-too-long
_log.warning(
"posterior predictive variable %s's shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented.",
k,
)
warning_vars.append(k)
warnings.warn(
f"The shape of variables {', '.join(warning_vars)} in {kind} group is not compatible "
"with number of chains and draws. The automatic dimension naming might not have worked. "
"This can also mean that some draws or even whole chains are not represented.",
UserWarning,
)
return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)

@requires(["posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.posterior_predictive)
return self.translate_posterior_predictive_dict_to_xarray(
self.posterior_predictive, "posterior_predictive"
)

@requires(["predictions"])
def predictions_to_xarray(self):
"""Convert predictions (out of sample predictions) to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.predictions)
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, "predictions")

def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
Expand Down Expand Up @@ -545,7 +526,6 @@ def to_inference_data(
dims: Optional[DimSpec] = None,
model: Optional["Model"] = None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
) -> InferenceData:
"""Convert pymc data into an InferenceData object.
Expand Down Expand Up @@ -578,9 +558,6 @@ def to_inference_data(
save_warmup : bool, optional
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.
density_dist_obs : bool, default True
Store variables passed with ``observed`` arg to
:class:`~pymc.distributions.DensityDist` in the generated InferenceData.
Returns
-------
Expand All @@ -598,7 +575,6 @@ def to_inference_data(
dims=dims,
model=model,
save_warmup=save_warmup,
density_dist_obs=density_dist_obs,
).to_inference_data()


Expand All @@ -620,6 +596,7 @@ def predictions_to_inference_data(
predictions: Dict[str, np.ndarray]
The predictions are the return value of :func:`~pymc.sample_posterior_predictive`,
a dictionary of strings (variable names) to numpy ndarrays (draws).
Requires the arrays to follow the convention ``chain, draw, *shape``.
posterior_trace: MultiTrace
This should be a trace that has been thinned appropriately for
``pymc.sample_posterior_predictive``. Specifically, any variable whose shape is
Expand Down Expand Up @@ -648,14 +625,21 @@ def predictions_to_inference_data(
raise ValueError(
"Do not pass True for inplace unless passing" "an existing InferenceData as idata_orig"
)
new_idata = InferenceDataConverter(
converter = InferenceDataConverter(
trace=posterior_trace,
predictions=predictions,
model=model,
coords=coords,
dims=dims,
log_likelihood=False,
).to_inference_data()
)
if hasattr(idata_orig, "posterior"):
converter.nchains = idata_orig.posterior.dims["chain"]
converter.ndraws = idata_orig.posterior.dims["draw"]
else:
aelem = next(iter(predictions.values()))
converter.nchains, converter.ndraws = aelem.shape[:2]
new_idata = converter.to_inference_data()
if idata_orig is None:
return new_idata
elif inplace:
Expand Down
3 changes: 0 additions & 3 deletions pymc/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def wrapped(*args, **kwargs):
compareplot = alias_deprecation(az.plot_compare, alias="compareplot")


from pymc.plots.posteriorplot import plot_posterior_predictive_glm

__all__ = tuple(az.plots.__all__) + (
"autocorrplot",
"compareplot",
Expand All @@ -67,5 +65,4 @@ def wrapped(*args, **kwargs):
"energyplot",
"densityplot",
"pairplot",
"plot_posterior_predictive_glm",
)
86 changes: 0 additions & 86 deletions pymc/plots/posteriorplot.py

This file was deleted.

Loading

0 comments on commit cf95a78

Please sign in to comment.