diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 56622eb9039..802f40f816b 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -8,6 +8,7 @@ - `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info. - Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section. - Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models. +- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846)) - `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827)) ### Maintenance diff --git a/pymc3/distributions/posterior_predictive.py b/pymc3/distributions/posterior_predictive.py index c012329d3e3..46eff017b6d 100644 --- a/pymc3/distributions/posterior_predictive.py +++ b/pymc3/distributions/posterior_predictive.py @@ -12,12 +12,14 @@ import numpy as np import theano import theano.tensor as tt +from xarray import Dataset from ..backends.base import MultiTrace #, TraceLike, TraceDict from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext from ..exceptions import IncorrectArgumentsError from ..vartypes import theano_constant +from ..util import dataset_to_point_dict # Failing tests: # test_mixture_random_shape::test_mixture_random_shape # @@ -119,7 +121,7 @@ def __getitem__(self, item): -def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.ndarray]]], +def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]], samples: Optional[int]=None, model: Optional[Model]=None, var_names: Optional[List[str]]=None, @@ -135,7 +137,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np. Parameters ---------- - trace : MultiTrace or List of points + trace : MultiTrace, xarray.Dataset, or List of points (dictionary) Trace generated from MCMC sampling. samples : int, optional Number of posterior predictive samples to generate. Defaults to one posterior predictive @@ -168,6 +170,9 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np. ### greater than the number of samples in the trace parameter, we sample repeatedly. This ### makes the shape issues just a little easier to deal with. + if isinstance(trace, Dataset): + trace = dataset_to_point_dict(trace) + model = modelcontext(model) assert model is not None with model: diff --git a/pymc3/sampling.py b/pymc3/sampling.py index b0f920c672c..ca1773fb27e 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -30,6 +30,7 @@ import numpy as np import theano.gradient as tg from theano.tensor import Tensor +import xarray from .backends.base import BaseTrace, MultiTrace from .backends.ndarray import NDArray @@ -53,6 +54,7 @@ get_untransformed_name, is_transformed_name, get_default_varnames, + dataset_to_point_dict, ) from .vartypes import discrete_types from .exceptions import IncorrectArgumentsError @@ -1520,9 +1522,9 @@ def sample_posterior_predictive( Parameters ---------- - trace: backend, list, or MultiTrace - Trace generated from MCMC sampling. Or a list containing dicts from - find_MAP() or points + trace: backend, list, xarray.Dataset, or MultiTrace + Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()), + or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior) samples: int Number of posterior predictive samples to generate. Defaults to one posterior predictive sample per posterior sample, that is, the number of draws times the number of chains. It @@ -1556,6 +1558,9 @@ def sample_posterior_predictive( Dictionary with the variable names as keys, and values numpy arrays containing posterior predictive samples. """ + if isinstance(trace, xarray.Dataset): + trace = dataset_to_point_dict(trace) + len_trace = len(trace) try: nchain = trace.nchains diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 8d1bf00bd44..6c0eece0379 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -22,6 +22,7 @@ import mock import numpy.testing as npt +import arviz as az import pymc3 as pm import theano.tensor as tt from theano import shared @@ -880,3 +881,32 @@ def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture): var_names=['d'] ) + def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture): + pmodel, trace = point_list_arg_bug_fixture + + with pmodel: + prior = pm.sample_prior_predictive(samples=20) + idat = az.from_pymc3(trace, prior=prior) + with pmodel: + pp = pm.sample_posterior_predictive( + idat.prior, + var_names=['d'] + ) + + def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture): + pmodel, trace = point_list_arg_bug_fixture + idat = az.from_pymc3(trace) + with pmodel: + pp = pm.sample_posterior_predictive( + idat.posterior, + var_names=['d'] + ) + + def test_sample_from_xarray_posterior_fast(self, point_list_arg_bug_fixture): + pmodel, trace = point_list_arg_bug_fixture + idat = az.from_pymc3(trace) + with pmodel: + pp = pm.fast_sample_posterior_predictive( + idat.posterior, + var_names=['d'] + ) diff --git a/pymc3/util.py b/pymc3/util.py index a9bd3c5af6e..18a78aed97d 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -14,7 +14,11 @@ import re import functools -from numpy import asscalar +from typing import List, Dict + +import xarray +from numpy import asscalar, ndarray + LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE) @@ -179,3 +183,21 @@ def enhanced(*args, **kwargs): newwrapper = functools.partial(wrapper, *args, **kwargs) return newwrapper return enhanced + +def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]: + # grab posterior samples for each variable + _samples = { + vn : ds[vn].values + for vn in ds.keys() + } + # make dicts + points = [] + for c in ds.chain: + for d in ds.draw: + points.append({ + vn : s[c, d] + for vn, s in _samples.items() + }) + # use the list of points + ds = points + return ds