Skip to content

Commit

Permalink
speed up posterior predictive sampling (#6208)
Browse files Browse the repository at this point in the history
* refactor` dataset_to_point_list` for higher performance
* allow specifying `sample_dims` in posterior predictive
  • Loading branch information
OriolAbril authored Oct 27, 2022
1 parent a025059 commit 570e6e8
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 82 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
# Base dependencies
- aeppl=0.0.38
- aesara=2.8.7
- arviz>=0.12.0
- arviz>=0.13.0
- blas
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
# Base dependencies
- aeppl=0.0.38
- aesara=2.8.7
- arviz>=0.12.0
- arviz>=0.13.0
- blas
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
# Base dependencies (see install guide for Windows)
- aeppl=0.0.38
- aesara=2.8.7
- arviz>=0.12.0
- arviz>=0.13.0
- blas
- cachetools>=4.2.1
- cloudpickle
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
# Base dependencies (see install guide for Windows)
- aeppl=0.0.38
- aesara=2.8.7
- arviz>=0.12.0
- arviz>=0.13.0
- blas
- cachetools>=4.2.1
- cloudpickle
Expand Down
43 changes: 21 additions & 22 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Expand All @@ -15,7 +16,6 @@
)

import numpy as np
import xarray as xr

from aesara.graph.basic import Constant
from aesara.tensor.sharedvar import SharedVariable
Expand Down Expand Up @@ -162,6 +162,7 @@ def __init__(
predictions=None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
sample_dims: Optional[List] = None,
model=None,
save_warmup: Optional[bool] = None,
include_transformed: bool = False,
Expand Down Expand Up @@ -225,6 +226,9 @@ def __init__(
for var_name, dims in self.model.RV_dims.items()
}
self.dims = {**model_dims, **self.dims}
if sample_dims is None:
sample_dims = ["chain", "draw"]
self.sample_dims = sample_dims

self.observations = find_observations(self.model)

Expand Down Expand Up @@ -423,36 +427,27 @@ def log_likelihood_to_xarray(self):
),
)

def translate_posterior_predictive_dict_to_xarray(self, dct, kind) -> xr.Dataset:
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
data = {}
warning_vars = []
for k, ary in dct.items():
if (ary.shape[0] == self.nchains) and (ary.shape[1] == self.ndraws):
data[k] = ary
else:
data[k] = np.expand_dims(ary, 0)
warning_vars.append(k)
if warning_vars:
warnings.warn(
f"The shape of variables {', '.join(warning_vars)} in {kind} group is not compatible "
"with number of chains and draws. The automatic dimension naming might not have worked. "
"This can also mean that some draws or even whole chains are not represented.",
UserWarning,
)
return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=self.dims, default_dims=self.sample_dims
)

