diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index bef8415953..d6a1895a34 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -7,7 +7,7 @@ dependencies: # Base dependencies - aeppl=0.0.38 - aesara=2.8.7 -- arviz>=0.12.0 +- arviz>=0.13.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 5547af1f2a..2dea663920 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -7,7 +7,7 @@ dependencies: # Base dependencies - aeppl=0.0.38 - aesara=2.8.7 -- arviz>=0.12.0 +- arviz>=0.13.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 7fdf2c6bde..6a7012eea7 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -7,7 +7,7 @@ dependencies: # Base dependencies (see install guide for Windows) - aeppl=0.0.38 - aesara=2.8.7 -- arviz>=0.12.0 +- arviz>=0.13.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 6747477218..a510c4d4aa 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -7,7 +7,7 @@ dependencies: # Base dependencies (see install guide for Windows) - aeppl=0.0.38 - aesara=2.8.7 -- arviz>=0.12.0 +- arviz>=0.13.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 7939378c97..1ff2c66e24 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -7,6 +7,7 @@ Any, Dict, Iterable, + List, Mapping, Optional, Sequence, @@ -15,7 +16,6 @@ ) import numpy as np -import xarray as xr from aesara.graph.basic import Constant from aesara.tensor.sharedvar import SharedVariable @@ -162,6 +162,7 @@ def __init__( predictions=None, coords: Optional[CoordSpec] = None, dims: Optional[DimSpec] = None, + sample_dims: Optional[List] = None, model=None, save_warmup: Optional[bool] = None, include_transformed: bool = False, @@ -225,6 +226,9 @@ def __init__( for var_name, dims in self.model.RV_dims.items() } self.dims = {**model_dims, **self.dims} + if sample_dims is None: + sample_dims = ["chain", "draw"] + self.sample_dims = sample_dims self.observations = find_observations(self.model) @@ -423,36 +427,27 @@ def log_likelihood_to_xarray(self): ), ) - def translate_posterior_predictive_dict_to_xarray(self, dct, kind) -> xr.Dataset: - """Take Dict of variables to numpy ndarrays (samples) and translate into dataset.""" - data = {} - warning_vars = [] - for k, ary in dct.items(): - if (ary.shape[0] == self.nchains) and (ary.shape[1] == self.ndraws): - data[k] = ary - else: - data[k] = np.expand_dims(ary, 0) - warning_vars.append(k) - if warning_vars: - warnings.warn( - f"The shape of variables {', '.join(warning_vars)} in {kind} group is not compatible " - "with number of chains and draws. The automatic dimension naming might not have worked. " - "This can also mean that some draws or even whole chains are not represented.", - UserWarning, - ) - return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=pymc, coords=self.coords, dims=self.dims, default_dims=self.sample_dims + ) @requires(["posterior_predictive"]) def posterior_predictive_to_xarray(self): """Convert posterior_predictive samples to xarray.""" - return self.translate_posterior_predictive_dict_to_xarray( - self.posterior_predictive, "posterior_predictive" + data = self.posterior_predictive + dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} + return dict_to_dataset( + data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims ) @requires(["predictions"]) def predictions_to_xarray(self): """Convert predictions (out of sample predictions) to xarray.""" - return self.translate_posterior_predictive_dict_to_xarray(self.predictions, "predictions") + data = self.predictions + dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} + return dict_to_dataset( + data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + ) def priors_to_xarray(self): """Convert prior samples (and if possible prior predictive too) to xarray.""" @@ -541,6 +536,7 @@ def to_inference_data( log_likelihood: Union[bool, Iterable[str]] = True, coords: Optional[CoordSpec] = None, dims: Optional[DimSpec] = None, + sample_dims: Optional[List] = None, model: Optional["Model"] = None, save_warmup: Optional[bool] = None, include_transformed: bool = False, @@ -594,6 +590,7 @@ def to_inference_data( log_likelihood=log_likelihood, coords=coords, dims=dims, + sample_dims=sample_dims, model=model, save_warmup=save_warmup, include_transformed=include_transformed, @@ -608,6 +605,7 @@ def predictions_to_inference_data( model: Optional["Model"] = None, coords: Optional[CoordSpec] = None, dims: Optional[DimSpec] = None, + sample_dims: Optional[List] = None, idata_orig: Optional[InferenceData] = None, inplace: bool = False, ) -> InferenceData: @@ -653,6 +651,7 @@ def predictions_to_inference_data( model=model, coords=coords, dims=dims, + sample_dims=sample_dims, log_likelihood=False, ) if hasattr(idata_orig, "posterior"): diff --git a/pymc/sampling.py b/pymc/sampling.py index 8d0379add3..41ac56921b 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -75,7 +75,6 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( - chains_and_samples, dataset_to_point_list, get_default_varnames, get_untransformed_name, @@ -1765,6 +1764,7 @@ def sample_posterior_predictive( trace, model: Optional[Model] = None, var_names: Optional[List[str]] = None, + sample_dims: Optional[List[str]] = None, random_seed: RandomState = None, progressbar: bool = True, return_inferencedata: bool = True, @@ -1785,6 +1785,10 @@ def sample_posterior_predictive( generally be the model used to generate the ``trace``, but it doesn't need to be. var_names : Iterable[str] Names of variables for which to compute the posterior predictive samples. + sample_dims : list of str, optional + Dimensions over which to loop and generate posterior predictive samples. + When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample + dimensions. Only taken into account when `trace` is InferenceData or Dataset. random_seed : int, RandomState or Generator, optional Seed for the random number generator. progressbar : bool @@ -1821,6 +1825,14 @@ def sample_posterior_predictive( thinned_idata = idata.sel(draw=slice(None, None, 5)) with model: idata.extend(pymc.sample_posterior_predictive(thinned_idata)) + + Generate 5 posterior predictive samples per posterior sample. + + .. code:: python + + expanded_data = idata.posterior.expand_dims(pred_id=5) + with model: + idata.extend(pymc.sample_posterior_predictive(expanded_data)) """ _trace: Union[MultiTrace, PointList] @@ -1829,36 +1841,34 @@ def sample_posterior_predictive( idata_kwargs = {} else: idata_kwargs = idata_kwargs.copy() + if sample_dims is None: + sample_dims = ["chain", "draw"] constant_data: Dict[str, np.ndarray] = {} trace_coords: Dict[str, np.ndarray] = {} if "coords" not in idata_kwargs: idata_kwargs["coords"] = {} + idata: Optional[InferenceData] = None + stacked_dims = None if isinstance(trace, InferenceData): - idata_kwargs["coords"].setdefault("draw", trace["posterior"]["draw"]) - idata_kwargs["coords"].setdefault("chain", trace["posterior"]["chain"]) _constant_data = getattr(trace, "constant_data", None) if _constant_data is not None: trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) constant_data.update({str(k): v.data for k, v in _constant_data.items()}) - trace_coords.update({str(k): v.data for k, v in trace["posterior"].coords.items()}) - _trace = dataset_to_point_list(trace["posterior"]) - nchain, len_trace = chains_and_samples(trace) - elif isinstance(trace, xarray.Dataset): - idata_kwargs["coords"].setdefault("draw", trace["draw"]) - idata_kwargs["coords"].setdefault("chain", trace["chain"]) + idata = trace + trace = trace["posterior"] + if isinstance(trace, xarray.Dataset): trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) - _trace = dataset_to_point_list(trace) - nchain, len_trace = chains_and_samples(trace) + _trace, stacked_dims = dataset_to_point_list(trace, sample_dims) + nchain = 1 elif isinstance(trace, MultiTrace): _trace = trace nchain = _trace.nchains - len_trace = len(_trace) elif isinstance(trace, list) and all(isinstance(x, dict) for x in trace): _trace = trace nchain = 1 - len_trace = len(_trace) else: raise TypeError(f"Unsupported type for `trace` argument: {type(trace)}.") + len_trace = len(_trace) if isinstance(_trace, MultiTrace): samples = sum(len(v) for v in _trace._straces.values()) @@ -1961,23 +1971,30 @@ def sample_posterior_predictive( ppc_trace = ppc_trace_t.trace_dict for k, ary in ppc_trace.items(): - ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:])) + if stacked_dims is not None: + ppc_trace[k] = ary.reshape( + (*[len(coord) for coord in stacked_dims.values()], *ary.shape[1:]) + ) + else: + ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:])) if not return_inferencedata: return ppc_trace ikwargs: Dict[str, Any] = dict(model=model, **idata_kwargs) + ikwargs.setdefault("sample_dims", sample_dims) + if stacked_dims is not None: + coords = ikwargs.get("coords", {}) + ikwargs["coords"] = {**stacked_dims, **coords} if predictions: if extend_inferencedata: - ikwargs.setdefault("idata_orig", trace) + ikwargs.setdefault("idata_orig", idata) ikwargs.setdefault("inplace", True) return pm.predictions_to_inference_data(ppc_trace, **ikwargs) - converter = pm.backends.arviz.InferenceDataConverter(posterior_predictive=ppc_trace, **ikwargs) - converter.nchains = nchain - converter.ndraws = len_trace - idata_pp = converter.to_inference_data() - if extend_inferencedata: - trace.extend(idata_pp) - return trace + idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs) + + if extend_inferencedata and idata is not None: + idata.extend(idata_pp) + return idata return idata_pp diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 86516ff4b8..5efb34b89a 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -1621,6 +1621,22 @@ def test_aesara_function_kwargs(self): assert np.all(pp["y"] == np.arange(5) * 2) + def test_sample_dims(self, point_list_arg_bug_fixture): + pmodel, trace = point_list_arg_bug_fixture + with pmodel: + post = pm.to_inference_data(trace).posterior.stack(sample=["chain", "draw"]) + pp = pm.sample_posterior_predictive(post, var_names=["d"], sample_dims=["sample"]) + assert "sample" in pp.posterior_predictive + assert len(pp.posterior_predictive["sample"]) == len(post["sample"]) + post = post.expand_dims(pred_id=5) + pp = pm.sample_posterior_predictive( + post, var_names=["d"], sample_dims=["sample", "pred_id"] + ) + assert "sample" in pp.posterior_predictive + assert "pred_id" in pp.posterior_predictive + assert len(pp.posterior_predictive["sample"]) == len(post["sample"]) + assert len(pp.posterior_predictive["pred_id"]) == 5 + class TestDraw(SeededTest): def test_univariate(self): diff --git a/pymc/tests/test_util.py b/pymc/tests/test_util.py index 1a0cfe1d17..570c070b78 100644 --- a/pymc/tests/test_util.py +++ b/pymc/tests/test_util.py @@ -154,7 +154,7 @@ def fn(a=UNSET): def test_dataset_to_point_list(): ds = xarray.Dataset() ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw")) - pl = dataset_to_point_list(ds) + pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"]) assert isinstance(pl, list) assert len(pl) == 6 assert isinstance(pl[0], dict) @@ -163,4 +163,4 @@ def test_dataset_to_point_list(): # Check that non-str keys are caught ds[3] = xarray.DataArray([1, 2, 3]) with pytest.raises(ValueError, match="must be str"): - dataset_to_point_list(ds) + dataset_to_point_list(ds, sample_dims=["chain", "draw"]) diff --git a/pymc/util.py b/pymc/util.py index 8fc6f3cf7d..bb126fc62d 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -14,9 +14,8 @@ import functools -from typing import Dict, Hashable, List, Tuple, Union, cast +from typing import Any, Dict, List, Tuple, cast -import arviz import cloudpickle import numpy as np import xarray @@ -231,38 +230,26 @@ def enhanced(*args, **kwargs): return enhanced -def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]: +def dataset_to_point_list( + ds: xarray.Dataset, sample_dims: List +) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]: # All keys of the dataset must be a str - for vn in ds.keys(): + var_names = list(ds.keys()) + for vn in var_names: if not isinstance(vn, str): raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.") - # make dicts - points: List[Dict[Hashable, np.ndarray]] = [] - da: "xarray.DataArray" - for c in ds.chain: - for d in ds.draw: - points.append({vn: da.sel(chain=c, draw=d).values for vn, da in ds.items()}) + num_sample_dims = len(sample_dims) + stacked_dims = {dim_name: ds[dim_name] for dim_name in sample_dims} + ds = ds.transpose(*sample_dims, ...) + stacked_dict = { + vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) for vn, da in ds.items() + } + points = [ + {vn: stacked_dict[vn][i, ...] for vn in var_names} + for i in range(np.product([len(coords) for coords in stacked_dims.values()])) + ] # use the list of points - return cast(List[Dict[str, np.ndarray]], points) - - -def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tuple[int, int]: - """Extract and return number of chains and samples in xarray or arviz traces.""" - dataset: xarray.Dataset - if isinstance(data, xarray.Dataset): - dataset = data - elif isinstance(data, arviz.InferenceData): - dataset = data["posterior"] - else: - raise ValueError( - "Argument must be xarray Dataset or arviz InferenceData. Got %s", - data.__class__, - ) - - coords = dataset.coords - nchains = coords["chain"].sizes["chain"] - nsamples = coords["draw"].sizes["draw"] - return nchains, nsamples + return cast(List[Dict[str, np.ndarray]], points), stacked_dims def hashable(a=None) -> int: diff --git a/requirements-dev.txt b/requirements-dev.txt index 5255ddfc3d..9475a128f4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ aeppl==0.0.38 aesara==2.8.7 -arviz>=0.12.0 +arviz>=0.13.0 cachetools>=4.2.1 cloudpickle fastprogress>=0.2.0 diff --git a/requirements.txt b/requirements.txt index ef0af21c8a..c99dea2315 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ aeppl==0.0.38 aesara==2.8.7 -arviz>=0.12.0 +arviz>=0.13.0 cachetools>=4.2.1 cloudpickle fastprogress>=0.2.0