Skip to content

Commit

Permalink
Fix posterior pred. sampling keep_size w/ arviz input.
Browse files Browse the repository at this point in the history
Previously posterior predictive sampling functions did not properly
handle the `keep_size` keyword argument when getting an xarray Dataset
as parameter.

Also extended these functions to accept InferenceData object as input.
  • Loading branch information
rpgoldman committed Jul 10, 2020
1 parent 90f48ed commit de2d614
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 50 deletions.
49 changes: 34 additions & 15 deletions pymc3/distributions/posterior_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import logging
from collections import UserDict
from contextlib import AbstractContextManager

if TYPE_CHECKING:
import contextvars # noqa: F401
from typing import Set
from typing_extensions import Protocol
from typing_extensions import Protocol, Literal

import numpy as np
import theano
Expand All @@ -17,9 +18,13 @@
from ..backends.base import MultiTrace #, TraceLike, TraceDict
from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc
from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext
from arviz import InferenceData

from ..backends.base import MultiTrace # , TraceLike, TraceDict
from ..exceptions import IncorrectArgumentsError
from ..vartypes import theano_constant
from ..util import dataset_to_point_dict
from ..util import dataset_to_point_dict, chains_and_samples

# Failing tests:
# test_mixture_random_shape::test_mixture_random_shape
#
Expand Down Expand Up @@ -121,12 +126,14 @@ def __getitem__(self, item):



