Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support xarray input to sample_posterior_predictive #3846

Merged
merged 4 commits into from
Mar 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846))
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))

### Maintenance
Expand Down
9 changes: 7 additions & 2 deletions pymc3/distributions/posterior_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import numpy as np
import theano
import theano.tensor as tt
from xarray import Dataset

from ..backends.base import MultiTrace #, TraceLike, TraceDict
from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc
from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext
from ..exceptions import IncorrectArgumentsError
from ..vartypes import theano_constant
from ..util import dataset_to_point_dict
# Failing tests:
# test_mixture_random_shape::test_mixture_random_shape
#
Expand Down Expand Up @@ -119,7 +121,7 @@ def __getitem__(self, item):



def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.ndarray]]],
def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]],
samples: Optional[int]=None,
model: Optional[Model]=None,
var_names: Optional[List[str]]=None,
Expand All @@ -135,7 +137,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.

Parameters
----------
trace : MultiTrace or List of points
trace : MultiTrace, xarray.Dataset, or List of points (dictionary)
Trace generated from MCMC sampling.
samples : int, optional
Number of posterior predictive samples to generate. Defaults to one posterior predictive
Expand Down Expand Up @@ -168,6 +170,9 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
### greater than the number of samples in the trace parameter, we sample repeatedly. This
### makes the shape issues just a little easier to deal with.

if isinstance(trace, Dataset):
trace = dataset_to_point_dict(trace)

model = modelcontext(model)
assert model is not None
with model:
Expand Down
11 changes: 8 additions & 3 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import numpy as np
import theano.gradient as tg
from theano.tensor import Tensor
import xarray

from .backends.base import BaseTrace, MultiTrace
from .backends.ndarray import NDArray
Expand All @@ -53,6 +54,7 @@
get_untransformed_name,
is_transformed_name,
get_default_varnames,
dataset_to_point_dict,
)
from .vartypes import discrete_types
from .exceptions import IncorrectArgumentsError
Expand Down Expand Up @@ -1520,9 +1522,9 @@ def sample_posterior_predictive(

Parameters
----------
trace: backend, list, or MultiTrace
Trace generated from MCMC sampling. Or a list containing dicts from
find_MAP() or points
trace: backend, list, xarray.Dataset, or MultiTrace
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
samples: int
Number of posterior predictive samples to generate. Defaults to one posterior predictive
sample per posterior sample, that is, the number of draws times the number of chains. It
Expand Down Expand Up @@ -1556,6 +1558,9 @@ def sample_posterior_predictive(
Dictionary with the variable names as keys, and values numpy arrays containing
posterior predictive samples.
"""
if isinstance(trace, xarray.Dataset):
trace = dataset_to_point_dict(trace)

len_trace = len(trace)
try:
nchain = trace.nchains
Expand Down
30 changes: 30 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import mock

import numpy.testing as npt
import arviz as az
import pymc3 as pm
import theano.tensor as tt
from theano import shared
Expand Down Expand Up @@ -880,3 +881,32 @@ def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
var_names=['d']
)

def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture

with pmodel:
prior = pm.sample_prior_predictive(samples=20)
idat = az.from_pymc3(trace, prior=prior)
with pmodel:
pp = pm.sample_posterior_predictive(
idat.prior,
var_names=['d']
)

def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
idat = az.from_pymc3(trace)
with pmodel:
pp = pm.sample_posterior_predictive(
idat.posterior,
var_names=['d']
)

def test_sample_from_xarray_posterior_fast(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
idat = az.from_pymc3(trace)
with pmodel:
pp = pm.fast_sample_posterior_predictive(
idat.posterior,
var_names=['d']
)
24 changes: 23 additions & 1 deletion pymc3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

import re
import functools
from numpy import asscalar
from typing import List, Dict

import xarray
from numpy import asscalar, ndarray


LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE)

Expand Down Expand Up @@ -179,3 +183,21 @@ def enhanced(*args, **kwargs):
newwrapper = functools.partial(wrapper, *args, **kwargs)
return newwrapper
return enhanced

def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
# grab posterior samples for each variable
_samples = {
vn : ds[vn].values
for vn in ds.keys()
}
# make dicts
points = []
for c in ds.chain:
for d in ds.draw:
points.append({
vn : s[c, d]
for vn, s in _samples.items()
})
# use the list of points
ds = points
return ds