@requires(["posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(
self.posterior_predictive, "posterior_predictive"
data = self.posterior_predictive
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
)

@requires(["predictions"])
def predictions_to_xarray(self):
"""Convert predictions (out of sample predictions) to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, "predictions")
data = self.predictions
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
)

def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
Expand Down Expand Up @@ -541,6 +536,7 @@ def to_inference_data(
log_likelihood: Union[bool, Iterable[str]] = True,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
sample_dims: Optional[List] = None,
model: Optional["Model"] = None,
save_warmup: Optional[bool] = None,
include_transformed: bool = False,
Expand Down Expand Up @@ -594,6 +590,7 @@ def to_inference_data(
log_likelihood=log_likelihood,
coords=coords,
dims=dims,
sample_dims=sample_dims,
model=model,
save_warmup=save_warmup,
include_transformed=include_transformed,
Expand All @@ -608,6 +605,7 @@ def predictions_to_inference_data(
model: Optional["Model"] = None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
sample_dims: Optional[List] = None,
idata_orig: Optional[InferenceData] = None,
inplace: bool = False,
) -> InferenceData:
Expand Down Expand Up @@ -653,6 +651,7 @@ def predictions_to_inference_data(
model=model,
coords=coords,
dims=dims,
sample_dims=sample_dims,
log_likelihood=False,
)
if hasattr(idata_orig, "posterior"):
Expand Down
61 changes: 39 additions & 22 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
chains_and_samples,
dataset_to_point_list,
get_default_varnames,
get_untransformed_name,
Expand Down Expand Up @@ -1765,6 +1764,7 @@ def sample_posterior_predictive(
trace,
model: Optional[Model] = None,
var_names: Optional[List[str]] = None,
sample_dims: Optional[List[str]] = None,
random_seed: RandomState = None,
progressbar: bool = True,
return_inferencedata: bool = True,
Expand All @@ -1785,6 +1785,10 @@ def sample_posterior_predictive(
generally be the model used to generate the ``trace``, but it doesn't need to be.
var_names : Iterable[str]
Names of variables for which to compute the posterior predictive samples.
sample_dims : list of str, optional
Dimensions over which to loop and generate posterior predictive samples.
When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample
dimensions. Only taken into account when `trace` is InferenceData or Dataset.
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.
progressbar : bool
Expand Down Expand Up @@ -1821,6 +1825,14 @@ def sample_posterior_predictive(
thinned_idata = idata.sel(draw=slice(None, None, 5))
with model:
idata.extend(pymc.sample_posterior_predictive(thinned_idata))
Generate 5 posterior predictive samples per posterior sample.
.. code:: python
expanded_data = idata.posterior.expand_dims(pred_id=5)
with model:
idata.extend(pymc.sample_posterior_predictive(expanded_data))
"""

_trace: Union[MultiTrace, PointList]
Expand All @@ -1829,36 +1841,34 @@ def sample_posterior_predictive(
idata_kwargs = {}
else:
idata_kwargs = idata_kwargs.copy()
if sample_dims is None:
sample_dims = ["chain", "draw"]
constant_data: Dict[str, np.ndarray] = {}
trace_coords: Dict[str, np.ndarray] = {}
if "coords" not in idata_kwargs:
idata_kwargs["coords"] = {}
idata: Optional[InferenceData] = None
stacked_dims = None
if isinstance(trace, InferenceData):
idata_kwargs["coords"].setdefault("draw", trace["posterior"]["draw"])
idata_kwargs["coords"].setdefault("chain", trace["posterior"]["chain"])
_constant_data = getattr(trace, "constant_data", None)
if _constant_data is not None:
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
trace_coords.update({str(k): v.data for k, v in trace["posterior"].coords.items()})
_trace = dataset_to_point_list(trace["posterior"])
nchain, len_trace = chains_and_samples(trace)
elif isinstance(trace, xarray.Dataset):
idata_kwargs["coords"].setdefault("draw", trace["draw"])
idata_kwargs["coords"].setdefault("chain", trace["chain"])
idata = trace
trace = trace["posterior"]
if isinstance(trace, xarray.Dataset):
trace_coords.update({str(k): v.data for k, v in trace.coords.items()})
_trace = dataset_to_point_list(trace)
nchain, len_trace = chains_and_samples(trace)
_trace, stacked_dims = dataset_to_point_list(trace, sample_dims)
nchain = 1
elif isinstance(trace, MultiTrace):
_trace = trace
nchain = _trace.nchains
len_trace = len(_trace)
elif isinstance(trace, list) and all(isinstance(x, dict) for x in trace):
_trace = trace
nchain = 1
len_trace = len(_trace)
else:
raise TypeError(f"Unsupported type for `trace` argument: {type(trace)}.")
len_trace = len(_trace)

if isinstance(_trace, MultiTrace):
samples = sum(len(v) for v in _trace._straces.values())
Expand Down Expand Up @@ -1961,23 +1971,30 @@ def sample_posterior_predictive(
ppc_trace = ppc_trace_t.trace_dict

for k, ary in ppc_trace.items():
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
if stacked_dims is not None:
ppc_trace[k] = ary.reshape(
(*[len(coord) for coord in stacked_dims.values()], *ary.shape[1:])
)
else:
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))

if not return_inferencedata:
return ppc_trace
ikwargs: Dict[str, Any] = dict(model=model, **idata_kwargs)
ikwargs.setdefault("sample_dims", sample_dims)
if stacked_dims is not None:
coords = ikwargs.get("coords", {})
ikwargs["coords"] = {**stacked_dims, **coords}
if predictions:
if extend_inferencedata:
ikwargs.setdefault("idata_orig", trace)
ikwargs.setdefault("idata_orig", idata)
ikwargs.setdefault("inplace", True)
return pm.predictions_to_inference_data(ppc_trace, **ikwargs)
converter = pm.backends.arviz.InferenceDataConverter(posterior_predictive=ppc_trace, **ikwargs)
converter.nchains = nchain
converter.ndraws = len_trace
idata_pp = converter.to_inference_data()
if extend_inferencedata:
trace.extend(idata_pp)
return trace
idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)

