Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update posterior predictive sampling and notebook #5268

Merged
merged 9 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ jobs:
--ignore=pymc/tests/test_dist_math.py
--ignore=pymc/tests/test_minibatches.py
--ignore=pymc/tests/test_pickling.py
--ignore=pymc/tests/test_plots.py
--ignore=pymc/tests/test_updates.py
--ignore=pymc/tests/test_gp.py
--ignore=pymc/tests/test_model.py
Expand All @@ -68,7 +67,6 @@ jobs:
pymc/tests/test_dist_math.py
pymc/tests/test_minibatches.py
pymc/tests/test_pickling.py
pymc/tests/test_plots.py
pymc/tests/test_updates.py
pymc/tests/test_transforms.py

Expand Down
4,161 changes: 3,939 additions & 222 deletions docs/source/learn/examples/posterior_predictive.ipynb

Large diffs are not rendered by default.

74 changes: 29 additions & 45 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def __init__(
dims: Optional[DimSpec] = None,
model=None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
):

self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
Expand Down Expand Up @@ -175,29 +174,11 @@ def __init__(
self.log_likelihood = log_likelihood
self.predictions = predictions

def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
return next(iter(dct.values()))

if trace is None:
# if you have a posterior_predictive built with keep_dims,
# you'll lose here, but there's nothing I can do about that.
self.nchains = 1
get_from = None
if predictions is not None:
get_from = predictions
elif posterior_predictive is not None:
get_from = posterior_predictive
elif prior is not None:
get_from = prior
if get_from is None:
# pylint: disable=line-too-long
raise ValueError(
"When constructing InferenceData must have at least"
" one of trace, prior, posterior_predictive or predictions."
)

aelem = arbitrary_element(get_from)
self.ndraws = aelem.shape[0]
if all(elem is None for elem in (trace, predictions, posterior_predictive, prior)):
raise ValueError(
"When constructing InferenceData you must pass at least"
" one of trace, prior, posterior_predictive or predictions."
)

self.coords = {**self.model.coords, **(coords or {})}
self.coords = {
Expand All @@ -214,7 +195,6 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
}
self.dims = {**model_dims, **self.dims}

self.density_dist_obs = density_dist_obs
self.observations = find_observations(self.model)

def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
Expand Down Expand Up @@ -396,34 +376,35 @@ def log_likelihood_to_xarray(self):
),
)

def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
def translate_posterior_predictive_dict_to_xarray(self, dct, kind) -> xr.Dataset:
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
data = {}
warning_vars = []
for k, ary in dct.items():
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
if (ary.shape[0] == self.nchains) and (ary.shape[1] == self.ndraws):
data[k] = ary
elif shape[0] == self.nchains * self.ndraws:
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
else:
data[k] = np.expand_dims(ary, 0)
# pylint: disable=line-too-long
_log.warning(
"posterior predictive variable %s's shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented.",
k,
)
warning_vars.append(k)
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OriolAbril I think you intended for this warning to only be printed if warnings_vars was not empty. Is this correct? I am currently always getting the output:

/Users/thomas/repo/pymc/pymc/backends/arviz.py:389: UserWarning: The shape of variables  in posterior_predictive group is not compatible with number of chains and draws. The automatic dimension naming might not have worked. This can also mean that some draws or even whole chains are not represented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be extra clear: The warning prints a space in place of {', '.join(warning_vars)}.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry I also missed it in my review.
It needs to be indented into an if warning_vars.
Thanks for reporting this! Do you want to open a PR? Otherwise I'll do it in the morning

f"The shape of variables {', '.join(warning_vars)} in {kind} group is not compatible "
"with number of chains and draws. The automatic dimension naming might not have worked. "
"This can also mean that some draws or even whole chains are not represented.",
UserWarning,
)
return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)

@requires(["posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.posterior_predictive)
return self.translate_posterior_predictive_dict_to_xarray(
self.posterior_predictive, "posterior_predictive"
)

@requires(["predictions"])
def predictions_to_xarray(self):
"""Convert predictions (out of sample predictions) to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.predictions)
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, "predictions")

def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
Expand Down Expand Up @@ -545,7 +526,6 @@ def to_inference_data(
dims: Optional[DimSpec] = None,
model: Optional["Model"] = None,
save_warmup: Optional[bool] = None,
density_dist_obs: bool = True,
) -> InferenceData:
"""Convert pymc data into an InferenceData object.

Expand Down Expand Up @@ -578,9 +558,6 @@ def to_inference_data(
save_warmup : bool, optional
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.
density_dist_obs : bool, default True
Store variables passed with ``observed`` arg to
:class:`~pymc.distributions.DensityDist` in the generated InferenceData.

Returns
-------
Expand All @@ -598,7 +575,6 @@ def to_inference_data(
dims=dims,
model=model,
save_warmup=save_warmup,
density_dist_obs=density_dist_obs,
).to_inference_data()


Expand All @@ -620,6 +596,7 @@ def predictions_to_inference_data(
predictions: Dict[str, np.ndarray]
The predictions are the return value of :func:`~pymc.sample_posterior_predictive`,
a dictionary of strings (variable names) to numpy ndarrays (draws).
Requires the arrays to follow the convention ``chain, draw, *shape``.
posterior_trace: MultiTrace
This should be a trace that has been thinned appropriately for
``pymc.sample_posterior_predictive``. Specifically, any variable whose shape is
Expand Down Expand Up @@ -648,14 +625,21 @@ def predictions_to_inference_data(
raise ValueError(
"Do not pass True for inplace unless passing" "an existing InferenceData as idata_orig"
)
new_idata = InferenceDataConverter(
converter = InferenceDataConverter(
trace=posterior_trace,
predictions=predictions,
model=model,
coords=coords,
dims=dims,
log_likelihood=False,
).to_inference_data()
)
if hasattr(idata_orig, "posterior"):
converter.nchains = idata_orig.posterior.dims["chain"]
converter.ndraws = idata_orig.posterior.dims["draw"]
else:
aelem = next(iter(predictions.values()))
converter.nchains, converter.ndraws = aelem.shape[:2]
new_idata = converter.to_inference_data()
if idata_orig is None:
return new_idata
elif inplace:
Expand Down
3 changes: 0 additions & 3 deletions pymc/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def wrapped(*args, **kwargs):
compareplot = alias_deprecation(az.plot_compare, alias="compareplot")


from pymc.plots.posteriorplot import plot_posterior_predictive_glm

__all__ = tuple(az.plots.__all__) + (
"autocorrplot",
"compareplot",
Expand All @@ -67,5 +65,4 @@ def wrapped(*args, **kwargs):
"energyplot",
"densityplot",
"pairplot",
"plot_posterior_predictive_glm",
)
86 changes: 0 additions & 86 deletions pymc/plots/posteriorplot.py

This file was deleted.

Loading