def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]],
samples: Optional[int]=None,
model: Optional[Model]=None,
var_names: Optional[List[str]]=None,
keep_size: bool=False,
random_seed=None) -> Dict[str, np.ndarray]:
def fast_sample_posterior_predictive(
trace: Union[MultiTrace, Dataset, InferenceData, List[Dict[str, np.ndarray]]],
samples: Optional[int] = None,
model: Optional[Model] = None,
var_names: Optional[List[str]] = None,
keep_size: bool = False,
random_seed=None,
) -> Dict[str, np.ndarray]:
"""Generate posterior predictive samples from a model given a trace.
This is a vectorized alternative to the standard ``sample_posterior_predictive`` function.
Expand All @@ -137,7 +144,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict
Parameters
----------
trace: MultiTrace, xarray.Dataset, or List of points (dictionary)
trace: MultiTrace, xarray.Dataset, InferenceData, or List of points (dictionary)
Trace generated from MCMC sampling.
samples: int, optional
Number of posterior predictive samples to generate. Defaults to one posterior predictive
Expand Down Expand Up @@ -170,21 +177,33 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict
### greater than the number of samples in the trace parameter, we sample repeatedly. This
### makes the shape issues just a little easier to deal with.

if isinstance(trace, Dataset):
if isinstance(trace, InferenceData):
nchains, ndraws = chains_and_samples(trace)
trace = dataset_to_point_dict(trace.posterior)
elif isinstance(trace, Dataset):
nchains, ndraws = chains_and_samples(trace)
trace = dataset_to_point_dict(trace)
elif isinstance(trace, MultiTrace):
nchains = trace.nchains
ndraws = len(trace)
else:
if keep_size:
# arguably this should be just a warning.
raise IncorrectArgumentsError(
"For keep_size, cannot identify chains and length from %s.", trace
)

model = modelcontext(model)
assert model is not None
with model:

if keep_size and samples is not None:
raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments")
if keep_size and not isinstance(trace, MultiTrace):
# arguably this should be just a warning.
raise IncorrectArgumentsError("keep_size argument only applies when sampling from MultiTrace.")
raise IncorrectArgumentsError(
"Should not specify both keep_size and samples arguments"
)

if isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
_trace = _TraceDict(point_list=trace)
_trace = _TraceDict(point_list=trace)
elif isinstance(trace, MultiTrace):
_trace = _TraceDict(multi_trace=trace)
else:
Expand Down
79 changes: 56 additions & 23 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import warnings

import arviz
from arviz import InferenceData
import numpy as np
import theano.gradient as tg
from theano.tensor import Tensor
Expand Down Expand Up @@ -57,6 +58,7 @@
is_transformed_name,
get_default_varnames,
dataset_to_point_dict,
chains_and_samples,
)
from .vartypes import discrete_types
from .exceptions import IncorrectArgumentsError
Expand Down Expand Up @@ -91,6 +93,8 @@
)

ArrayLike = Union[np.ndarray, List[float]]
PointType = Dict[str, np.ndarray]
PointList = List[PointType]

_log = logging.getLogger("pymc3")

Expand Down Expand Up @@ -248,10 +252,10 @@ def sample(
callback=None,
*,
return_inferencedata=None,
idata_kwargs: dict=None,
idata_kwargs: dict = None,
mp_ctx=None,
pickle_backend: str = 'pickle',
**kwargs
pickle_backend: str = "pickle",
**kwargs,
):
"""Draw samples from the posterior using the given step methods.
Expand Down Expand Up @@ -1584,7 +1588,7 @@ def sample_posterior_predictive(
Parameters
----------
trace : backend, list, xarray.Dataset, or MultiTrace
trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
samples : int
Expand All @@ -1598,8 +1602,7 @@ def sample_posterior_predictive(
Variables for which to compute the posterior predictive samples.
Deprecated: please use ``var_names`` instead.
var_names : Iterable[str]
Alternative way to specify vars to sample, to make this function orthogonal with
others.
Names of variables for which to compute the posterior predictive samples.
size : int
The number of random draws from the distribution specified by the parameters in each
sample of the trace. Not recommended unless more than ndraws times nchains posterior
Expand All @@ -1620,29 +1623,48 @@ def sample_posterior_predictive(
Dictionary with the variable names as keys, and values numpy arrays containing
posterior predictive samples.
"""
if isinstance(trace, xarray.Dataset):
trace = dataset_to_point_dict(trace)

len_trace = len(trace)
try:
nchain = trace.nchains
except AttributeError:
nchain = 1
_trace: Union[MultiTrace, PointList]
if isinstance(trace, InferenceData):
_trace = dataset_to_point_dict(trace.posterior)
elif isinstance(trace, xarray.Dataset):
_trace = dataset_to_point_dict(trace)
else:
_trace = trace

nchain: int
len_trace: int
if isinstance(trace, (InferenceData, xarray.Dataset)):
nchain, len_trace = chains_and_samples(trace)
else:
len_trace = len(_trace)
try:
nchain = _trace.nchains
except AttributeError:
nchain = 1

if keep_size and samples is not None:
raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments")
raise IncorrectArgumentsError(
"Should not specify both keep_size and samples arguments"
)
if keep_size and size is not None:
raise IncorrectArgumentsError("Should not specify both keep_size and size arguments")
raise IncorrectArgumentsError(
"Should not specify both keep_size and size arguments"
)

if samples is None:
if isinstance(trace, MultiTrace):
samples = sum(len(v) for v in trace._straces.values())
elif isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
if isinstance(_trace, MultiTrace):
samples = sum(len(v) for v in _trace._straces.values())
elif isinstance(_trace, list) and all((isinstance(x, dict) for x in _trace)):
# this is a list of points
samples = len(trace)
samples = len(_trace)
else:
raise ValueError("Do not know how to compute number of samples for trace argument of type %s"%type(trace))
raise ValueError(
"Do not know how to compute number of samples for trace argument of type %s"
% type(_trace)
)

assert samples is not None
if samples < len_trace * nchain:
warnings.warn(
"samples parameter is smaller than nchains times ndraws, some draws "
Expand Down Expand Up @@ -1675,10 +1697,21 @@ def sample_posterior_predictive(
try:
for idx in indices:
if nchain > 1:
chain_idx, point_idx = np.divmod(idx, len_trace)
param = trace._straces[chain_idx % nchain].point(point_idx)
# the trace object will either be a MultiTrace (and have _straces)...
if hasattr(_trace, "_straces"):
chain_idx, point_idx = np.divmod(idx, len_trace)
param = (
cast(MultiTrace, _trace)
._straces[chain_idx % nchain]
.point(point_idx)
)
# ... or a PointList
else:
param = cast(PointList, _trace)[idx % len_trace]
# there's only a single chain, but the index might hit it multiple times if
# the number of indices is greater than the length of the trace.
else:
param = trace[idx % len_trace]
param = _trace[idx % len_trace]

values = draw_values(vars, point=param, size=size)
for k, v in zip(vars, values):
Expand Down
24 changes: 24 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,13 @@ def test_normal_scalar(self):
ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True)
assert ppc["a"].shape == (nchains, ndraws)

# test keep_size parameter and idata input
idata = az.from_pymc3(trace)
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
assert ppc["a"].shape == (nchains, ndraws)
ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True)
assert ppc["a"].shape == (nchains, ndraws)

# test default case
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
assert "a" in ppc
Expand Down Expand Up @@ -428,6 +435,15 @@ def test_normal_vector(self, caplog):
assert "a" in ppc
assert ppc["a"].shape == (12, 2)

# test keep_size parameter with inference data as input...
idata = az.from_pymc3(trace)
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
with pytest.warns(UserWarning):
ppc = pm.sample_posterior_predictive(trace, samples=12, var_names=["a"])
assert "a" in ppc
assert ppc["a"].shape == (12, 2)

# test keep_size parameter
ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True)
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
Expand All @@ -436,6 +452,14 @@ def test_normal_vector(self, caplog):
assert "a" in ppc
assert ppc["a"].shape == (12, 2)

# test keep_size parameter with inference data as input
ppc = pm.fast_sample_posterior_predictive(idata, keep_size=True)
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
with pytest.warns(UserWarning):
ppc = pm.fast_sample_posterior_predictive(trace, samples=12, var_names=["a"])
assert "a" in ppc
assert ppc["a"].shape == (12, 2)


# size unsupported by fast_ version argument. [2019/08/19:rpg]
ppc = pm.sample_posterior_predictive(trace, samples=10, var_names=["a"], size=4)
Expand Down
45 changes: 33 additions & 12 deletions pymc3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import re
import functools
from typing import List, Dict
from typing import List, Dict, Tuple, Union

import xarray
import arviz
from numpy import asscalar, ndarray


Expand Down Expand Up @@ -182,22 +183,42 @@ def enhanced(*args, **kwargs):
else:
newwrapper = functools.partial(wrapper, *args, **kwargs)
return newwrapper

return enhanced


# FIXME: this function is poorly named, because it returns a LIST of
# points, not a dictionary of points.
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
# grab posterior samples for each variable
_samples = {
vn : ds[vn].values
for vn in ds.keys()
}
_samples: Dict[str, ndarray] = {vn: ds[vn].values for vn in ds.keys()}
# make dicts
points = []
points: List[Dict[str, ndarray]] = []
vn: str
s: ndarray
for c in ds.chain:
for d in ds.draw:
points.append({
vn : s[c, d]
for vn, s in _samples.items()
})
points.append({vn: s[c, d] for vn, s in _samples.items()})
# use the list of points
ds = points
return ds
return points


def chains_and_samples(
data: Union[xarray.Dataset, arviz.InferenceData]
) -> Tuple[int, int]:
"""Extract and return number of chains and samples in xarray or arviz traces."""
dataset: xarray.Dataset
if isinstance(data, xarray.Dataset):
dataset = data
elif isinstance(data, arviz.InferenceData):
dataset = data.posterior
else:
raise ValueError(
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
data.__class__,
)

coords = dataset.coords
nchains = coords["chain"].sizes["chain"]
nsamples = coords["draw"].sizes["draw"]
return nchains, nsamples

0 comments on commit de2d614

Please sign in to comment.