if extend_inferencedata and idata is not None:
idata.extend(idata_pp)
return idata
return idata_pp


Expand Down
16 changes: 16 additions & 0 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,22 @@ def test_aesara_function_kwargs(self):

assert np.all(pp["y"] == np.arange(5) * 2)

def test_sample_dims(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
with pmodel:
post = pm.to_inference_data(trace).posterior.stack(sample=["chain", "draw"])
pp = pm.sample_posterior_predictive(post, var_names=["d"], sample_dims=["sample"])
assert "sample" in pp.posterior_predictive
assert len(pp.posterior_predictive["sample"]) == len(post["sample"])
post = post.expand_dims(pred_id=5)
pp = pm.sample_posterior_predictive(
post, var_names=["d"], sample_dims=["sample", "pred_id"]
)
assert "sample" in pp.posterior_predictive
assert "pred_id" in pp.posterior_predictive
assert len(pp.posterior_predictive["sample"]) == len(post["sample"])
assert len(pp.posterior_predictive["pred_id"]) == 5


class TestDraw(SeededTest):
def test_univariate(self):
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def fn(a=UNSET):
def test_dataset_to_point_list():
ds = xarray.Dataset()
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
pl = dataset_to_point_list(ds)
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
assert isinstance(pl, list)
assert len(pl) == 6
assert isinstance(pl[0], dict)
Expand All @@ -163,4 +163,4 @@ def test_dataset_to_point_list():
# Check that non-str keys are caught
ds[3] = xarray.DataArray([1, 2, 3])
with pytest.raises(ValueError, match="must be str"):
dataset_to_point_list(ds)
dataset_to_point_list(ds, sample_dims=["chain", "draw"])
47 changes: 17 additions & 30 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

import functools

from typing import Dict, Hashable, List, Tuple, Union, cast
from typing import Any, Dict, List, Tuple, cast

import arviz
import cloudpickle
import numpy as np
import xarray
Expand Down Expand Up @@ -231,38 +230,26 @@ def enhanced(*args, **kwargs):
return enhanced


def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
def dataset_to_point_list(
ds: xarray.Dataset, sample_dims: List
) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]:
# All keys of the dataset must be a str
for vn in ds.keys():
var_names = list(ds.keys())
for vn in var_names:
if not isinstance(vn, str):
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
# make dicts
points: List[Dict[Hashable, np.ndarray]] = []
da: "xarray.DataArray"
for c in ds.chain:
for d in ds.draw:
points.append({vn: da.sel(chain=c, draw=d).values for vn, da in ds.items()})
num_sample_dims = len(sample_dims)
stacked_dims = {dim_name: ds[dim_name] for dim_name in sample_dims}
ds = ds.transpose(*sample_dims, ...)
stacked_dict = {
vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) for vn, da in ds.items()
}
points = [
{vn: stacked_dict[vn][i, ...] for vn in var_names}
for i in range(np.product([len(coords) for coords in stacked_dims.values()]))
]
# use the list of points
return cast(List[Dict[str, np.ndarray]], points)


def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tuple[int, int]:
"""Extract and return number of chains and samples in xarray or arviz traces."""
dataset: xarray.Dataset
if isinstance(data, xarray.Dataset):
dataset = data
elif isinstance(data, arviz.InferenceData):
dataset = data["posterior"]
else:
raise ValueError(
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
data.__class__,
)

coords = dataset.coords
nchains = coords["chain"].sizes["chain"]
nsamples = coords["draw"].sizes["draw"]
return nchains, nsamples
return cast(List[Dict[str, np.ndarray]], points), stacked_dims


def hashable(a=None) -> int:
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

aeppl==0.0.38
aesara==2.8.7
arviz>=0.12.0
arviz>=0.13.0
cachetools>=4.2.1
cloudpickle
fastprogress>=0.2.0
Expand Down
Loading

0 comments on commit 570e6e8

Please sign in to comment.