From de2d6149453251b9b2496627c8fadffe59a83873 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Fri, 10 Jul 2020 11:00:36 -0500 Subject: [PATCH] Fix posterior pred. sampling keep_size w/ arviz input. Previously posterior predictive sampling functions did not properly handle the `keep_size` keyword argument when getting an xarray Dataset as parameter. Also extended these functions to accept InferenceData object as input. --- pymc3/distributions/posterior_predictive.py | 49 +++++++++---- pymc3/sampling.py | 79 +++++++++++++++------ pymc3/tests/test_sampling.py | 24 +++++++ pymc3/util.py | 45 ++++++++---- 4 files changed, 147 insertions(+), 50 deletions(-) diff --git a/pymc3/distributions/posterior_predictive.py b/pymc3/distributions/posterior_predictive.py index 69da02978ee..31a3dea74c0 100644 --- a/pymc3/distributions/posterior_predictive.py +++ b/pymc3/distributions/posterior_predictive.py @@ -4,10 +4,11 @@ import logging from collections import UserDict from contextlib import AbstractContextManager + if TYPE_CHECKING: import contextvars # noqa: F401 from typing import Set -from typing_extensions import Protocol +from typing_extensions import Protocol, Literal import numpy as np import theano @@ -17,9 +18,13 @@ from ..backends.base import MultiTrace #, TraceLike, TraceDict from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext +from arviz import InferenceData + +from ..backends.base import MultiTrace # , TraceLike, TraceDict from ..exceptions import IncorrectArgumentsError from ..vartypes import theano_constant -from ..util import dataset_to_point_dict +from ..util import dataset_to_point_dict, chains_and_samples + # Failing tests: # test_mixture_random_shape::test_mixture_random_shape # @@ -121,12 +126,14 @@ def __getitem__(self, item): -def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]], - samples: Optional[int]=None, - model: Optional[Model]=None, - var_names: Optional[List[str]]=None, - keep_size: bool=False, - random_seed=None) -> Dict[str, np.ndarray]: +def fast_sample_posterior_predictive( + trace: Union[MultiTrace, Dataset, InferenceData, List[Dict[str, np.ndarray]]], + samples: Optional[int] = None, + model: Optional[Model] = None, + var_names: Optional[List[str]] = None, + keep_size: bool = False, + random_seed=None, +) -> Dict[str, np.ndarray]: """Generate posterior predictive samples from a model given a trace. This is a vectorized alternative to the standard ``sample_posterior_predictive`` function. @@ -137,7 +144,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict Parameters ---------- - trace: MultiTrace, xarray.Dataset, or List of points (dictionary) + trace: MultiTrace, xarray.Dataset, InferenceData, or List of points (dictionary) Trace generated from MCMC sampling. samples: int, optional Number of posterior predictive samples to generate. Defaults to one posterior predictive @@ -170,21 +177,33 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict ### greater than the number of samples in the trace parameter, we sample repeatedly. This ### makes the shape issues just a little easier to deal with. - if isinstance(trace, Dataset): + if isinstance(trace, InferenceData): + nchains, ndraws = chains_and_samples(trace) + trace = dataset_to_point_dict(trace.posterior) + elif isinstance(trace, Dataset): + nchains, ndraws = chains_and_samples(trace) trace = dataset_to_point_dict(trace) + elif isinstance(trace, MultiTrace): + nchains = trace.nchains + ndraws = len(trace) + else: + if keep_size: + # arguably this should be just a warning. + raise IncorrectArgumentsError( + "For keep_size, cannot identify chains and length from %s.", trace + ) model = modelcontext(model) assert model is not None with model: if keep_size and samples is not None: - raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments") - if keep_size and not isinstance(trace, MultiTrace): - # arguably this should be just a warning. - raise IncorrectArgumentsError("keep_size argument only applies when sampling from MultiTrace.") + raise IncorrectArgumentsError( + "Should not specify both keep_size and samples arguments" + ) if isinstance(trace, list) and all((isinstance(x, dict) for x in trace)): - _trace = _TraceDict(point_list=trace) + _trace = _TraceDict(point_list=trace) elif isinstance(trace, MultiTrace): _trace = _TraceDict(multi_trace=trace) else: diff --git a/pymc3/sampling.py b/pymc3/sampling.py index f550218302a..cfe22f244e6 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -29,6 +29,7 @@ import warnings import arviz +from arviz import InferenceData import numpy as np import theano.gradient as tg from theano.tensor import Tensor @@ -57,6 +58,7 @@ is_transformed_name, get_default_varnames, dataset_to_point_dict, + chains_and_samples, ) from .vartypes import discrete_types from .exceptions import IncorrectArgumentsError @@ -91,6 +93,8 @@ ) ArrayLike = Union[np.ndarray, List[float]] +PointType = Dict[str, np.ndarray] +PointList = List[PointType] _log = logging.getLogger("pymc3") @@ -248,10 +252,10 @@ def sample( callback=None, *, return_inferencedata=None, - idata_kwargs: dict=None, + idata_kwargs: dict = None, mp_ctx=None, - pickle_backend: str = 'pickle', - **kwargs + pickle_backend: str = "pickle", + **kwargs, ): """Draw samples from the posterior using the given step methods. @@ -1584,7 +1588,7 @@ def sample_posterior_predictive( Parameters ---------- - trace : backend, list, xarray.Dataset, or MultiTrace + trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()), or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior) samples : int @@ -1598,8 +1602,7 @@ def sample_posterior_predictive( Variables for which to compute the posterior predictive samples. Deprecated: please use ``var_names`` instead. var_names : Iterable[str] - Alternative way to specify vars to sample, to make this function orthogonal with - others. + Names of variables for which to compute the posterior predictive samples. size : int The number of random draws from the distribution specified by the parameters in each sample of the trace. Not recommended unless more than ndraws times nchains posterior @@ -1620,29 +1623,48 @@ def sample_posterior_predictive( Dictionary with the variable names as keys, and values numpy arrays containing posterior predictive samples. """ - if isinstance(trace, xarray.Dataset): - trace = dataset_to_point_dict(trace) - len_trace = len(trace) - try: - nchain = trace.nchains - except AttributeError: - nchain = 1 + _trace: Union[MultiTrace, PointList] + if isinstance(trace, InferenceData): + _trace = dataset_to_point_dict(trace.posterior) + elif isinstance(trace, xarray.Dataset): + _trace = dataset_to_point_dict(trace) + else: + _trace = trace + + nchain: int + len_trace: int + if isinstance(trace, (InferenceData, xarray.Dataset)): + nchain, len_trace = chains_and_samples(trace) + else: + len_trace = len(_trace) + try: + nchain = _trace.nchains + except AttributeError: + nchain = 1 if keep_size and samples is not None: - raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments") + raise IncorrectArgumentsError( + "Should not specify both keep_size and samples arguments" + ) if keep_size and size is not None: - raise IncorrectArgumentsError("Should not specify both keep_size and size arguments") + raise IncorrectArgumentsError( + "Should not specify both keep_size and size arguments" + ) if samples is None: - if isinstance(trace, MultiTrace): - samples = sum(len(v) for v in trace._straces.values()) - elif isinstance(trace, list) and all((isinstance(x, dict) for x in trace)): + if isinstance(_trace, MultiTrace): + samples = sum(len(v) for v in _trace._straces.values()) + elif isinstance(_trace, list) and all((isinstance(x, dict) for x in _trace)): # this is a list of points - samples = len(trace) + samples = len(_trace) else: - raise ValueError("Do not know how to compute number of samples for trace argument of type %s"%type(trace)) + raise ValueError( + "Do not know how to compute number of samples for trace argument of type %s" + % type(_trace) + ) + assert samples is not None if samples < len_trace * nchain: warnings.warn( "samples parameter is smaller than nchains times ndraws, some draws " @@ -1675,10 +1697,21 @@ def sample_posterior_predictive( try: for idx in indices: if nchain > 1: - chain_idx, point_idx = np.divmod(idx, len_trace) - param = trace._straces[chain_idx % nchain].point(point_idx) + # the trace object will either be a MultiTrace (and have _straces)... + if hasattr(_trace, "_straces"): + chain_idx, point_idx = np.divmod(idx, len_trace) + param = ( + cast(MultiTrace, _trace) + ._straces[chain_idx % nchain] + .point(point_idx) + ) + # ... or a PointList + else: + param = cast(PointList, _trace)[idx % len_trace] + # there's only a single chain, but the index might hit it multiple times if + # the number of indices is greater than the length of the trace. else: - param = trace[idx % len_trace] + param = _trace[idx % len_trace] values = draw_values(vars, point=param, size=size) for k, v in zip(vars, values): diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index e80599f912d..806ff2bd017 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -381,6 +381,13 @@ def test_normal_scalar(self): ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True) assert ppc["a"].shape == (nchains, ndraws) + # test keep_size parameter and idata input + idata = az.from_pymc3(trace) + ppc = pm.sample_posterior_predictive(idata, keep_size=True) + assert ppc["a"].shape == (nchains, ndraws) + ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True) + assert ppc["a"].shape == (nchains, ndraws) + # test default case ppc = pm.sample_posterior_predictive(trace, var_names=["a"]) assert "a" in ppc @@ -428,6 +435,15 @@ def test_normal_vector(self, caplog): assert "a" in ppc assert ppc["a"].shape == (12, 2) + # test keep_size parameter with inference data as input... + idata = az.from_pymc3(trace) + ppc = pm.sample_posterior_predictive(idata, keep_size=True) + assert ppc["a"].shape == (trace.nchains, len(trace), 2) + with pytest.warns(UserWarning): + ppc = pm.sample_posterior_predictive(trace, samples=12, var_names=["a"]) + assert "a" in ppc + assert ppc["a"].shape == (12, 2) + # test keep_size parameter ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True) assert ppc["a"].shape == (trace.nchains, len(trace), 2) @@ -436,6 +452,14 @@ def test_normal_vector(self, caplog): assert "a" in ppc assert ppc["a"].shape == (12, 2) + # test keep_size parameter with inference data as input + ppc = pm.fast_sample_posterior_predictive(idata, keep_size=True) + assert ppc["a"].shape == (trace.nchains, len(trace), 2) + with pytest.warns(UserWarning): + ppc = pm.fast_sample_posterior_predictive(trace, samples=12, var_names=["a"]) + assert "a" in ppc + assert ppc["a"].shape == (12, 2) + # size unsupported by fast_ version argument. [2019/08/19:rpg] ppc = pm.sample_posterior_predictive(trace, samples=10, var_names=["a"], size=4) diff --git a/pymc3/util.py b/pymc3/util.py index 8babb18f333..a5d5d51de8a 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -14,9 +14,10 @@ import re import functools -from typing import List, Dict +from typing import List, Dict, Tuple, Union import xarray +import arviz from numpy import asscalar, ndarray @@ -182,22 +183,42 @@ def enhanced(*args, **kwargs): else: newwrapper = functools.partial(wrapper, *args, **kwargs) return newwrapper + return enhanced + +# FIXME: this function is poorly named, because it returns a LIST of +# points, not a dictionary of points. def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]: # grab posterior samples for each variable - _samples = { - vn : ds[vn].values - for vn in ds.keys() - } + _samples: Dict[str, ndarray] = {vn: ds[vn].values for vn in ds.keys()} # make dicts - points = [] + points: List[Dict[str, ndarray]] = [] + vn: str + s: ndarray for c in ds.chain: for d in ds.draw: - points.append({ - vn : s[c, d] - for vn, s in _samples.items() - }) + points.append({vn: s[c, d] for vn, s in _samples.items()}) # use the list of points - ds = points - return ds + return points + + +def chains_and_samples( + data: Union[xarray.Dataset, arviz.InferenceData] +) -> Tuple[int, int]: + """Extract and return number of chains and samples in xarray or arviz traces.""" + dataset: xarray.Dataset + if isinstance(data, xarray.Dataset): + dataset = data + elif isinstance(data, arviz.InferenceData): + dataset = data.posterior + else: + raise ValueError( + "Argument must be xarray Dataset or arviz InferenceData. Got %s", + data.__class__, + ) + + coords = dataset.coords + nchains = coords["chain"].sizes["chain"] + nsamples = coords["draw"].sizes["draw"] + return nchains, nsamples