diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index add775fa970..08cb82358df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,6 +59,8 @@ jobs: pymc/tests/distributions/test_censored.py pymc/tests/distributions/test_simulator.py pymc/tests/distributions/test_truncated.py + pymc/tests/test_sampling_predictive.py + pymc/tests/stats/test_convergence.py - | pymc/tests/tuning/test_scaling.py @@ -147,7 +149,7 @@ jobs: python-version: ["3.8"] test-subset: - pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py - - pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py + - pymc/tests/test_model.py pymc/tests/test_sampling_utils.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py - pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/test_parallel_sampling.py - pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py diff --git a/pymc/__init__.py b/pymc/__init__.py index 09314aa5c30..fb0287675ff 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -68,6 +68,8 @@ def __set_compiler_flags(): from pymc.plots import * from pymc.printing import * from pymc.sampling import * +from pymc.sampling_predictive import * +from pymc.sampling_utils import * from pymc.smc import * from pymc.stats import * from pymc.step_methods import * diff --git a/pymc/sampling.py b/pymc/sampling.py index 1a132009637..d94bdfd8a5c 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -22,42 +22,18 @@ from collections import defaultdict from copy import copy -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, -) +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast import aesara.gradient as tg import cloudpickle import numpy as np -import xarray - -from aesara import tensor as at -from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk -from aesara.graph.fg import FunctionGraph -from aesara.tensor.random.var import ( - RandomGeneratorSharedVariable, - RandomStateSharedVariable, -) -from aesara.tensor.sharedvar import SharedVariable + from arviz import InferenceData from fastprogress.fastprogress import progress_bar from typing_extensions import TypeAlias import pymc as pm -from pymc.aesaraf import compile_pymc -from pymc.backends.arviz import _DefaultTrace from pymc.backends.base import BaseTrace, MultiTrace from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection @@ -70,18 +46,17 @@ ) from pymc.model import Model, modelcontext from pymc.parallel_sampling import Draw, _cpu_count +from pymc.sampling_utils import ( + RandomSeed, + RandomState, + _get_seeds_per_chain, + all_continuous, +) from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks from pymc.step_methods import NUTS, CompoundStep, DEMetropolis from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential -from pymc.util import ( - dataset_to_point_list, - drop_warning_stat, - get_default_varnames, - get_untransformed_name, - is_transformed_name, - point_wrapper, -) +from pymc.util import drop_warning_stat, get_untransformed_name, is_transformed_name from pymc.vartypes import discrete_types sys.setrecursionlimit(10000) @@ -89,22 +64,11 @@ __all__ = [ "sample", "iter_sample", - "compile_forward_sampling_function", - "sample_posterior_predictive", - "sample_posterior_predictive_w", "init_nuts", - "sample_prior_predictive", - "draw", ] Step: TypeAlias = Union[BlockedStep, CompoundStep] -ArrayLike: TypeAlias = Union[np.ndarray, List[float]] -PointList: TypeAlias = List[PointType] -Backend: TypeAlias = Union[BaseTrace, MultiTrace, NDArray] - -RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]] -RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator] _log = logging.getLogger("pymc") @@ -244,67 +208,6 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: _log.info(">" * level + f"{s.__class__.__name__}: [{varnames}]") -def all_continuous(vars): - """Check that vars not include discrete variables, excepting observed RVs.""" - - vars_ = [var for var in vars if not hasattr(var.tag, "observations")] - - if any([(var.dtype in discrete_types) for var in vars_]): - return False - else: - return True - - -def _get_seeds_per_chain( - random_state: RandomState, - chains: int, -) -> Union[Sequence[int], np.ndarray]: - """Obtain or validate specified integer seeds per chain. - - This function process different possible sources of seeding and returns one integer - seed per chain: - 1. If the input is an integer and a single chain is requested, the input is - returned inside a tuple. - 2. If the input is a sequence or NumPy array with as many entries as chains, - the input is returned. - 3. If the input is an integer and multiple chains are requested, new unique seeds - are generated from NumPy default Generator seeded with that integer. - 4. If the input is None new unique seeds are generated from an unseeded NumPy default - Generator. - 5. If a RandomState or Generator is provided, new unique seeds are generated from it. - - Raises - ------ - ValueError - If none of the conditions above are met - """ - - def _get_unique_seeds_per_chain(integers_fn): - seeds = [] - while len(set(seeds)) != chains: - seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)] - return seeds - - if random_state is None or isinstance(random_state, int): - if chains == 1 and isinstance(random_state, int): - return (random_state,) - return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers) - if isinstance(random_state, np.random.Generator): - return _get_unique_seeds_per_chain(random_state.integers) - if isinstance(random_state, np.random.RandomState): - return _get_unique_seeds_per_chain(random_state.randint) - - if not isinstance(random_state, (list, tuple, np.ndarray)): - raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.") - - if len(random_state) != chains: - raise ValueError( - f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})." - ) - - return random_state - - def sample( draws: int = 1000, step=None, @@ -1598,660 +1501,6 @@ def stop_tuning(step): return step -def get_vars_in_point_list(trace, model): - """Get the list of Variable instances in the model that have values stored in the trace.""" - if not isinstance(trace, MultiTrace): - names_in_trace = list(trace[0]) - else: - names_in_trace = trace.varnames - vars_in_trace = [model[v] for v in names_in_trace if v in model] - return vars_in_trace - - -def compile_forward_sampling_function( - outputs: List[Variable], - vars_in_trace: List[Variable], - basic_rvs: Optional[List[Variable]] = None, - givens_dict: Optional[Dict[Variable, Any]] = None, - constant_data: Optional[Dict[str, np.ndarray]] = None, - constant_coords: Optional[Set[str]] = None, - **kwargs, -) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]: - """Compile a function to draw samples, conditioned on the values of some variables. - - The goal of this function is to walk the aesara computational graph from the list - of output nodes down to the root nodes, and then compile a function that will produce - values for these output nodes. The compiled function will take as inputs the subset of - variables in the ``vars_in_trace`` that are deemed to not be **volatile**. - - Volatile variables are variables whose values could change between runs of the - compiled function or after inference has been run. These variables are: - - - Variables in the outputs list - - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time - - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list - - Variables that are keys in the ``givens_dict`` - - Variables that have volatile inputs - - Concretely, this function can be used to compile a function to sample from the - posterior predictive distribution of a model that has variables that are conditioned - on ``MutableData`` instances. The variables that depend on the mutable data that have changed - will be considered volatile, and as such, they wont be included as inputs into the compiled - function. This means that if they have values stored in the posterior, these values will be - ignored and new values will be computed (in the case of deterministics and potentials) or - sampled (in the case of random variables). - - This function also enables a way to impute values for any variable in the computational - graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used - to set the ``givens`` argument of the aesara function compilation. This will essentially - replace a node in the computational graph with any other expression that has the same - type as the desired node. Passing variables in the givens_dict is considered an intervention - that might lead to different variable values from those that could have been seen during - inference, as such, **any variable that is passed in the ``givens_dict`` will be considered - volatile**. - - Parameters - ---------- - outputs : List[aesara.graph.basic.Variable] - The list of variables that will be returned by the compiled function - vars_in_trace : List[aesara.graph.basic.Variable] - The list of variables that are assumed to have values stored in the trace - basic_rvs : Optional[List[aesara.graph.basic.Variable]] - A list of random variables that are defined in the model. This list (which could be the - output of ``model.basic_RVs``) should have a reference to the variables that should - be considered as random variable instances. This includes variables that have - a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or - Censored distributions. - givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]] - A dictionary that maps tensor variables to the values that should be used to replace them - in the compiled function. The types of the key and value should match or an error will be - raised during compilation. - constant_data : Optional[Dict[str, numpy.ndarray]] - A dictionary that maps the names of ``MutableData`` or ``ConstantData`` instances to their - corresponding values at inference time. If a model was created with ``MutableData``, these - are stored as ``SharedVariable`` with the name of the data variable and a value equal to - the initial data. At inference time, this information is stored in ``InferenceData`` - objects under the ``constant_data`` group, which allows us to check whether a - ``SharedVariable`` instance changed its values after inference or not. If the values have - changed, then the ``SharedVariable`` is assumed to be volatile. If it has not changed, then - the ``SharedVariable`` is assumed to not be volatile. If a ``SharedVariable`` is not found - in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile. - Setting ``constant_data`` to ``None`` is equivalent to passing an empty dictionary. - constant_coords : Optional[Set[str]] - A set with the names of the mutable coordinates that have not changed their shape after - inference. If a model was created with mutable coordinates, these are stored as - ``SharedVariable`` with the name of the coordinate and a value equal to the length of said - coordinate. This set let's us check if a ``SharedVariable`` is a mutated coordinate, in - which case, it is considered volatile. If a ``SharedVariable`` is not found - in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile. - Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set. - - Returns - ------- - function: Callable - Compiled forward sampling Aesara function - volatile_basic_rvs: Set of Variable - Set of all basic_rvs that were considered volatile and will be resampled when - the function is evaluated - """ - if givens_dict is None: - givens_dict = {} - - if basic_rvs is None: - basic_rvs = [] - - if constant_data is None: - constant_data = {} - if constant_coords is None: - constant_coords = set() - - # We define a helper function to check if shared values match to an array - def shared_value_matches(var): - try: - old_array_value = constant_data[var.name] - except KeyError: - return var.name in constant_coords - current_shared_value = var.get_value(borrow=True) - return np.array_equal(old_array_value, current_shared_value) - - # We need a function graph to walk the clients and propagate the volatile property - fg = FunctionGraph(outputs=outputs, clone=False) - - # Walk the graph from inputs to outputs and tag the volatile variables - nodes: List[Variable] = general_toposort( - fg.outputs, deps=lambda x: x.owner.inputs if x.owner else [] - ) - volatile_nodes: Set[Any] = set() - for node in nodes: - if ( - node in fg.outputs - or node in givens_dict - or ( # SharedVariables, except RandomState/Generators - isinstance(node, SharedVariable) - and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) - and not shared_value_matches(node) - ) - or ( # Basic RVs that are not in the trace - node in basic_rvs and node not in vars_in_trace - ) - or ( # Variables that have any volatile input - node.owner and any(inp in volatile_nodes for inp in node.owner.inputs) - ) - ): - volatile_nodes.add(node) - - # Collect the function inputs by walking the graph from the outputs. Inputs will be: - # 1. Random variables that are not volatile - # 2. Variables that have no owner and are not constant or shared - inputs = [] - - def expand(node): - if ( - ( - node.owner is None and not isinstance(node, (Constant, SharedVariable)) - ) # Variables without owners that are not constant or shared - or node in vars_in_trace # Variables in the trace - ) and node not in volatile_nodes: - # This test will include variables without owners, and that are not constant - # or shared, because these nodes will never be considered volatile - inputs.append(node) - if node.owner: - return node.owner.inputs - - # walk produces a generator, so we have to actually exhaust the generator in a list to walk - # the entire graph - list(walk(fg.outputs, expand)) - - # Populate the givens list - givens = [ - ( - node, - value - if isinstance(value, (Variable, Apply)) - else at.constant(value, dtype=getattr(node, "dtype", None), name=node.name), - ) - for node, value in givens_dict.items() - ] - - return ( - compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs), - set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled - ) - - -def sample_posterior_predictive( - trace, - model: Optional[Model] = None, - var_names: Optional[List[str]] = None, - sample_dims: Optional[List[str]] = None, - random_seed: RandomState = None, - progressbar: bool = True, - return_inferencedata: bool = True, - extend_inferencedata: bool = False, - predictions: bool = False, - idata_kwargs: dict = None, - compile_kwargs: dict = None, -) -> Union[InferenceData, Dict[str, np.ndarray]]: - """Generate posterior predictive samples from a model given a trace. - - Parameters - ---------- - trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace - Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()), - or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior) - model : Model (optional if in ``with`` context) - Model to be used to generate the posterior predictive samples. It will - generally be the model used to generate the ``trace``, but it doesn't need to be. - var_names : Iterable[str] - Names of variables for which to compute the posterior predictive samples. - sample_dims : list of str, optional - Dimensions over which to loop and generate posterior predictive samples. - When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample - dimensions. Only taken into account when `trace` is InferenceData or Dataset. - random_seed : int, RandomState or Generator, optional - Seed for the random number generator. - progressbar : bool - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). - return_inferencedata : bool, default True - Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False). - extend_inferencedata : bool, default False - Whether to automatically use :meth:`arviz.InferenceData.extend` to add the posterior predictive samples to - ``trace`` or not. If True, ``trace`` is modified inplace but still returned. - predictions : bool, default False - Choose the function used to convert the samples to inferencedata. See ``idata_kwargs`` - for more details. - idata_kwargs : dict, optional - Keyword arguments for :func:`pymc.to_inference_data` if ``predictions=False`` or to - :func:`pymc.predictions_to_inference_data` otherwise. - compile_kwargs: dict, optional - Keyword arguments for :func:`pymc.aesaraf.compile_pymc`. - - Returns - ------- - arviz.InferenceData or Dict - An ArviZ ``InferenceData`` object containing the posterior predictive samples (default), or - a dictionary with variable names as keys, and samples as numpy arrays. - - Examples - -------- - Thin a sampled inferencedata by keeping 1 out of every 5 draws - before passing it to sample_posterior_predictive - - .. code:: python - - thinned_idata = idata.sel(draw=slice(None, None, 5)) - with model: - idata.extend(pymc.sample_posterior_predictive(thinned_idata)) - - Generate 5 posterior predictive samples per posterior sample. - - .. code:: python - - expanded_data = idata.posterior.expand_dims(pred_id=5) - with model: - idata.extend(pymc.sample_posterior_predictive(expanded_data)) - """ - - _trace: Union[MultiTrace, PointList] - nchain: int - if idata_kwargs is None: - idata_kwargs = {} - else: - idata_kwargs = idata_kwargs.copy() - if sample_dims is None: - sample_dims = ["chain", "draw"] - constant_data: Dict[str, np.ndarray] = {} - trace_coords: Dict[str, np.ndarray] = {} - if "coords" not in idata_kwargs: - idata_kwargs["coords"] = {} - idata: Optional[InferenceData] = None - stacked_dims = None - if isinstance(trace, InferenceData): - _constant_data = getattr(trace, "constant_data", None) - if _constant_data is not None: - trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) - constant_data.update({str(k): v.data for k, v in _constant_data.items()}) - idata = trace - trace = trace["posterior"] - if isinstance(trace, xarray.Dataset): - trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) - _trace, stacked_dims = dataset_to_point_list(trace, sample_dims) - nchain = 1 - elif isinstance(trace, MultiTrace): - _trace = trace - nchain = _trace.nchains - elif isinstance(trace, list) and all(isinstance(x, dict) for x in trace): - _trace = trace - nchain = 1 - else: - raise TypeError(f"Unsupported type for `trace` argument: {type(trace)}.") - len_trace = len(_trace) - - if isinstance(_trace, MultiTrace): - samples = sum(len(v) for v in _trace._straces.values()) - elif isinstance(_trace, list): - # this is a list of points - samples = len(_trace) - else: - raise TypeError( - "Do not know how to compute number of samples for trace argument of type %s" - % type(_trace) - ) - - assert samples is not None - - model = modelcontext(model) - - if model.potentials: - warnings.warn( - "The effect of Potentials on other parameters is ignored during posterior predictive sampling. " - "This is likely to lead to invalid or biased predictive samples.", - UserWarning, - stacklevel=2, - ) - - constant_coords = set() - for dim, coord in trace_coords.items(): - current_coord = model.coords.get(dim, None) - if ( - current_coord is not None - and len(coord) == len(current_coord) - and np.all(coord == current_coord) - ): - constant_coords.add(dim) - - if var_names is not None: - vars_ = [model[x] for x in var_names] - else: - vars_ = model.observed_RVs + model.auto_deterministics - - indices = np.arange(samples) - if progressbar: - indices = progress_bar(indices, total=samples, display=progressbar) - - vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) - - if not vars_to_sample: - if return_inferencedata and not extend_inferencedata: - return InferenceData() - elif return_inferencedata and extend_inferencedata: - return trace - return {} - - vars_in_trace = get_vars_in_point_list(_trace, model) - - if random_seed is not None: - (random_seed,) = _get_seeds_per_chain(random_seed, 1) - - if compile_kwargs is None: - compile_kwargs = {} - compile_kwargs.setdefault("allow_input_downcast", True) - compile_kwargs.setdefault("accept_inplace", True) - - _sampler_fn, volatile_basic_rvs = compile_forward_sampling_function( - outputs=vars_to_sample, - vars_in_trace=vars_in_trace, - basic_rvs=model.basic_RVs, - givens_dict=None, - random_seed=random_seed, - constant_data=constant_data, - constant_coords=constant_coords, - **compile_kwargs, - ) - sampler_fn = point_wrapper(_sampler_fn) - # All model variables have a name, but mypy does not know this - _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore - ppc_trace_t = _DefaultTrace(samples) - try: - for idx in indices: - if nchain > 1: - # the trace object will either be a MultiTrace (and have _straces)... - if hasattr(_trace, "_straces"): - chain_idx, point_idx = np.divmod(idx, len_trace) - chain_idx = chain_idx % nchain - param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx) - # ... or a PointList - else: - param = cast(PointList, _trace)[idx % (len_trace * nchain)] - # there's only a single chain, but the index might hit it multiple times if - # the number of indices is greater than the length of the trace. - else: - param = _trace[idx % len_trace] - - values = sampler_fn(**param) - - for k, v in zip(vars_, values): - ppc_trace_t.insert(k.name, v, idx) - except KeyboardInterrupt: - pass - - ppc_trace = ppc_trace_t.trace_dict - - for k, ary in ppc_trace.items(): - if stacked_dims is not None: - ppc_trace[k] = ary.reshape( - (*[len(coord) for coord in stacked_dims.values()], *ary.shape[1:]) - ) - else: - ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:])) - - if not return_inferencedata: - return ppc_trace - ikwargs: Dict[str, Any] = dict(model=model, **idata_kwargs) - ikwargs.setdefault("sample_dims", sample_dims) - if stacked_dims is not None: - coords = ikwargs.get("coords", {}) - ikwargs["coords"] = {**stacked_dims, **coords} - if predictions: - if extend_inferencedata: - ikwargs.setdefault("idata_orig", idata) - ikwargs.setdefault("inplace", True) - return pm.predictions_to_inference_data(ppc_trace, **ikwargs) - idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs) - - if extend_inferencedata and idata is not None: - idata.extend(idata_pp) - return idata - return idata_pp - - -def sample_posterior_predictive_w( - traces, - samples: Optional[int] = None, - models: Optional[List[Model]] = None, - weights: Optional[ArrayLike] = None, - random_seed: RandomState = None, - progressbar: bool = True, - return_inferencedata: bool = True, - idata_kwargs: dict = None, -): - """Generate weighted posterior predictive samples from a list of models and - a list of traces according to a set of weights. - - Parameters - ---------- - traces : list or list of lists - List of traces generated from MCMC sampling (xarray.Dataset, arviz.InferenceData, or - MultiTrace), or a list of list containing dicts from find_MAP() or points. The number of - traces should be equal to the number of weights. - samples : int, optional - Number of posterior predictive samples to generate. Defaults to the - length of the shorter trace in traces. - models : list of Model - List of models used to generate the list of traces. The number of models should be equal to - the number of weights and the number of observed RVs should be the same for all models. - By default a single model will be inferred from ``with`` context, in this case results will - only be meaningful if all models share the same distributions for the observed RVs. - weights : array-like, optional - Individual weights for each trace. Default, same weight for each model. - random_seed : int, RandomState or Generator, optional - Seed for the random number generator. - progressbar : bool, optional default True - Whether or not to display a progress bar in the command line. The bar shows the percentage - of completion, the sampling speed in samples per second (SPS), and the estimated remaining - time until completion ("expected time of arrival"; ETA). - return_inferencedata : bool - Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False). - Defaults to True. - idata_kwargs : dict, optional - Keyword arguments for :func:`pymc.to_inference_data` - - Returns - ------- - arviz.InferenceData or Dict - An ArviZ ``InferenceData`` object containing the posterior predictive samples from the - weighted models (default), or a dictionary with variable names as keys, and samples as - numpy arrays. - """ - raise FutureWarning( - "The function `sample_posterior_predictive_w` has been removed in PyMC 4.3.0. " - "Switch to `arviz.stats.weight_predictions`" - ) - - -def sample_prior_predictive( - samples: int = 500, - model: Optional[Model] = None, - var_names: Optional[Iterable[str]] = None, - random_seed: RandomState = None, - return_inferencedata: bool = True, - idata_kwargs: dict = None, - compile_kwargs: dict = None, -) -> Union[InferenceData, Dict[str, np.ndarray]]: - """Generate samples from the prior predictive distribution. - - Parameters - ---------- - samples : int - Number of samples from the prior predictive to generate. Defaults to 500. - model : Model (optional if in ``with`` context) - var_names : Iterable[str] - A list of names of variables for which to compute the prior predictive - samples. Defaults to both observed and unobserved RVs. Transformed values - are not included unless explicitly defined in var_names. - random_seed : int, RandomState or Generator, optional - Seed for the random number generator. - return_inferencedata : bool - Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False). - Defaults to True. - idata_kwargs : dict, optional - Keyword arguments for :func:`pymc.to_inference_data` - compile_kwargs: dict, optional - Keyword arguments for :func:`pymc.aesaraf.compile_pymc`. - - Returns - ------- - arviz.InferenceData or Dict - An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default), - or a dictionary with variable names as keys and samples as numpy arrays. - """ - model = modelcontext(model) - - if model.potentials: - warnings.warn( - "The effect of Potentials on other parameters is ignored during prior predictive sampling. " - "This is likely to lead to invalid or biased predictive samples.", - UserWarning, - stacklevel=2, - ) - - if var_names is None: - prior_pred_vars = model.observed_RVs + model.auto_deterministics - prior_vars = ( - get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials - ) - vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars} - else: - vars_ = set(var_names) - - names = sorted(get_default_varnames(vars_, include_transformed=False)) - vars_to_sample = [model[name] for name in names] - - # Any variables from var_names that are missing must be transformed variables. - # Misspelled variables would have raised a KeyError above. - missing_names = vars_.difference(names) - for name in sorted(missing_names): - transformed_value_var = model[name] - rv_var = model.values_to_rvs[transformed_value_var] - transform = transformed_value_var.tag.transform - transformed_rv_var = transform.forward(rv_var, *rv_var.owner.inputs) - - names.append(name) - vars_to_sample.append(transformed_rv_var) - - # If the user asked for the transformed variable in var_names, but not the - # original RV, we add it manually here - if rv_var.name not in names: - names.append(rv_var.name) - vars_to_sample.append(rv_var) - - if random_seed is not None: - (random_seed,) = _get_seeds_per_chain(random_seed, 1) - - if compile_kwargs is None: - compile_kwargs = {} - compile_kwargs.setdefault("allow_input_downcast", True) - compile_kwargs.setdefault("accept_inplace", True) - - sampler_fn, volatile_basic_rvs = compile_forward_sampling_function( - vars_to_sample, - vars_in_trace=[], - basic_rvs=model.basic_RVs, - givens_dict=None, - random_seed=random_seed, - **compile_kwargs, - ) - - # All model variables have a name, but mypy does not know this - _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore - values = zip(*(sampler_fn() for i in range(samples))) - - data = {k: np.stack(v) for k, v in zip(names, values)} - if data is None: - raise AssertionError("No variables sampled: attempting to sample %s" % names) - - prior: Dict[str, np.ndarray] = {} - for var_name in vars_: - if var_name in data: - prior[var_name] = data[var_name] - - if not return_inferencedata: - return prior - ikwargs: Dict[str, Any] = dict(model=model) - if idata_kwargs: - ikwargs.update(idata_kwargs) - return pm.to_inference_data(prior=prior, **ikwargs) - - -def draw( - vars: Union[Variable, Sequence[Variable]], - draws: int = 1, - random_seed: RandomState = None, - **kwargs, -) -> Union[np.ndarray, List[np.ndarray]]: - """Draw samples for one variable or a list of variables - - Parameters - ---------- - vars : TensorVariable or iterable of TensorVariable - A variable or a list of variables for which to draw samples. - draws : int, default 1 - Number of samples needed to draw. - random_seed : int, RandomState or numpy_Generator, optional - Seed for the random number generator. - **kwargs : dict, optional - Keyword arguments for :func:`pymc.aesaraf.compile_pymc`. - - Returns - ------- - list of ndarray - A list of numpy arrays. - - Examples - -------- - .. code-block:: python - - import pymc as pm - - # Draw samples for one variable - with pm.Model(): - x = pm.Normal("x") - x_draws = pm.draw(x, draws=100) - print(x_draws.shape) - - # Draw 1000 samples for several variables - with pm.Model(): - x = pm.Normal("x") - y = pm.Normal("y", shape=10) - z = pm.Uniform("z", shape=5) - num_draws = 1000 - # Draw samples of a list variables - draws = pm.draw([x, y, z], draws=num_draws) - assert draws[0].shape == (num_draws,) - assert draws[1].shape == (num_draws, 10) - assert draws[2].shape == (num_draws, 5) - """ - if random_seed is not None: - (random_seed,) = _get_seeds_per_chain(random_seed, 1) - - draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs) - - if draws == 1: - return draw_fn() - - # Single variable output - if not isinstance(vars, (list, tuple)): - cast(Callable[[], np.ndarray], draw_fn) - return np.stack([draw_fn() for _ in range(draws)]) - - # Multiple variable output - cast(Callable[[], List[np.ndarray]], draw_fn) - drawn_values = zip(*(draw_fn() for _ in range(draws))) - return [np.stack(v) for v in drawn_values] - - def _init_jitter( model: Model, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 8852691ead1..f1a36f04584 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -7,7 +7,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from pymc.initial_point import StartDict -from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter +from pymc.sampling import _init_jitter +from pymc.sampling_utils import RandomSeed, _get_seeds_per_chain xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() diff --git a/pymc/sampling_predictive.py b/pymc/sampling_predictive.py new file mode 100644 index 00000000000..93ec11fa35a --- /dev/null +++ b/pymc/sampling_predictive.py @@ -0,0 +1,457 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for prior and posterior predictive sampling.""" + +import logging +import warnings + +from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast + +import numpy as np +import xarray + +from arviz import InferenceData +from fastprogress.fastprogress import progress_bar + +import pymc as pm + +from pymc.backends.arviz import _DefaultTrace +from pymc.backends.base import MultiTrace +from pymc.initial_point import PointType +from pymc.model import Model, modelcontext +from pymc.sampling_utils import ( + ArrayLike, + PointList, + RandomState, + _get_seeds_per_chain, + compile_forward_sampling_function, + get_vars_in_point_list, +) +from pymc.util import dataset_to_point_list, get_default_varnames, point_wrapper + +__all__ = ( + "sample_prior_predictive", + "sample_posterior_predictive", + "sample_posterior_predictive_w", +) + + +_log = logging.getLogger("pymc") + + +def sample_prior_predictive( + samples: int = 500, + model: Optional[Model] = None, + var_names: Optional[Iterable[str]] = None, + random_seed: RandomState = None, + return_inferencedata: bool = True, + idata_kwargs: dict = None, + compile_kwargs: dict = None, +) -> Union[InferenceData, Dict[str, np.ndarray]]: + """Generate samples from the prior predictive distribution. + + Parameters + ---------- + samples : int + Number of samples from the prior predictive to generate. Defaults to 500. + model : Model (optional if in ``with`` context) + var_names : Iterable[str] + A list of names of variables for which to compute the prior predictive + samples. Defaults to both observed and unobserved RVs. Transformed values + are not included unless explicitly defined in var_names. + random_seed : int, RandomState or Generator, optional + Seed for the random number generator. + return_inferencedata : bool + Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False). + Defaults to True. + idata_kwargs : dict, optional + Keyword arguments for :func:`pymc.to_inference_data` + compile_kwargs: dict, optional + Keyword arguments for :func:`pymc.aesaraf.compile_pymc`. + + Returns + ------- + arviz.InferenceData or Dict + An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default), + or a dictionary with variable names as keys and samples as numpy arrays. + """ + model = modelcontext(model) + + if model.potentials: + warnings.warn( + "The effect of Potentials on other parameters is ignored during prior predictive sampling. " + "This is likely to lead to invalid or biased predictive samples.", + UserWarning, + stacklevel=2, + ) + + if var_names is None: + prior_pred_vars = model.observed_RVs + model.auto_deterministics + prior_vars = ( + get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials + ) + vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars} + else: + vars_ = set(var_names) + + names = sorted(get_default_varnames(vars_, include_transformed=False)) + vars_to_sample = [model[name] for name in names] + + # Any variables from var_names that are missing must be transformed variables. + # Misspelled variables would have raised a KeyError above. + missing_names = vars_.difference(names) + for name in sorted(missing_names): + transformed_value_var = model[name] + rv_var = model.values_to_rvs[transformed_value_var] + transform = transformed_value_var.tag.transform + transformed_rv_var = transform.forward(rv_var, *rv_var.owner.inputs) + + names.append(name) + vars_to_sample.append(transformed_rv_var) + + # If the user asked for the transformed variable in var_names, but not the + # original RV, we add it manually here + if rv_var.name not in names: + names.append(rv_var.name) + vars_to_sample.append(rv_var) + + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + + if compile_kwargs is None: + compile_kwargs = {} + compile_kwargs.setdefault("allow_input_downcast", True) + compile_kwargs.setdefault("accept_inplace", True) + + sampler_fn, volatile_basic_rvs = compile_forward_sampling_function( + vars_to_sample, + vars_in_trace=[], + basic_rvs=model.basic_RVs, + givens_dict=None, + random_seed=random_seed, + **compile_kwargs, + ) + + # All model variables have a name, but mypy does not know this + _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore + values = zip(*(sampler_fn() for i in range(samples))) + + data = {k: np.stack(v) for k, v in zip(names, values)} + if data is None: + raise AssertionError("No variables sampled: attempting to sample %s" % names) + + prior: Dict[str, np.ndarray] = {} + for var_name in vars_: + if var_name in data: + prior[var_name] = data[var_name] + + if not return_inferencedata: + return prior + ikwargs: Dict[str, Any] = dict(model=model) + if idata_kwargs: + ikwargs.update(idata_kwargs) + return pm.to_inference_data(prior=prior, **ikwargs) + + +def sample_posterior_predictive( + trace, + model: Optional[Model] = None, + var_names: Optional[List[str]] = None, + sample_dims: Optional[List[str]] = None, + random_seed: RandomState = None, + progressbar: bool = True, + return_inferencedata: bool = True, + extend_inferencedata: bool = False, + predictions: bool = False, + idata_kwargs: dict = None, + compile_kwargs: dict = None, +) -> Union[InferenceData, Dict[str, np.ndarray]]: + """Generate posterior predictive samples from a model given a trace. + + Parameters + ---------- + trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace + Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()), + or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior) + model : Model (optional if in ``with`` context) + Model to be used to generate the posterior predictive samples. It will + generally be the model used to generate the ``trace``, but it doesn't need to be. + var_names : Iterable[str] + Names of variables for which to compute the posterior predictive samples. + sample_dims : list of str, optional + Dimensions over which to loop and generate posterior predictive samples. + When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample + dimensions. Only taken into account when `trace` is InferenceData or Dataset. + random_seed : int, RandomState or Generator, optional + Seed for the random number generator. + progressbar : bool + Whether or not to display a progress bar in the command line. The bar shows the percentage + of completion, the sampling speed in samples per second (SPS), and the estimated remaining + time until completion ("expected time of arrival"; ETA). + return_inferencedata : bool, default True + Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False). + extend_inferencedata : bool, default False + Whether to automatically use :meth:`arviz.InferenceData.extend` to add the posterior predictive samples to + ``trace`` or not. If True, ``trace`` is modified inplace but still returned. + predictions : bool, default False + Choose the function used to convert the samples to inferencedata. See ``idata_kwargs`` + for more details. + idata_kwargs : dict, optional + Keyword arguments for :func:`pymc.to_inference_data` if ``predictions=False`` or to + :func:`pymc.predictions_to_inference_data` otherwise. + compile_kwargs: dict, optional + Keyword arguments for :func:`pymc.aesaraf.compile_pymc`. + + Returns + ------- + arviz.InferenceData or Dict + An ArviZ ``InferenceData`` object containing the posterior predictive samples (default), or + a dictionary with variable names as keys, and samples as numpy arrays. + + Examples + -------- + Thin a sampled inferencedata by keeping 1 out of every 5 draws + before passing it to sample_posterior_predictive + + .. code:: python + + thinned_idata = idata.sel(draw=slice(None, None, 5)) + with model: + idata.extend(pymc.sample_posterior_predictive(thinned_idata)) + + Generate 5 posterior predictive samples per posterior sample. + + .. code:: python + + expanded_data = idata.posterior.expand_dims(pred_id=5) + with model: + idata.extend(pymc.sample_posterior_predictive(expanded_data)) + """ + + _trace: Union[MultiTrace, PointList] + nchain: int + if idata_kwargs is None: + idata_kwargs = {} + else: + idata_kwargs = idata_kwargs.copy() + if sample_dims is None: + sample_dims = ["chain", "draw"] + constant_data: Dict[str, np.ndarray] = {} + trace_coords: Dict[str, np.ndarray] = {} + if "coords" not in idata_kwargs: + idata_kwargs["coords"] = {} + idata: Optional[InferenceData] = None + stacked_dims = None + if isinstance(trace, InferenceData): + _constant_data = getattr(trace, "constant_data", None) + if _constant_data is not None: + trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) + constant_data.update({str(k): v.data for k, v in _constant_data.items()}) + idata = trace + trace = trace["posterior"] + if isinstance(trace, xarray.Dataset): + trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) + _trace, stacked_dims = dataset_to_point_list(trace, sample_dims) + nchain = 1 + elif isinstance(trace, MultiTrace): + _trace = trace + nchain = _trace.nchains + elif isinstance(trace, list) and all(isinstance(x, dict) for x in trace): + _trace = trace + nchain = 1 + else: + raise TypeError(f"Unsupported type for `trace` argument: {type(trace)}.") + len_trace = len(_trace) + + if isinstance(_trace, MultiTrace): + samples = sum(len(v) for v in _trace._straces.values()) + elif isinstance(_trace, list): + # this is a list of points + samples = len(_trace) + else: + raise TypeError( + "Do not know how to compute number of samples for trace argument of type %s" + % type(_trace) + ) + + assert samples is not None + + model = modelcontext(model) + + if model.potentials: + warnings.warn( + "The effect of Potentials on other parameters is ignored during posterior predictive sampling. " + "This is likely to lead to invalid or biased predictive samples.", + UserWarning, + stacklevel=2, + ) + + constant_coords = set() + for dim, coord in trace_coords.items(): + current_coord = model.coords.get(dim, None) + if ( + current_coord is not None + and len(coord) == len(current_coord) + and np.all(coord == current_coord) + ): + constant_coords.add(dim) + + if var_names is not None: + vars_ = [model[x] for x in var_names] + else: + vars_ = model.observed_RVs + model.auto_deterministics + + indices = np.arange(samples) + if progressbar: + indices = progress_bar(indices, total=samples, display=progressbar) + + vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) + + if not vars_to_sample: + if return_inferencedata and not extend_inferencedata: + return InferenceData() + elif return_inferencedata and extend_inferencedata: + return trace + return {} + + vars_in_trace = get_vars_in_point_list(_trace, model) + + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + + if compile_kwargs is None: + compile_kwargs = {} + compile_kwargs.setdefault("allow_input_downcast", True) + compile_kwargs.setdefault("accept_inplace", True) + + _sampler_fn, volatile_basic_rvs = compile_forward_sampling_function( + outputs=vars_to_sample, + vars_in_trace=vars_in_trace, + basic_rvs=model.basic_RVs, + givens_dict=None, + random_seed=random_seed, + constant_data=constant_data, + constant_coords=constant_coords, + **compile_kwargs, + ) + sampler_fn = point_wrapper(_sampler_fn) + # All model variables have a name, but mypy does not know this + _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore + ppc_trace_t = _DefaultTrace(samples) + try: + for idx in indices: + if nchain > 1: + # the trace object will either be a MultiTrace (and have _straces)... + if hasattr(_trace, "_straces"): + chain_idx, point_idx = np.divmod(idx, len_trace) + chain_idx = chain_idx % nchain + param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx) + # ... or a PointList + else: + param = cast(PointList, _trace)[idx % (len_trace * nchain)] + # there's only a single chain, but the index might hit it multiple times if + # the number of indices is greater than the length of the trace. + else: + param = _trace[idx % len_trace] + + values = sampler_fn(**param) + + for k, v in zip(vars_, values): + ppc_trace_t.insert(k.name, v, idx) + except KeyboardInterrupt: + pass + + ppc_trace = ppc_trace_t.trace_dict + + for k, ary in ppc_trace.items(): + if stacked_dims is not None: + ppc_trace[k] = ary.reshape( + (*[len(coord) for coord in stacked_dims.values()], *ary.shape[1:]) + ) + else: + ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:])) + + if not return_inferencedata: + return ppc_trace + ikwargs: Dict[str, Any] = dict(model=model, **idata_kwargs) + ikwargs.setdefault("sample_dims", sample_dims) + if stacked_dims is not None: + coords = ikwargs.get("coords", {}) + ikwargs["coords"] = {**stacked_dims, **coords} + if predictions: + if extend_inferencedata: + ikwargs.setdefault("idata_orig", idata) + ikwargs.setdefault("inplace", True) + return pm.predictions_to_inference_data(ppc_trace, **ikwargs) + idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs) + + if extend_inferencedata and idata is not None: + idata.extend(idata_pp) + return idata + return idata_pp + + +def sample_posterior_predictive_w( + traces, + samples: Optional[int] = None, + models: Optional[List[Model]] = None, + weights: Optional[ArrayLike] = None, + random_seed: RandomState = None, + progressbar: bool = True, + return_inferencedata: bool = True, + idata_kwargs: dict = None, +): + """Generate weighted posterior predictive samples from a list of models and + a list of traces according to a set of weights. + + Parameters + ---------- + traces : list or list of lists + List of traces generated from MCMC sampling (xarray.Dataset, arviz.InferenceData, or + MultiTrace), or a list of list containing dicts from find_MAP() or points. The number of + traces should be equal to the number of weights. + samples : int, optional + Number of posterior predictive samples to generate. Defaults to the + length of the shorter trace in traces. + models : list of Model + List of models used to generate the list of traces. The number of models should be equal to + the number of weights and the number of observed RVs should be the same for all models. + By default a single model will be inferred from ``with`` context, in this case results will + only be meaningful if all models share the same distributions for the observed RVs. + weights : array-like, optional + Individual weights for each trace. Default, same weight for each model. + random_seed : int, RandomState or Generator, optional + Seed for the random number generator. + progressbar : bool, optional default True + Whether or not to display a progress bar in the command line. The bar shows the percentage + of completion, the sampling speed in samples per second (SPS), and the estimated remaining + time until completion ("expected time of arrival"; ETA). + return_inferencedata : bool + Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False). + Defaults to True. + idata_kwargs : dict, optional + Keyword arguments for :func:`pymc.to_inference_data` + + Returns + ------- + arviz.InferenceData or Dict + An ArviZ ``InferenceData`` object containing the posterior predictive samples from the + weighted models (default), or a dictionary with variable names as keys, and samples as + numpy arrays. + """ + raise FutureWarning( + "The function `sample_posterior_predictive_w` has been removed in PyMC 4.3.0. " + "Switch to `arviz.stats.weight_predictions`" + ) diff --git a/pymc/sampling_utils.py b/pymc/sampling_utils.py new file mode 100644 index 00000000000..7d8e787001c --- /dev/null +++ b/pymc/sampling_utils.py @@ -0,0 +1,372 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for MCMC, prior and posterior predictive sampling.""" + +import logging + +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, +) + +import numpy as np + +from aesara import tensor as at +from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk +from aesara.graph.fg import FunctionGraph +from aesara.tensor.random.var import ( + RandomGeneratorSharedVariable, + RandomStateSharedVariable, +) +from aesara.tensor.sharedvar import SharedVariable +from typing_extensions import TypeAlias + +from pymc.aesaraf import compile_pymc +from pymc.backends.base import BaseTrace, MultiTrace +from pymc.backends.ndarray import NDArray +from pymc.initial_point import PointType +from pymc.vartypes import discrete_types + +ArrayLike: TypeAlias = Union[np.ndarray, List[float]] +PointList: TypeAlias = List[PointType] +Backend: TypeAlias = Union[BaseTrace, MultiTrace, NDArray] + +RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]] +RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator] + +_log = logging.getLogger("pymc") + + +__all__ = ( + "compile_forward_sampling_function", + "draw", +) + + +def all_continuous(vars): + """Check that vars not include discrete variables, excepting observed RVs.""" + + vars_ = [var for var in vars if not hasattr(var.tag, "observations")] + + if any([(var.dtype in discrete_types) for var in vars_]): + return False + else: + return True + + +def _get_seeds_per_chain( + random_state: RandomState, + chains: int, +) -> Union[Sequence[int], np.ndarray]: + """Obtain or validate specified integer seeds per chain. + + This function process different possible sources of seeding and returns one integer + seed per chain: + 1. If the input is an integer and a single chain is requested, the input is + returned inside a tuple. + 2. If the input is a sequence or NumPy array with as many entries as chains, + the input is returned. + 3. If the input is an integer and multiple chains are requested, new unique seeds + are generated from NumPy default Generator seeded with that integer. + 4. If the input is None new unique seeds are generated from an unseeded NumPy default + Generator. + 5. If a RandomState or Generator is provided, new unique seeds are generated from it. + + Raises + ------ + ValueError + If none of the conditions above are met + """ + + def _get_unique_seeds_per_chain(integers_fn): + seeds = [] + while len(set(seeds)) != chains: + seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)] + return seeds + + if random_state is None or isinstance(random_state, int): + if chains == 1 and isinstance(random_state, int): + return (random_state,) + return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers) + if isinstance(random_state, np.random.Generator): + return _get_unique_seeds_per_chain(random_state.integers) + if isinstance(random_state, np.random.RandomState): + return _get_unique_seeds_per_chain(random_state.randint) + + if not isinstance(random_state, (list, tuple, np.ndarray)): + raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.") + + if len(random_state) != chains: + raise ValueError( + f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})." + ) + + return random_state + + +def get_vars_in_point_list(trace, model): + """Get the list of Variable instances in the model that have values stored in the trace.""" + if not isinstance(trace, MultiTrace): + names_in_trace = list(trace[0]) + else: + names_in_trace = trace.varnames + vars_in_trace = [model[v] for v in names_in_trace if v in model] + return vars_in_trace + + +def compile_forward_sampling_function( + outputs: List[Variable], + vars_in_trace: List[Variable], + basic_rvs: Optional[List[Variable]] = None, + givens_dict: Optional[Dict[Variable, Any]] = None, + constant_data: Optional[Dict[str, np.ndarray]] = None, + constant_coords: Optional[Set[str]] = None, + **kwargs, +) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]: + """Compile a function to draw samples, conditioned on the values of some variables. + + The goal of this function is to walk the aesara computational graph from the list + of output nodes down to the root nodes, and then compile a function that will produce + values for these output nodes. The compiled function will take as inputs the subset of + variables in the ``vars_in_trace`` that are deemed to not be **volatile**. + + Volatile variables are variables whose values could change between runs of the + compiled function or after inference has been run. These variables are: + + - Variables in the outputs list + - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time + - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list + - Variables that are keys in the ``givens_dict`` + - Variables that have volatile inputs + + Concretely, this function can be used to compile a function to sample from the + posterior predictive distribution of a model that has variables that are conditioned + on ``MutableData`` instances. The variables that depend on the mutable data that have changed + will be considered volatile, and as such, they wont be included as inputs into the compiled + function. This means that if they have values stored in the posterior, these values will be + ignored and new values will be computed (in the case of deterministics and potentials) or + sampled (in the case of random variables). + + This function also enables a way to impute values for any variable in the computational + graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used + to set the ``givens`` argument of the aesara function compilation. This will essentially + replace a node in the computational graph with any other expression that has the same + type as the desired node. Passing variables in the givens_dict is considered an intervention + that might lead to different variable values from those that could have been seen during + inference, as such, **any variable that is passed in the ``givens_dict`` will be considered + volatile**. + + Parameters + ---------- + outputs : List[aesara.graph.basic.Variable] + The list of variables that will be returned by the compiled function + vars_in_trace : List[aesara.graph.basic.Variable] + The list of variables that are assumed to have values stored in the trace + basic_rvs : Optional[List[aesara.graph.basic.Variable]] + A list of random variables that are defined in the model. This list (which could be the + output of ``model.basic_RVs``) should have a reference to the variables that should + be considered as random variable instances. This includes variables that have + a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or + Censored distributions. + givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]] + A dictionary that maps tensor variables to the values that should be used to replace them + in the compiled function. The types of the key and value should match or an error will be + raised during compilation. + constant_data : Optional[Dict[str, numpy.ndarray]] + A dictionary that maps the names of ``MutableData`` or ``ConstantData`` instances to their + corresponding values at inference time. If a model was created with ``MutableData``, these + are stored as ``SharedVariable`` with the name of the data variable and a value equal to + the initial data. At inference time, this information is stored in ``InferenceData`` + objects under the ``constant_data`` group, which allows us to check whether a + ``SharedVariable`` instance changed its values after inference or not. If the values have + changed, then the ``SharedVariable`` is assumed to be volatile. If it has not changed, then + the ``SharedVariable`` is assumed to not be volatile. If a ``SharedVariable`` is not found + in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile. + Setting ``constant_data`` to ``None`` is equivalent to passing an empty dictionary. + constant_coords : Optional[Set[str]] + A set with the names of the mutable coordinates that have not changed their shape after + inference. If a model was created with mutable coordinates, these are stored as + ``SharedVariable`` with the name of the coordinate and a value equal to the length of said + coordinate. This set let's us check if a ``SharedVariable`` is a mutated coordinate, in + which case, it is considered volatile. If a ``SharedVariable`` is not found + in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile. + Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set. + + Returns + ------- + function: Callable + Compiled forward sampling Aesara function + volatile_basic_rvs: Set of Variable + Set of all basic_rvs that were considered volatile and will be resampled when + the function is evaluated + """ + if givens_dict is None: + givens_dict = {} + + if basic_rvs is None: + basic_rvs = [] + + if constant_data is None: + constant_data = {} + if constant_coords is None: + constant_coords = set() + + # We define a helper function to check if shared values match to an array + def shared_value_matches(var): + try: + old_array_value = constant_data[var.name] + except KeyError: + return var.name in constant_coords + current_shared_value = var.get_value(borrow=True) + return np.array_equal(old_array_value, current_shared_value) + + # We need a function graph to walk the clients and propagate the volatile property + fg = FunctionGraph(outputs=outputs, clone=False) + + # Walk the graph from inputs to outputs and tag the volatile variables + nodes: List[Variable] = general_toposort( + fg.outputs, deps=lambda x: x.owner.inputs if x.owner else [] + ) + volatile_nodes: Set[Any] = set() + for node in nodes: + if ( + node in fg.outputs + or node in givens_dict + or ( # SharedVariables, except RandomState/Generators + isinstance(node, SharedVariable) + and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) + and not shared_value_matches(node) + ) + or ( # Basic RVs that are not in the trace + node in basic_rvs and node not in vars_in_trace + ) + or ( # Variables that have any volatile input + node.owner and any(inp in volatile_nodes for inp in node.owner.inputs) + ) + ): + volatile_nodes.add(node) + + # Collect the function inputs by walking the graph from the outputs. Inputs will be: + # 1. Random variables that are not volatile + # 2. Variables that have no owner and are not constant or shared + inputs = [] + + def expand(node): + if ( + ( + node.owner is None and not isinstance(node, (Constant, SharedVariable)) + ) # Variables without owners that are not constant or shared + or node in vars_in_trace # Variables in the trace + ) and node not in volatile_nodes: + # This test will include variables without owners, and that are not constant + # or shared, because these nodes will never be considered volatile + inputs.append(node) + if node.owner: + return node.owner.inputs + + # walk produces a generator, so we have to actually exhaust the generator in a list to walk + # the entire graph + list(walk(fg.outputs, expand)) + + # Populate the givens list + givens = [ + ( + node, + value + if isinstance(value, (Variable, Apply)) + else at.constant(value, dtype=getattr(node, "dtype", None), name=node.name), + ) + for node, value in givens_dict.items() + ] + + return ( + compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs), + set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled + ) + + +def draw( + vars: Union[Variable, Sequence[Variable]], + draws: int = 1, + random_seed: RandomState = None, + **kwargs, +) -> Union[np.ndarray, List[np.ndarray]]: + """Draw samples for one variable or a list of variables + + Parameters + ---------- + vars : TensorVariable or iterable of TensorVariable + A variable or a list of variables for which to draw samples. + draws : int, default 1 + Number of samples needed to draw. + random_seed : int, RandomState or numpy_Generator, optional + Seed for the random number generator. + **kwargs : dict, optional + Keyword arguments for :func:`pymc.aesaraf.compile_pymc`. + + Returns + ------- + list of ndarray + A list of numpy arrays. + + Examples + -------- + .. code-block:: python + + import pymc as pm + + # Draw samples for one variable + with pm.Model(): + x = pm.Normal("x") + x_draws = pm.draw(x, draws=100) + print(x_draws.shape) + + # Draw 1000 samples for several variables + with pm.Model(): + x = pm.Normal("x") + y = pm.Normal("y", shape=10) + z = pm.Uniform("z", shape=5) + num_draws = 1000 + # Draw samples of a list variables + draws = pm.draw([x, y, z], draws=num_draws) + assert draws[0].shape == (num_draws,) + assert draws[1].shape == (num_draws, 10) + assert draws[2].shape == (num_draws, 5) + """ + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + + draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs) + + if draws == 1: + return draw_fn() + + # Single variable output + if not isinstance(vars, (list, tuple)): + cast(Callable[[], np.ndarray], draw_fn) + return np.stack([draw_fn() for _ in range(draws)]) + + # Multiple variable output + cast(Callable[[], List[np.ndarray]], draw_fn) + drawn_values = zip(*(draw_fn() for _ in range(draws))) + return [np.stack(v) for v in drawn_values] diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 5ff2b9b5cc4..15300d8be58 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -35,7 +35,7 @@ from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.model import Point, modelcontext -from pymc.sampling import sample_prior_predictive +from pymc.sampling_predictive import sample_prior_predictive from pymc.step_methods.metropolis import MultivariateNormalProposal from pymc.vartypes import discrete_types diff --git a/pymc/tests/distributions/test_mixture.py b/pymc/tests/distributions/test_mixture.py index da439d58853..e62f202db92 100644 --- a/pymc/tests/distributions/test_mixture.py +++ b/pymc/tests/distributions/test_mixture.py @@ -55,12 +55,12 @@ from pymc.distributions.transforms import _default_transform from pymc.math import expand_packed_triangular from pymc.model import Model -from pymc.sampling import ( - draw, - sample, +from pymc.sampling import sample +from pymc.sampling_predictive import ( sample_posterior_predictive, sample_prior_predictive, ) +from pymc.sampling_utils import draw from pymc.step_methods import Metropolis from pymc.tests.distributions.util import ( Domain, diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 1f3e31cfdc6..9083b4cdb65 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -41,7 +41,7 @@ ) from pymc.distributions.shape_utils import change_dist_size, to_tuple from pymc.math import kronecker -from pymc.sampling import draw +from pymc.sampling_utils import draw from pymc.tests.distributions.util import ( BaseTestDistributionRandom, Domain, diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index e4dd45900c9..aae609dd502 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -42,7 +42,9 @@ RandomWalk, ) from pymc.model import Model -from pymc.sampling import draw, sample, sample_posterior_predictive +from pymc.sampling import sample +from pymc.sampling_predictive import sample_posterior_predictive +from pymc.sampling_utils import draw from pymc.tests.distributions.util import assert_moment_is_expected from pymc.tests.helpers import select_by_precision diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index b351d04d15f..64fdb2a40ff 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -12,44 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import re import unittest.mock as mock import warnings from contextlib import ExitStack as does_not_raise from copy import copy -from typing import Tuple import aesara import aesara.tensor as at import numpy as np -import numpy.random as npr import numpy.testing as npt import pytest import scipy.special -import xarray as xr -from aesara import Mode, shared -from aesara.compile import SharedVariable +from aesara import shared from aesara.compile.ops import as_op from arviz import InferenceData -from arviz import from_dict as az_from_dict -from arviz.tests.helpers import check_multiple_attrs -from scipy import stats import pymc as pm -from pymc.aesaraf import compile_pymc from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray from pymc.distributions import transforms from pymc.exceptions import SamplingError -from pymc.sampling import ( - _get_seeds_per_chain, - assign_step_methods, - compile_forward_sampling_function, - get_vars_in_point_list, -) +from pymc.sampling import assign_step_methods from pymc.stats.convergence import SamplerWarning, WarningType from pymc.step_methods import ( NUTS, @@ -614,569 +600,6 @@ def test_errors_and_warnings(self): pm.sampling._choose_backend(trace=strace) -class TestSamplePPC(SeededTest): - def test_normal_scalar(self): - nchains = 2 - ndraws = 500 - with pm.Model() as model: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) - trace = pm.sample( - draws=ndraws, - chains=nchains, - ) - - with model: - # test list input - ppc0 = pm.sample_posterior_predictive( - 10 * [model.initial_point()], return_inferencedata=False - ) - assert "a" in ppc0 - assert len(ppc0["a"][0]) == 10 - # test empty ppc - ppc = pm.sample_posterior_predictive(trace, var_names=[], return_inferencedata=False) - assert len(ppc) == 0 - - # test keep_size parameter - ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False) - assert ppc["a"].shape == (nchains, ndraws) - - # test default case - random_state = self.get_random_state() - idata_ppc = pm.sample_posterior_predictive( - trace, var_names=["a"], random_seed=random_state - ) - ppc = idata_ppc.posterior_predictive - assert "a" in ppc - assert ppc["a"].shape == (nchains, ndraws) - # mu's standard deviation may have changed thanks to a's observed - _, pval = stats.kstest( - (ppc["a"] - trace.posterior["mu"]).values.flatten(), stats.norm(loc=0, scale=1).cdf - ) - assert pval > 0.001 - - def test_normal_scalar_idata(self): - nchains = 2 - ndraws = 500 - with pm.Model() as model: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning) - trace = pm.sample( - draws=ndraws, - chains=nchains, - return_inferencedata=False, - discard_tuned_samples=False, - ) - - assert not isinstance(trace, InferenceData) - - with model: - # test keep_size parameter and idata input - idata = pm.to_inference_data(trace) - assert isinstance(idata, InferenceData) - - ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) - assert ppc["a"].shape == (nchains, ndraws) - - def test_normal_vector(self): - with pm.Model() as model: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) - trace = pm.sample(return_inferencedata=False, draws=12, chains=1) - - with model: - # test list input - ppc0 = pm.sample_posterior_predictive( - 10 * [model.initial_point()], - return_inferencedata=False, - ) - ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False, var_names=[]) - assert len(ppc) == 0 - - ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False) - assert ppc["a"].shape == (trace.nchains, len(trace), 2) - assert ppc0["a"].shape == (1, 10, 2) - - def test_normal_vector_idata(self): - with pm.Model() as model: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) - trace = pm.sample(return_inferencedata=False) - - assert not isinstance(trace, InferenceData) - - with model: - # test keep_size parameter with inference data as input... - idata = pm.to_inference_data(trace) - assert isinstance(idata, InferenceData) - - ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) - assert ppc["a"].shape == (trace.nchains, len(trace), 2) - - def test_exceptions(self): - with pm.Model() as model: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) - idata = pm.sample(idata_kwargs={"log_likelihood": False}) - - with model: - # test wrong type argument - bad_trace = {"mu": stats.norm.rvs(size=1000)} - with pytest.raises(TypeError, match="type for `trace`"): - ppc = pm.sample_posterior_predictive(bad_trace) - - def test_sum_normal(self): - with pm.Model() as model: - a = pm.Normal("a", sigma=0.2) - b = pm.Normal("b", mu=a) - idata = pm.sample(draws=1000, chains=1) - - with model: - # test list input - ppc0 = pm.sample_posterior_predictive( - 10 * [model.initial_point()], return_inferencedata=False - ) - assert ppc0 == {} - ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False, var_names=["b"]) - assert len(ppc) == 1 - assert ppc["b"].shape == ( - 1, - 1000, - ) - scale = np.sqrt(1 + 0.2**2) - _, pval = stats.kstest(ppc["b"].flatten(), stats.norm(scale=scale).cdf) - assert pval > 0.001 - - def test_model_not_drawable_prior(self): - data = np.random.poisson(lam=10, size=200) - model = pm.Model() - with model: - mu = pm.HalfFlat("sigma") - pm.Poisson("foo", mu=mu, observed=data) - with aesara.config.change_flags(mode=fast_unstable_sampling_mode): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - idata = pm.sample(tune=10, draws=40, chains=1) - - with model: - with pytest.raises(NotImplementedError) as excinfo: - pm.sample_prior_predictive(50) - assert "Cannot sample" in str(excinfo.value) - samples = pm.sample_posterior_predictive(idata, return_inferencedata=False) - assert samples["foo"].shape == (1, 40, 200) - - def test_model_shared_variable(self): - rng = np.random.RandomState(9832) - - x = rng.randn(100) - y = x > 0 - x_shared = aesara.shared(x) - y_shared = aesara.shared(y) - samples = 100 - with pm.Model() as model: - coeff = pm.Normal("x", mu=0, sigma=1) - logistic = pm.Deterministic("p", pm.math.sigmoid(coeff * x_shared)) - - obs = pm.Bernoulli("obs", p=logistic, observed=y_shared) - trace = pm.sample( - samples, - chains=1, - return_inferencedata=False, - compute_convergence_checks=False, - random_seed=rng, - ) - - x_shared.set_value([-1, 0, 1.0]) - y_shared.set_value([0, 0, 0]) - - with model: - post_pred = pm.sample_posterior_predictive( - trace, return_inferencedata=False, var_names=["p", "obs"] - ) - - expected_p = np.array([[logistic.eval({coeff: val}) for val in trace["x"][:samples]]]) - assert post_pred["obs"].shape == (1, samples, 3) - npt.assert_allclose(post_pred["p"], expected_p) - - def test_deterministic_of_observed(self): - rng = np.random.RandomState(8442) - - meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(10)) - meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(10)) - nchains = 2 - with pm.Model() as model: - mu_in_1 = pm.Normal("mu_in_1", 0, 2) - sigma_in_1 = pm.HalfNormal("sd_in_1", 1) - mu_in_2 = pm.Normal("mu_in_2", 0, 2) - sigma_in_2 = pm.HalfNormal("sd__in_2", 1) - - in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1) - in_2 = pm.Normal("in_2", mu_in_2, sigma_in_2, observed=meas_in_2) - out_diff = in_1 + in_2 - pm.Deterministic("out", out_diff) - - with aesara.config.change_flags(mode=fast_unstable_sampling_mode): - trace = pm.sample( - tune=100, - draws=100, - chains=nchains, - step=pm.Metropolis(), - return_inferencedata=False, - compute_convergence_checks=False, - random_seed=rng, - ) - - rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-4 - - ppc = pm.sample_posterior_predictive( - return_inferencedata=False, - model=model, - trace=trace, - random_seed=0, - var_names=[var.name for var in (model.deterministics + model.basic_RVs)], - ) - - npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol) - - def test_deterministic_of_observed_modified_interface(self): - rng = np.random.RandomState(4982) - - meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(100)) - meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(100)) - with pm.Model() as model: - mu_in_1 = pm.Normal("mu_in_1", 0, 1, initval=0) - sigma_in_1 = pm.HalfNormal("sd_in_1", 1, initval=1) - mu_in_2 = pm.Normal("mu_in_2", 0, 1, initval=0) - sigma_in_2 = pm.HalfNormal("sd__in_2", 1, initval=1) - - in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1) - in_2 = pm.Normal("in_2", mu_in_2, sigma_in_2, observed=meas_in_2) - out_diff = in_1 + in_2 - pm.Deterministic("out", out_diff) - - with aesara.config.change_flags(mode=fast_unstable_sampling_mode): - trace = pm.sample( - tune=100, - draws=100, - step=pm.Metropolis(), - return_inferencedata=False, - compute_convergence_checks=False, - random_seed=rng, - ) - varnames = [v for v in trace.varnames if v != "out"] - ppc_trace = [ - dict(zip(varnames, row)) for row in zip(*(trace.get_values(v) for v in varnames)) - ] - ppc = pm.sample_posterior_predictive( - return_inferencedata=False, - model=model, - trace=ppc_trace, - var_names=[x.name for x in (model.deterministics + model.basic_RVs)], - ) - - rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-3 - npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol) - - def test_variable_type(self): - with pm.Model() as model: - mu = pm.HalfNormal("mu", 1) - a = pm.Normal("a", mu=mu, sigma=2, observed=np.array([1, 2])) - b = pm.Poisson("b", mu, observed=np.array([1, 2])) - with aesara.config.change_flags(mode=fast_unstable_sampling_mode): - trace = pm.sample( - tune=10, draws=10, compute_convergence_checks=False, return_inferencedata=False - ) - - with model: - ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False) - assert ppc["a"].dtype.kind == "f" - assert ppc["b"].dtype.kind == "i" - - def test_potentials_warning(self): - warning_msg = "The effect of Potentials on other parameters is ignored during" - with pm.Model() as m: - a = pm.Normal("a", 0, 1) - p = pm.Potential("p", a + 1) - obs = pm.Normal("obs", a, 1, observed=5) - - trace = az_from_dict({"a": np.random.rand(5)}) - with m: - with pytest.warns(UserWarning, match=warning_msg): - pm.sample_posterior_predictive(trace) - - def test_idata_extension(self): - """Testing if sample_posterior_predictive() extends inferenceData""" - - with pm.Model() as model: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1, observed=[0.0, 1.0]) - idata = pm.sample(tune=10, draws=10, compute_convergence_checks=False) - - base_test_dict = { - "posterior": ["mu", "~a"], - "sample_stats": ["diverging", "lp"], - "log_likelihood": ["a"], - "observed_data": ["a"], - } - test_dict = {"~posterior_predictive": [], "~predictions": [], **base_test_dict} - fails = check_multiple_attrs(test_dict, idata) - assert not fails - - # extending idata with in-sample ppc - with model: - pm.sample_posterior_predictive(idata, extend_inferencedata=True) - # test addition - test_dict = {"posterior_predictive": ["a"], "~predictions": [], **base_test_dict} - fails = check_multiple_attrs(test_dict, idata) - assert not fails - - # extending idata with out-of-sample ppc - with model: - pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True) - # test addition - test_dict = {"posterior_predictive": ["a"], "predictions": ["a"], **base_test_dict} - fails = check_multiple_attrs(test_dict, idata) - assert not fails - - @pytest.mark.parametrize("multitrace", [False, True]) - def test_deterministics_out_of_idata(self, multitrace): - draws = 10 - chains = 2 - coords = {"draw": range(draws), "chain": range(chains)} - ds = xr.Dataset( - { - "a": xr.DataArray( - [[0] * draws] * chains, - coords=coords, - dims=["chain", "draw"], - ) - }, - coords=coords, - ) - with pm.Model() as m: - a = pm.Normal("a") - if multitrace: - straces = [] - for chain in ds.chain: - strace = pm.backends.NDArray(model=m, vars=[a]) - strace.setup(len(ds.draw), int(chain)) - strace.values = {"a": ds.a.sel(chain=chain).data} - strace.draw_idx = len(ds.draw) - straces.append(strace) - trace = MultiTrace(straces) - else: - trace = ds - - d = pm.Deterministic("d", a - 4) - pm.Normal("c", d, sigma=0.01) - ppc = pm.sample_posterior_predictive(trace, var_names="c", return_inferencedata=True) - assert np.all(np.abs(ppc.posterior_predictive.c + 4) <= 0.1) - - def test_logging_sampled_basic_rvs_prior(self, caplog): - with pm.Model() as m: - x = pm.Normal("x") - y = pm.Deterministic("y", x + 1) - z = pm.Normal("z", y, observed=0) - - with m: - pm.sample_prior_predictive(samples=1) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, z]")] - caplog.clear() - - with m: - pm.sample_prior_predictive(samples=1, var_names=["x"]) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x]")] - caplog.clear() - - def test_logging_sampled_basic_rvs_posterior(self, caplog): - with pm.Model() as m: - x = pm.Normal("x") - x_det = pm.Deterministic("x_det", x + 1) - y = pm.Normal("y", x_det) - z = pm.Normal("z", y, observed=0) - - idata = az_from_dict(posterior={"x": np.zeros(5), "x_det": np.ones(5), "y": np.ones(5)}) - with m: - pm.sample_posterior_predictive(idata) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [z]")] - caplog.clear() - - with m: - pm.sample_posterior_predictive(idata, var_names=["y", "z"]) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] - caplog.clear() - - # Resampling `x` will force resampling of `y`, even if it is in trace - with m: - pm.sample_posterior_predictive(idata, var_names=["x", "z"]) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, y, z]")] - caplog.clear() - - # Missing deterministic `x_det` does not show in the log, even if it is being - # recomputed, only `y` RV shows - idata = az_from_dict(posterior={"x": np.zeros(5)}) - with m: - pm.sample_posterior_predictive(idata) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] - caplog.clear() - - # Missing deterministic `x_det` does not cause recomputation of downstream `y` RV - idata = az_from_dict(posterior={"x": np.zeros(5), "y": np.ones(5)}) - with m: - pm.sample_posterior_predictive(idata) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [z]")] - caplog.clear() - - # Missing `x` causes sampling of downstream `y` RV, even if it is present in trace - idata = az_from_dict(posterior={"y": np.ones(5)}) - with m: - pm.sample_posterior_predictive(idata) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, y, z]")] - caplog.clear() - - def test_logging_sampled_basic_rvs_posterior_deterministic(self, caplog): - with pm.Model() as m: - x = pm.Normal("x") - x_det = pm.Deterministic("x_det", x + 1) - y = pm.Normal("y", x_det) - z = pm.Normal("z", y, observed=0) - - # Explicit resampling a deterministic will lead to resampling of downstream RV `y` - # This behavior could change in the future as the posterior of `y` is still valid - idata = az_from_dict(posterior={"x": np.zeros(5), "x_det": np.ones(5), "y": np.ones(5)}) - with m: - pm.sample_posterior_predictive(idata, var_names=["x_det", "z"]) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] - caplog.clear() - - @staticmethod - def make_mock_model(): - rng = np.random.default_rng(seed=42) - data = rng.normal(loc=1, scale=0.2, size=(10, 3)) - with pm.Model() as model: - model.add_coord("name", ["A", "B", "C"], mutable=True) - model.add_coord("obs", list(range(10, 20)), mutable=True) - offsets = pm.MutableData("offsets", rng.normal(0, 1, size=(10,))) - a = pm.Normal("a", mu=0, sigma=1, dims=["name"]) - b = pm.Normal("b", mu=offsets, sigma=1) - mu = pm.Deterministic("mu", a + b[..., None], dims=["obs", "name"]) - sigma = pm.HalfNormal("sigma", sigma=1, dims=["name"]) - - data = pm.MutableData( - "y_obs", - data, - dims=["obs", "name"], - ) - pm.Normal("y", mu=mu, sigma=sigma, observed=data, dims=["obs", "name"]) - return model - - @pytest.fixture(scope="class") - def mock_multitrace(self): - with self.make_mock_model(): - trace = pm.sample( - draws=10, - tune=10, - chains=2, - progressbar=False, - compute_convergence_checks=False, - return_inferencedata=False, - random_seed=42, - ) - return trace - - @pytest.fixture(scope="class", params=["MultiTrace", "InferenceData", "Dataset"]) - def mock_sample_results(self, request, mock_multitrace): - kind = request.param - trace = mock_multitrace - # We rebuild the class to ensure that all dimensions, data and coords start out - # the same across params values - model = self.make_mock_model() - if kind == "MultiTrace": - return kind, trace, model - else: - idata = pm.to_inference_data( - trace, - save_warmup=False, - model=model, - log_likelihood=False, - ) - if kind == "Dataset": - return kind, idata.posterior, model - else: - return kind, idata, model - - def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results, caplog): - kind, samples, model = mock_sample_results - with model: - pm.sample_posterior_predictive(samples) - if kind == "MultiTrace": - # MultiTrace will only have the actual MCMC posterior samples but no information on - # the MutableData and mutable coordinate values, so it will always assume they are volatile - # and resample their descendants - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] - caplog.clear() - elif kind == "InferenceData": - # InferenceData has all MCMC posterior samples and the values for both coordinates and - # data containers. This enables it to see that no data has changed and it should only - # resample the observed variable - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y]")] - caplog.clear() - elif kind == "Dataset": - # Dataset has all MCMC posterior samples and the values of the coordinates. This - # enables it to see that the coordinates have not changed, but the MutableData is - # assumed volatile by default - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] - caplog.clear() - - original_offsets = model["offsets"].get_value() - with model: - # Changing the MutableData values. This will only be picked up by InferenceData - pm.set_data({"offsets": original_offsets + 1}) - pm.sample_posterior_predictive(samples) - if kind == "MultiTrace": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] - caplog.clear() - elif kind == "InferenceData": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] - caplog.clear() - elif kind == "Dataset": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] - caplog.clear() - - with model: - # Changing the mutable coordinates. This will be picked up by InferenceData and Dataset - model.set_dim("name", new_length=4, coord_values=["D", "E", "F", "G"]) - pm.set_data({"offsets": original_offsets, "y_obs": np.zeros((10, 4))}) - pm.sample_posterior_predictive(samples) - if kind == "MultiTrace": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] - caplog.clear() - elif kind == "InferenceData": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, sigma, y]")] - caplog.clear() - elif kind == "Dataset": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] - caplog.clear() - - with model: - # Changing the mutable coordinate values, but not shape, and also changing MutableData. - # This will trigger resampling of all variables - model.set_dim("name", new_length=3, coord_values=["A", "B", "D"]) - pm.set_data({"offsets": original_offsets + 1, "y_obs": np.zeros((10, 3))}) - pm.sample_posterior_predictive(samples) - if kind == "MultiTrace": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] - caplog.clear() - elif kind == "InferenceData": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] - caplog.clear() - elif kind == "Dataset": - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] - caplog.clear() - - def check_exec_nuts_init(method): with pm.Model() as model: pm.Normal("a", mu=0, sigma=1, size=2) @@ -1244,404 +667,6 @@ def test_init_jitter(initval, jitter_max_retries, expectation): m.check_start_vals(start) -@pytest.fixture(scope="class") -def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]: - with pm.Model() as pmodel: - n = pm.Normal("n") - trace = pm.sample(return_inferencedata=False) - - with pmodel: - d = pm.Deterministic("d", n * 4) - return pmodel, trace - - -class TestSamplePriorPredictive(SeededTest): - def test_ignores_observed(self): - observed = np.random.normal(10, 1, size=200) - with pm.Model(): - # Use a prior that's way off to show we're ignoring the observed variables - observed_data = pm.MutableData("observed_data", observed) - mu = pm.Normal("mu", mu=-100, sigma=1) - positive_mu = pm.Deterministic("positive_mu", np.abs(mu)) - z = -1 - positive_mu - pm.Normal("x_obs", mu=z, sigma=1, observed=observed_data) - prior = pm.sample_prior_predictive(return_inferencedata=False) - - assert "observed_data" not in prior - assert (prior["mu"] < -90).all() - assert (prior["positive_mu"] > 90).all() - assert (prior["x_obs"] < -90).all() - assert prior["x_obs"].shape == (500, 200) - npt.assert_array_almost_equal(prior["positive_mu"], np.abs(prior["mu"]), decimal=4) - - def test_respects_shape(self): - for shape in (2, (2,), (10, 2), (10, 10)): - with pm.Model(): - mu = pm.Gamma("mu", 3, 1, size=1) - goals = pm.Poisson("goals", mu, size=shape) - trace1 = pm.sample_prior_predictive( - 10, return_inferencedata=False, var_names=["mu", "mu", "goals"] - ) - trace2 = pm.sample_prior_predictive( - 10, return_inferencedata=False, var_names=["mu", "goals"] - ) - if shape == 2: # want to test shape as an int - shape = (2,) - assert trace1["goals"].shape == (10,) + shape - assert trace2["goals"].shape == (10,) + shape - - def test_multivariate(self): - with pm.Model(): - m = pm.Multinomial("m", n=5, p=np.array([0.25, 0.25, 0.25, 0.25])) - trace = pm.sample_prior_predictive(10) - - assert trace.prior["m"].shape == (1, 10, 4) - - def test_multivariate2(self): - # Added test for issue #3271 - mn_data = np.random.multinomial(n=100, pvals=[1 / 6.0] * 6, size=10) - with pm.Model() as dm_model: - probs = pm.Dirichlet("probs", a=np.ones(6)) - obs = pm.Multinomial("obs", n=100, p=probs, observed=mn_data) - with aesara.config.change_flags(mode=fast_unstable_sampling_mode): - burned_trace = pm.sample( - tune=10, - draws=20, - chains=1, - return_inferencedata=False, - compute_convergence_checks=False, - ) - sim_priors = pm.sample_prior_predictive( - return_inferencedata=False, samples=20, model=dm_model - ) - sim_ppc = pm.sample_posterior_predictive( - burned_trace, return_inferencedata=False, model=dm_model - ) - assert sim_priors["probs"].shape == (20, 6) - assert sim_priors["obs"].shape == (20,) + mn_data.shape - assert sim_ppc["obs"].shape == (1, 20) + mn_data.shape - - def test_layers(self): - with pm.Model() as model: - a = pm.Uniform("a", lower=0, upper=1, size=10) - b = pm.Binomial("b", n=1, p=a, size=10) - - b_sampler = compile_pymc([], b, mode="FAST_RUN", random_seed=232093) - avg = np.stack([b_sampler() for i in range(10000)]).mean(0) - npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2) - - def test_transformed(self): - n = 18 - at_bats = 45 * np.ones(n, dtype=int) - hits = np.random.randint(1, 40, size=n, dtype=int) - draws = 50 - - with pm.Model() as model: - phi = pm.Beta("phi", alpha=1.0, beta=1.0) - - kappa_log = pm.Exponential("logkappa", lam=5.0) - kappa = pm.Deterministic("kappa", at.exp(kappa_log)) - - thetas = pm.Beta("thetas", alpha=phi * kappa, beta=(1.0 - phi) * kappa, size=n) - - y = pm.Binomial("y", n=at_bats, p=thetas, observed=hits) - gen = pm.sample_prior_predictive(draws) - - assert gen.prior["phi"].shape == (1, draws) - assert gen.prior_predictive["y"].shape == (1, draws, n) - assert "thetas" in gen.prior.data_vars - - def test_shared(self): - n1 = 10 - obs = shared(np.random.rand(n1) < 0.5) - draws = 50 - - with pm.Model() as m: - p = pm.Beta("p", 1.0, 1.0) - y = pm.Bernoulli("y", p, observed=obs) - o = pm.Deterministic("o", obs) - gen1 = pm.sample_prior_predictive(draws) - - assert gen1.prior_predictive["y"].shape == (1, draws, n1) - assert gen1.prior["o"].shape == (1, draws, n1) - - n2 = 20 - obs.set_value(np.random.rand(n2) < 0.5) - with m: - gen2 = pm.sample_prior_predictive(draws) - - assert gen2.prior_predictive["y"].shape == (1, draws, n2) - assert gen2.prior["o"].shape == (1, draws, n2) - - def test_density_dist(self): - obs = np.random.normal(-1, 0.1, size=10) - with pm.Model(): - mu = pm.Normal("mu", 0, 1) - sigma = pm.HalfNormal("sigma", 1e-6) - a = pm.DensityDist( - "a", - mu, - sigma, - random=lambda mu, sigma, rng=None, size=None: rng.normal( - loc=mu, scale=sigma, size=size - ), - observed=obs, - ) - prior = pm.sample_prior_predictive(return_inferencedata=False) - - npt.assert_almost_equal((prior["a"] - prior["mu"][..., None]).mean(), 0, decimal=3) - - def test_shape_edgecase(self): - with pm.Model(): - mu = pm.Normal("mu", size=5) - sigma = pm.Uniform("sigma", lower=2, upper=3) - x = pm.Normal("x", mu=mu, sigma=sigma, size=5) - prior = pm.sample_prior_predictive(10) - assert prior.prior["mu"].shape == (1, 10, 5) - - def test_zeroinflatedpoisson(self): - with pm.Model(): - mu = pm.Beta("mu", alpha=1, beta=1) - psi = pm.HalfNormal("psi", sigma=1) - pm.ZeroInflatedPoisson("suppliers", psi=psi, mu=mu, size=20) - gen_data = pm.sample_prior_predictive(samples=5000) - assert gen_data.prior["mu"].shape == (1, 5000) - assert gen_data.prior["psi"].shape == (1, 5000) - assert gen_data.prior["suppliers"].shape == (1, 5000, 20) - - def test_potentials_warning(self): - warning_msg = "The effect of Potentials on other parameters is ignored during" - with pm.Model() as m: - a = pm.Normal("a", 0, 1) - p = pm.Potential("p", a + 1) - - with m: - with pytest.warns(UserWarning, match=warning_msg): - pm.sample_prior_predictive(samples=5) - - def test_transformed_vars(self): - # Test that prior predictive returns transformation of RVs when these are - # passed explicitly in `var_names` - - def ub_interval_forward(x, ub): - # Interval transform assuming lower bound is zero - return np.log(x - 0) - np.log(ub - x) - - with pm.Model() as model: - ub = pm.HalfNormal("ub", 10) - x = pm.Uniform("x", 0, ub) - - prior = pm.sample_prior_predictive( - var_names=["ub", "ub_log__", "x", "x_interval__"], - samples=10, - random_seed=123, - ) - - # Check values are correct - assert np.allclose(prior.prior["ub_log__"].data, np.log(prior.prior["ub"].data)) - assert np.allclose( - prior.prior["x_interval__"].data, - ub_interval_forward(prior.prior["x"].data, prior.prior["ub"].data), - ) - - # Check that it works when the original RVs are not mentioned in var_names - with pm.Model() as model_transformed_only: - ub = pm.HalfNormal("ub", 10) - x = pm.Uniform("x", 0, ub) - - prior_transformed_only = pm.sample_prior_predictive( - var_names=["ub_log__", "x_interval__"], - samples=10, - random_seed=123, - ) - assert ( - "ub" not in prior_transformed_only.prior.data_vars - and "x" not in prior_transformed_only.prior.data_vars - ) - assert np.allclose( - prior.prior["ub_log__"].data, prior_transformed_only.prior["ub_log__"].data - ) - assert np.allclose( - prior.prior["x_interval__"], prior_transformed_only.prior["x_interval__"].data - ) - - def test_issue_4490(self): - # Test that samples do not depend on var_name order or, more fundamentally, - # that they do not depend on the set order used inside `sample_prior_predictive` - seed = 4490 - with pm.Model() as m1: - a = pm.Normal("a") - b = pm.Normal("b") - c = pm.Normal("c") - d = pm.Normal("d") - prior1 = pm.sample_prior_predictive( - samples=1, var_names=["a", "b", "c", "d"], random_seed=seed - ) - - with pm.Model() as m2: - a = pm.Normal("a") - b = pm.Normal("b") - c = pm.Normal("c") - d = pm.Normal("d") - prior2 = pm.sample_prior_predictive( - samples=1, var_names=["b", "a", "d", "c"], random_seed=seed - ) - - assert prior1.prior["a"] == prior2.prior["a"] - assert prior1.prior["b"] == prior2.prior["b"] - assert prior1.prior["c"] == prior2.prior["c"] - assert prior1.prior["d"] == prior2.prior["d"] - - def test_aesara_function_kwargs(self): - sharedvar = aesara.shared(0) - with pm.Model() as m: - x = pm.DiracDelta("x", 0) - y = pm.Deterministic("y", x + sharedvar) - - prior = pm.sample_prior_predictive( - samples=5, - return_inferencedata=False, - compile_kwargs=dict( - mode=Mode("py"), - updates={sharedvar: sharedvar + 1}, - ), - ) - - assert np.all(prior["y"] == np.arange(5)) - - -class TestSamplePosteriorPredictive: - def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture): - pmodel, trace = point_list_arg_bug_fixture - with pmodel: - pp = pm.sample_posterior_predictive( - [trace[15]], return_inferencedata=False, var_names=["d"] - ) - - def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture): - pmodel, trace = point_list_arg_bug_fixture - - with pmodel: - prior = pm.sample_prior_predictive( - samples=20, - return_inferencedata=False, - ) - idat = pm.to_inference_data(trace, prior=prior) - - with pmodel: - pp = pm.sample_posterior_predictive( - idat.prior, return_inferencedata=False, var_names=["d"] - ) - - def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture): - pmodel, trace = point_list_arg_bug_fixture - with pmodel: - idat = pm.to_inference_data(trace) - pp = pm.sample_posterior_predictive(idat.posterior, var_names=["d"]) - - def test_aesara_function_kwargs(self): - sharedvar = aesara.shared(0) - with pm.Model() as m: - x = pm.DiracDelta("x", 0.0) - y = pm.Deterministic("y", x + sharedvar) - - pp = pm.sample_posterior_predictive( - trace=az_from_dict({"x": np.arange(5)}), - var_names=["y"], - return_inferencedata=False, - compile_kwargs=dict( - mode=Mode("py"), - updates={sharedvar: sharedvar + 1}, - ), - ) - - assert np.all(pp["y"] == np.arange(5) * 2) - - def test_sample_dims(self, point_list_arg_bug_fixture): - pmodel, trace = point_list_arg_bug_fixture - with pmodel: - post = pm.to_inference_data(trace).posterior.stack(sample=["chain", "draw"]) - pp = pm.sample_posterior_predictive(post, var_names=["d"], sample_dims=["sample"]) - assert "sample" in pp.posterior_predictive - assert len(pp.posterior_predictive["sample"]) == len(post["sample"]) - post = post.expand_dims(pred_id=5) - pp = pm.sample_posterior_predictive( - post, var_names=["d"], sample_dims=["sample", "pred_id"] - ) - assert "sample" in pp.posterior_predictive - assert "pred_id" in pp.posterior_predictive - assert len(pp.posterior_predictive["sample"]) == len(post["sample"]) - assert len(pp.posterior_predictive["pred_id"]) == 5 - - -class TestDraw(SeededTest): - def test_univariate(self): - with pm.Model(): - x = pm.Normal("x") - - x_draws = pm.draw(x) - assert x_draws.shape == () - - (x_draws,) = pm.draw([x]) - assert x_draws.shape == () - - x_draws = pm.draw(x, draws=10) - assert x_draws.shape == (10,) - - (x_draws,) = pm.draw([x], draws=10) - assert x_draws.shape == (10,) - - def test_multivariate(self): - with pm.Model(): - mln = pm.Multinomial("mln", n=5, p=np.array([0.25, 0.25, 0.25, 0.25])) - - mln_draws = pm.draw(mln, draws=1) - assert mln_draws.shape == (4,) - - (mln_draws,) = pm.draw([mln], draws=1) - assert mln_draws.shape == (4,) - - mln_draws = pm.draw(mln, draws=10) - assert mln_draws.shape == (10, 4) - - (mln_draws,) = pm.draw([mln], draws=10) - assert mln_draws.shape == (10, 4) - - def test_multiple_variables(self): - with pm.Model(): - x = pm.Normal("x") - y = pm.Normal("y", shape=10) - z = pm.Uniform("z", shape=5) - w = pm.Dirichlet("w", a=[1, 1, 1]) - - num_draws = 100 - draws = pm.draw((x, y, z, w), draws=num_draws) - assert draws[0].shape == (num_draws,) - assert draws[1].shape == (num_draws, 10) - assert draws[2].shape == (num_draws, 5) - assert draws[3].shape == (num_draws, 3) - - def test_draw_different_samples(self): - with pm.Model(): - x = pm.Normal("x") - - x_draws_1 = pm.draw(x, 100) - x_draws_2 = pm.draw(x, 100) - assert not np.all(np.isclose(x_draws_1, x_draws_2)) - - def test_draw_aesara_function_kwargs(self): - sharedvar = aesara.shared(0) - x = pm.DiracDelta.dist(0.0) - y = x + sharedvar - draws = pm.draw( - y, - draws=5, - mode=Mode("py"), - updates={sharedvar: sharedvar + 1}, - ) - assert np.all(draws == np.arange(5)) - - def test_step_args(): with pm.Model() as model: a = pm.Normal("a") @@ -1804,679 +829,6 @@ def test_no_init_nuts_compound(caplog): assert "Initializing NUTS" not in caplog.text -class TestCompileForwardSampler: - @staticmethod - def get_function_roots(function): - return [ - var - for var in aesara.graph.basic.graph_inputs(function.maker.fgraph.outputs) - if var.name - ] - - @staticmethod - def get_function_inputs(function): - return {i for i in function.maker.fgraph.inputs if not isinstance(i, SharedVariable)} - - def test_linear_model(self): - with pm.Model() as model: - x = pm.MutableData("x", np.linspace(0, 1, 10)) - y = pm.MutableData("y", np.ones(10)) - - alpha = pm.Normal("alpha", 0, 0.1) - beta = pm.Normal("beta", 0, 0.1) - mu = pm.Deterministic("mu", alpha + beta * x) - sigma = pm.HalfNormal("sigma", 0.1) - obs = pm.Normal("obs", mu, sigma, observed=y, shape=x.shape) - - f, volatile_rvs = compile_forward_sampling_function( - [obs], - vars_in_trace=[alpha, beta, sigma, mu], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {obs} - assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma"} - assert {i.name for i in self.get_function_roots(f)} == {"x", "alpha", "beta", "sigma"} - - with pm.Model() as model: - x = pm.ConstantData("x", np.linspace(0, 1, 10)) - y = pm.MutableData("y", np.ones(10)) - - alpha = pm.Normal("alpha", 0, 0.1) - beta = pm.Normal("beta", 0, 0.1) - mu = pm.Deterministic("mu", alpha + beta * x) - sigma = pm.HalfNormal("sigma", 0.1) - obs = pm.Normal("obs", mu, sigma, observed=y, shape=x.shape) - - f, volatile_rvs = compile_forward_sampling_function( - [obs], - vars_in_trace=[alpha, beta, sigma, mu], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {obs} - assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma", "mu"} - assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} - - def test_nested_observed_model(self): - with pm.Model() as model: - p = pm.ConstantData("p", np.array([0.25, 0.5, 0.25])) - x = pm.MutableData("x", np.zeros(10)) - y = pm.MutableData("y", np.ones(10)) - - category = pm.Categorical("category", p, observed=x) - beta = pm.Normal("beta", 0, 0.1, size=p.shape) - mu = pm.Deterministic("mu", beta[category]) - sigma = pm.HalfNormal("sigma", 0.1) - obs = pm.Normal("obs", mu, sigma, observed=y, shape=mu.shape) - - f, volatile_rvs = compile_forward_sampling_function( - outputs=model.observed_RVs, - vars_in_trace=[beta, mu, sigma], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {category, obs} - assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} - assert {i.name for i in self.get_function_roots(f)} == {"x", "p", "beta", "sigma"} - - f, volatile_rvs = compile_forward_sampling_function( - outputs=model.observed_RVs, - vars_in_trace=[beta, mu, sigma], - basic_rvs=model.basic_RVs, - givens_dict={category: np.zeros(10, dtype=category.dtype)}, - ) - assert volatile_rvs == {obs} - assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} - assert {i.name for i in self.get_function_roots(f)} == { - "x", - "p", - "category", - "beta", - "sigma", - } - - def test_volatile_parameters(self): - with pm.Model() as model: - y = pm.MutableData("y", np.ones(10)) - mu = pm.Normal("mu", 0, 1) - nested_mu = pm.Normal("nested_mu", mu, 1, size=10) - sigma = pm.HalfNormal("sigma", 1) - obs = pm.Normal("obs", nested_mu, sigma, observed=y, shape=nested_mu.shape) - - f, volatile_rvs = compile_forward_sampling_function( - outputs=model.observed_RVs, - vars_in_trace=[nested_mu, sigma], # mu isn't in the trace and will be deemed volatile - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {mu, nested_mu, obs} - assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} - assert {i.name for i in self.get_function_roots(f)} == {"sigma"} - - f, volatile_rvs = compile_forward_sampling_function( - outputs=model.observed_RVs, - vars_in_trace=[mu, nested_mu, sigma], - basic_rvs=model.basic_RVs, - givens_dict={ - mu: np.array(1.0) - }, # mu will be considered volatile because it's in givens - ) - assert volatile_rvs == {nested_mu, obs} - assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} - assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} - - def test_mixture(self): - with pm.Model() as model: - w = pm.Dirichlet("w", a=np.ones(3), size=(5, 3)) - - mu = pm.Normal("mu", mu=np.arange(3), sigma=1) - - components = pm.Normal.dist(mu=mu, sigma=1, size=w.shape) - mix_mu = pm.Mixture("mix_mu", w=w, comp_dists=components) - obs = pm.Normal("obs", mix_mu, 1, observed=np.ones((5, 3))) - - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[mix_mu, mu, w], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {obs} - assert {i.name for i in self.get_function_inputs(f)} == {"w", "mu", "mix_mu"} - assert {i.name for i in self.get_function_roots(f)} == {"mix_mu"} - - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[mu, w], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {mix_mu, obs} - assert {i.name for i in self.get_function_inputs(f)} == {"w", "mu"} - assert {i.name for i in self.get_function_roots(f)} == {"w", "mu"} - - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[mix_mu, mu], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {w, mix_mu, obs} - assert {i.name for i in self.get_function_inputs(f)} == {"mu"} - assert {i.name for i in self.get_function_roots(f)} == {"mu"} - - def test_censored(self): - with pm.Model() as model: - latent_mu = pm.Normal("latent_mu", mu=np.arange(3), sigma=1) - mu = pm.Censored("mu", pm.Normal.dist(mu=latent_mu, sigma=1), lower=-1, upper=1) - obs = pm.Normal("obs", mu, 1, observed=np.ones((10, 3))) - - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[latent_mu, mu], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {obs} - assert {i.name for i in self.get_function_inputs(f)} == {"latent_mu", "mu"} - assert {i.name for i in self.get_function_roots(f)} == {"mu"} - - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[mu], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {latent_mu, mu, obs} - assert {i.name for i in self.get_function_inputs(f)} == set() - assert {i.name for i in self.get_function_roots(f)} == set() - - def test_lkj_cholesky_cov(self): - with pm.Model() as model: - mu = np.zeros(3) - sd_dist = pm.Exponential.dist(1.0, size=3) - chol, corr, stds = pm.LKJCholeskyCov( # pylint: disable=unpacking-non-sequence - "chol_packed", n=3, eta=2, sd_dist=sd_dist, compute_corr=True - ) - chol_packed = model["chol_packed"] - chol = pm.Deterministic("chol", chol) - obs = pm.MvNormal("obs", mu=mu, chol=chol, observed=np.zeros(3)) - - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[chol_packed, chol], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {obs} - assert {i.name for i in self.get_function_inputs(f)} == {"chol_packed", "chol"} - assert {i.name for i in self.get_function_roots(f)} == {"chol"} - - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[chol_packed], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {obs} - assert {i.name for i in self.get_function_inputs(f)} == {"chol_packed"} - assert {i.name for i in self.get_function_roots(f)} == {"chol_packed"} - - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[chol], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {chol_packed, obs} - assert {i.name for i in self.get_function_inputs(f)} == set() - assert {i.name for i in self.get_function_roots(f)} == set() - - def test_non_random_model_variable(self): - with pm.Model() as model: - # A user may register non-pure RandomVariables that can nevertheless be - # sampled, as long as a custom logprob is dispatched or Aeppl can infer - # its logprob (which is the case for `clip`) - y = at.clip(pm.Normal.dist(), -1, 1) - y = model.register_rv(y, name="y") - y_abs = pm.Deterministic("y_abs", at.abs(y)) - obs = pm.Normal("obs", y_abs, observed=np.zeros(10)) - - # y_abs should be resampled even if in the trace, because the source y is missing - f, volatile_rvs = compile_forward_sampling_function( - outputs=[obs], - vars_in_trace=[y_abs], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {y, obs} - assert {i.name for i in self.get_function_inputs(f)} == set() - assert {i.name for i in self.get_function_roots(f)} == set() - - def test_mutable_coords_volatile(self): - rng = np.random.default_rng(seed=42) - data = rng.normal(loc=1, scale=0.2, size=(10, 3)) - with pm.Model() as model: - model.add_coord("name", ["A", "B", "C"], mutable=True) - model.add_coord("obs", list(range(10, 20)), mutable=True) - offsets = pm.MutableData("offsets", rng.normal(0, 1, size=(10,))) - a = pm.Normal("a", mu=0, sigma=1, dims=["name"]) - b = pm.Normal("b", mu=offsets, sigma=1) - mu = pm.Deterministic("mu", a + b[..., None], dims=["obs", "name"]) - sigma = pm.HalfNormal("sigma", sigma=1, dims=["name"]) - - data = pm.MutableData( - "y_obs", - data, - dims=["obs", "name"], - ) - y = pm.Normal("y", mu=mu, sigma=sigma, observed=data, dims=["obs", "name"]) - - # When no constant_data and constant_coords, all the dependent nodes will be volatile and - # resampled - f, volatile_rvs = compile_forward_sampling_function( - outputs=[y], - vars_in_trace=[a, b, mu, sigma], - basic_rvs=model.basic_RVs, - ) - assert volatile_rvs == {y, a, b, sigma} - assert {i.name for i in self.get_function_inputs(f)} == set() - assert {i.name for i in self.get_function_roots(f)} == {"name", "obs", "offsets"} - - # When the constant data has the same values as the shared data, offsets wont be volatile - f, volatile_rvs = compile_forward_sampling_function( - outputs=[y], - vars_in_trace=[a, b, mu, sigma], - basic_rvs=model.basic_RVs, - constant_data={"offsets": offsets.get_value()}, - ) - assert volatile_rvs == {y, a, sigma} - assert {i.name for i in self.get_function_inputs(f)} == {"b"} - assert {i.name for i in self.get_function_roots(f)} == {"b", "name", "obs"} - - # When we declare constant_coords, the shared variables with matching names wont be volatile - f, volatile_rvs = compile_forward_sampling_function( - outputs=[y], - vars_in_trace=[a, b, mu, sigma], - basic_rvs=model.basic_RVs, - constant_coords={"name", "obs"}, - ) - assert volatile_rvs == {y, b} - assert {i.name for i in self.get_function_inputs(f)} == {"a", "sigma"} - assert {i.name for i in self.get_function_roots(f)} == { - "a", - "sigma", - "name", - "obs", - "offsets", - } - - # When we have both constant_data and constant_coords, only y will be volatile - f, volatile_rvs = compile_forward_sampling_function( - outputs=[y], - vars_in_trace=[a, b, mu, sigma], - basic_rvs=model.basic_RVs, - constant_data={"offsets": offsets.get_value()}, - constant_coords={"name", "obs"}, - ) - assert volatile_rvs == {y} - assert {i.name for i in self.get_function_inputs(f)} == {"a", "b", "mu", "sigma"} - assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma", "name", "obs"} - - # When constant_data has different values than the shared variable, then - # offsets will be volatile - f, volatile_rvs = compile_forward_sampling_function( - outputs=[y], - vars_in_trace=[a, b, mu, sigma], - basic_rvs=model.basic_RVs, - constant_data={"offsets": offsets.get_value() + 1}, - constant_coords={"name", "obs"}, - ) - assert volatile_rvs == {y, b} - assert {i.name for i in self.get_function_inputs(f)} == {"a", "sigma"} - assert {i.name for i in self.get_function_roots(f)} == { - "a", - "sigma", - "name", - "obs", - "offsets", - } - - -def test_get_seeds_per_chain(): - ret = _get_seeds_per_chain(None, chains=1) - assert len(ret) == 1 and isinstance(ret[0], int) - - ret = _get_seeds_per_chain(None, chains=2) - assert len(ret) == 2 and isinstance(ret[0], int) - - ret = _get_seeds_per_chain(5, chains=1) - assert ret == (5,) - - ret = _get_seeds_per_chain(5, chains=3) - assert len(ret) == 3 and isinstance(ret[0], int) and not any(r == 5 for r in ret) - - rng = np.random.default_rng(123) - expected_ret = rng.integers(2**30, dtype=np.int64, size=1) - rng = np.random.default_rng(123) - ret = _get_seeds_per_chain(rng, chains=1) - assert ret == expected_ret - - rng = np.random.RandomState(456) - expected_ret = rng.randint(2**30, dtype=np.int64, size=2) - rng = np.random.RandomState(456) - ret = _get_seeds_per_chain(rng, chains=2) - assert np.all(ret == expected_ret) - - for expected_ret in ([0, 1, 2], (0, 1, 2, 3), np.arange(5)): - ret = _get_seeds_per_chain(expected_ret, chains=len(expected_ret)) - assert ret is expected_ret - - with pytest.raises(ValueError, match="does not match the number of chains"): - _get_seeds_per_chain(expected_ret, chains=len(expected_ret) + 1) - - with pytest.raises(ValueError, match=re.escape("The `seeds` must be array-like")): - _get_seeds_per_chain({1: 1, 2: 2}, 2) - - -def test_distinct_rvs(): - """Make sure `RandomVariable`s generated using a `Model`'s default RNG state all have distinct states.""" - - with pm.Model() as model: - X_rv = pm.Normal("x") - Y_rv = pm.Normal("y") - - pp_samples = pm.sample_prior_predictive( - samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) - ) - - assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0] - - with pm.Model(): - X_rv = pm.Normal("x") - Y_rv = pm.Normal("y") - - pp_samples_2 = pm.sample_prior_predictive( - samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) - ) - - assert np.array_equal(pp_samples["y"], pp_samples_2["y"]) - - -class TestNestedRandom(SeededTest): - def build_model(self, distribution, shape, nested_rvs_info): - with pm.Model() as model: - nested_rvs = {} - for rv_name, info in nested_rvs_info.items(): - try: - value, nested_shape = info - loc = 0.0 - except ValueError: - value, nested_shape, loc = info - if value is None: - nested_rvs[rv_name] = pm.Uniform( - rv_name, - 0 + loc, - 1 + loc, - shape=nested_shape, - ) - else: - nested_rvs[rv_name] = value * np.ones(nested_shape) - rv = distribution( - "target", - shape=shape, - **nested_rvs, - ) - return model, rv, nested_rvs - - def sample_prior(self, distribution, shape, nested_rvs_info, prior_samples): - model, rv, nested_rvs = self.build_model( - distribution, - shape, - nested_rvs_info, - ) - with model: - return pm.sample_prior_predictive(prior_samples, return_inferencedata=False) - - @pytest.mark.parametrize( - ["prior_samples", "shape", "mu", "alpha"], - [ - [10, (3,), (None, tuple()), (None, (3,))], - [10, (3,), (None, (3,)), (None, tuple())], - [ - 10, - ( - 4, - 3, - ), - (None, (3,)), - (None, (3,)), - ], - [ - 10, - ( - 4, - 3, - ), - (None, (3,)), - (None, (4, 3)), - ], - ], - ids=str, - ) - def test_NegativeBinomial( - self, - prior_samples, - shape, - mu, - alpha, - ): - prior = self.sample_prior( - distribution=pm.NegativeBinomial, - shape=shape, - nested_rvs_info=dict(mu=mu, alpha=alpha), - prior_samples=prior_samples, - ) - assert prior["target"].shape == (prior_samples,) + shape - - @pytest.mark.parametrize( - ["prior_samples", "shape", "psi", "mu", "alpha"], - [ - [10, (3,), (0.5, tuple()), (None, tuple()), (None, (3,))], - [10, (3,), (0.5, (3,)), (None, tuple()), (None, (3,))], - [10, (3,), (0.5, tuple()), (None, (3,)), (None, tuple())], - [10, (3,), (0.5, (3,)), (None, (3,)), (None, tuple())], - [ - 10, - ( - 4, - 3, - ), - (0.5, (3,)), - (None, (3,)), - (None, (3,)), - ], - [ - 10, - ( - 4, - 3, - ), - (0.5, (3,)), - (None, (3,)), - (None, (4, 3)), - ], - ], - ids=str, - ) - def test_ZeroInflatedNegativeBinomial( - self, - prior_samples, - shape, - psi, - mu, - alpha, - ): - prior = self.sample_prior( - distribution=pm.ZeroInflatedNegativeBinomial, - shape=shape, - nested_rvs_info=dict(psi=psi, mu=mu, alpha=alpha), - prior_samples=prior_samples, - ) - assert prior["target"].shape == (prior_samples,) + shape - - @pytest.mark.parametrize( - ["prior_samples", "shape", "nu", "sigma"], - [ - [10, (3,), (None, tuple()), (None, (3,))], - [10, (3,), (None, tuple()), (None, (3,))], - [10, (3,), (None, (3,)), (None, tuple())], - [10, (3,), (None, (3,)), (None, tuple())], - [ - 10, - ( - 4, - 3, - ), - (None, (3,)), - (None, (3,)), - ], - [ - 10, - ( - 4, - 3, - ), - (None, (3,)), - (None, (4, 3)), - ], - ], - ids=str, - ) - def test_Rice( - self, - prior_samples, - shape, - nu, - sigma, - ): - prior = self.sample_prior( - distribution=pm.Rice, - shape=shape, - nested_rvs_info=dict(nu=nu, sigma=sigma), - prior_samples=prior_samples, - ) - assert prior["target"].shape == (prior_samples,) + shape - - @pytest.mark.parametrize( - ["prior_samples", "shape", "mu", "sigma", "lower", "upper"], - [ - [10, (3,), (None, tuple()), (1.0, tuple()), (None, tuple(), -1), (None, (3,))], - [10, (3,), (None, tuple()), (1.0, tuple()), (None, tuple(), -1), (None, (3,))], - [10, (3,), (None, tuple()), (1.0, tuple()), (None, (3,), -1), (None, tuple())], - [10, (3,), (None, tuple()), (1.0, tuple()), (None, (3,), -1), (None, tuple())], - [ - 10, - ( - 4, - 3, - ), - (None, (3,)), - (1.0, tuple()), - (None, (3,), -1), - (None, (3,)), - ], - [ - 10, - ( - 4, - 3, - ), - (None, (3,)), - (1.0, tuple()), - (None, (3,), -1), - (None, (4, 3)), - ], - [10, (3,), (0.0, tuple()), (None, tuple()), (None, tuple(), -1), (None, (3,))], - [10, (3,), (0.0, tuple()), (None, tuple()), (None, tuple(), -1), (None, (3,))], - [10, (3,), (0.0, tuple()), (None, tuple()), (None, (3,), -1), (None, tuple())], - [10, (3,), (0.0, tuple()), (None, tuple()), (None, (3,), -1), (None, tuple())], - [ - 10, - ( - 4, - 3, - ), - (0.0, tuple()), - (None, (3,)), - (None, (3,), -1), - (None, (3,)), - ], - [ - 10, - ( - 4, - 3, - ), - (0.0, tuple()), - (None, (3,)), - (None, (3,), -1), - (None, (4, 3)), - ], - ], - ids=str, - ) - def test_TruncatedNormal( - self, - prior_samples, - shape, - mu, - sigma, - lower, - upper, - ): - prior = self.sample_prior( - distribution=pm.TruncatedNormal, - shape=shape, - nested_rvs_info=dict(mu=mu, sigma=sigma, lower=lower, upper=upper), - prior_samples=prior_samples, - ) - assert prior["target"].shape == (prior_samples,) + shape - - @pytest.mark.parametrize( - ["prior_samples", "shape", "c", "lower", "upper"], - [ - [10, (3,), (None, tuple()), (-1.0, (3,)), (2, tuple())], - [10, (3,), (None, tuple()), (-1.0, tuple()), (None, tuple(), 1)], - [10, (3,), (None, (3,)), (-1.0, tuple()), (None, tuple(), 1)], - [ - 10, - ( - 4, - 3, - ), - (None, (3,)), - (-1.0, tuple()), - (None, (3,), 1), - ], - [ - 10, - ( - 4, - 3, - ), - (None, (3,)), - (None, tuple(), -1), - (None, (3,), 1), - ], - ], - ids=str, - ) - def test_Triangular( - self, - prior_samples, - shape, - c, - lower, - upper, - ): - prior = self.sample_prior( - distribution=pm.Triangular, - shape=shape, - nested_rvs_info=dict(c=c, lower=lower, upper=upper), - prior_samples=prior_samples, - ) - assert prior["target"].shape == (prior_samples,) + shape - - class TestAssignStepMethods: def test_bernoulli(self): """Test bernoulli distribution is assigned binary gibbs metropolis method""" @@ -2633,24 +985,3 @@ def test_sample(self): np.testing.assert_allclose( x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1 ) - - -def test_get_vars_in_point_list(): - with pm.Model() as modelA: - pm.Normal("a", 0, 1) - pm.Normal("b", 0, 1) - with pm.Model() as modelB: - a = pm.Normal("a", 0, 1) - pm.Normal("c", 0, 1) - - point_list = [{"a": 0, "b": 0}] - vars_in_trace = get_vars_in_point_list(point_list, modelB) - assert set(vars_in_trace) == {a} - - strace = pm.backends.NDArray(model=modelB, vars=modelA.free_RVs) - strace.setup(1, 1) - strace.values = point_list[0] - strace.draw_idx = 1 - trace = MultiTrace([strace]) - vars_in_trace = get_vars_in_point_list(trace, modelB) - assert set(vars_in_trace) == {a} diff --git a/pymc/tests/test_sampling_predictive.py b/pymc/tests/test_sampling_predictive.py new file mode 100644 index 00000000000..f0edd10c283 --- /dev/null +++ b/pymc/tests/test_sampling_predictive.py @@ -0,0 +1,1240 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import warnings + +from typing import Tuple + +import aesara +import aesara.tensor as at +import numpy as np +import numpy.random as npr +import numpy.testing as npt +import pytest +import xarray as xr + +from aesara import Mode, shared +from arviz import InferenceData +from arviz import from_dict as az_from_dict +from arviz.tests.helpers import check_multiple_attrs +from scipy import stats + +import pymc as pm + +from pymc.aesaraf import compile_pymc +from pymc.backends.base import MultiTrace +from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode + + +class TestSamplePPC(SeededTest): + def test_normal_scalar(self): + nchains = 2 + ndraws = 500 + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) + trace = pm.sample( + draws=ndraws, + chains=nchains, + ) + + with model: + # test list input + ppc0 = pm.sample_posterior_predictive( + 10 * [model.initial_point()], return_inferencedata=False + ) + assert "a" in ppc0 + assert len(ppc0["a"][0]) == 10 + # test empty ppc + ppc = pm.sample_posterior_predictive(trace, var_names=[], return_inferencedata=False) + assert len(ppc) == 0 + + # test keep_size parameter + ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False) + assert ppc["a"].shape == (nchains, ndraws) + + # test default case + random_state = self.get_random_state() + idata_ppc = pm.sample_posterior_predictive( + trace, var_names=["a"], random_seed=random_state + ) + ppc = idata_ppc.posterior_predictive + assert "a" in ppc + assert ppc["a"].shape == (nchains, ndraws) + # mu's standard deviation may have changed thanks to a's observed + _, pval = stats.kstest( + (ppc["a"] - trace.posterior["mu"]).values.flatten(), stats.norm(loc=0, scale=1).cdf + ) + assert pval > 0.001 + + def test_normal_scalar_idata(self): + nchains = 2 + ndraws = 500 + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning) + trace = pm.sample( + draws=ndraws, + chains=nchains, + return_inferencedata=False, + discard_tuned_samples=False, + ) + + assert not isinstance(trace, InferenceData) + + with model: + # test keep_size parameter and idata input + idata = pm.to_inference_data(trace) + assert isinstance(idata, InferenceData) + + ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) + assert ppc["a"].shape == (nchains, ndraws) + + def test_normal_vector(self): + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) + trace = pm.sample(return_inferencedata=False, draws=12, chains=1) + + with model: + # test list input + ppc0 = pm.sample_posterior_predictive( + 10 * [model.initial_point()], + return_inferencedata=False, + ) + ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False, var_names=[]) + assert len(ppc) == 0 + + ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False) + assert ppc["a"].shape == (trace.nchains, len(trace), 2) + assert ppc0["a"].shape == (1, 10, 2) + + def test_normal_vector_idata(self): + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) + trace = pm.sample(return_inferencedata=False) + + assert not isinstance(trace, InferenceData) + + with model: + # test keep_size parameter with inference data as input... + idata = pm.to_inference_data(trace) + assert isinstance(idata, InferenceData) + + ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) + assert ppc["a"].shape == (trace.nchains, len(trace), 2) + + def test_exceptions(self): + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) + idata = pm.sample(idata_kwargs={"log_likelihood": False}) + + with model: + # test wrong type argument + bad_trace = {"mu": stats.norm.rvs(size=1000)} + with pytest.raises(TypeError, match="type for `trace`"): + ppc = pm.sample_posterior_predictive(bad_trace) + + def test_sum_normal(self): + with pm.Model() as model: + a = pm.Normal("a", sigma=0.2) + b = pm.Normal("b", mu=a) + idata = pm.sample(draws=1000, chains=1) + + with model: + # test list input + ppc0 = pm.sample_posterior_predictive( + 10 * [model.initial_point()], return_inferencedata=False + ) + assert ppc0 == {} + ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False, var_names=["b"]) + assert len(ppc) == 1 + assert ppc["b"].shape == ( + 1, + 1000, + ) + scale = np.sqrt(1 + 0.2**2) + _, pval = stats.kstest(ppc["b"].flatten(), stats.norm(scale=scale).cdf) + assert pval > 0.001 + + def test_model_not_drawable_prior(self): + data = np.random.poisson(lam=10, size=200) + model = pm.Model() + with model: + mu = pm.HalfFlat("sigma") + pm.Poisson("foo", mu=mu, observed=data) + with aesara.config.change_flags(mode=fast_unstable_sampling_mode): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + idata = pm.sample(tune=10, draws=40, chains=1) + + with model: + with pytest.raises(NotImplementedError) as excinfo: + pm.sample_prior_predictive(50) + assert "Cannot sample" in str(excinfo.value) + samples = pm.sample_posterior_predictive(idata, return_inferencedata=False) + assert samples["foo"].shape == (1, 40, 200) + + def test_model_shared_variable(self): + rng = np.random.RandomState(9832) + + x = rng.randn(100) + y = x > 0 + x_shared = aesara.shared(x) + y_shared = aesara.shared(y) + samples = 100 + with pm.Model() as model: + coeff = pm.Normal("x", mu=0, sigma=1) + logistic = pm.Deterministic("p", pm.math.sigmoid(coeff * x_shared)) + + obs = pm.Bernoulli("obs", p=logistic, observed=y_shared) + trace = pm.sample( + samples, + chains=1, + return_inferencedata=False, + compute_convergence_checks=False, + random_seed=rng, + ) + + x_shared.set_value([-1, 0, 1.0]) + y_shared.set_value([0, 0, 0]) + + with model: + post_pred = pm.sample_posterior_predictive( + trace, return_inferencedata=False, var_names=["p", "obs"] + ) + + expected_p = np.array([[logistic.eval({coeff: val}) for val in trace["x"][:samples]]]) + assert post_pred["obs"].shape == (1, samples, 3) + npt.assert_allclose(post_pred["p"], expected_p) + + def test_deterministic_of_observed(self): + rng = np.random.RandomState(8442) + + meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(10)) + meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(10)) + nchains = 2 + with pm.Model() as model: + mu_in_1 = pm.Normal("mu_in_1", 0, 2) + sigma_in_1 = pm.HalfNormal("sd_in_1", 1) + mu_in_2 = pm.Normal("mu_in_2", 0, 2) + sigma_in_2 = pm.HalfNormal("sd__in_2", 1) + + in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1) + in_2 = pm.Normal("in_2", mu_in_2, sigma_in_2, observed=meas_in_2) + out_diff = in_1 + in_2 + pm.Deterministic("out", out_diff) + + with aesara.config.change_flags(mode=fast_unstable_sampling_mode): + trace = pm.sample( + tune=100, + draws=100, + chains=nchains, + step=pm.Metropolis(), + return_inferencedata=False, + compute_convergence_checks=False, + random_seed=rng, + ) + + rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-4 + + ppc = pm.sample_posterior_predictive( + return_inferencedata=False, + model=model, + trace=trace, + random_seed=0, + var_names=[var.name for var in (model.deterministics + model.basic_RVs)], + ) + + npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol) + + def test_deterministic_of_observed_modified_interface(self): + rng = np.random.RandomState(4982) + + meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(100)) + meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(100)) + with pm.Model() as model: + mu_in_1 = pm.Normal("mu_in_1", 0, 1, initval=0) + sigma_in_1 = pm.HalfNormal("sd_in_1", 1, initval=1) + mu_in_2 = pm.Normal("mu_in_2", 0, 1, initval=0) + sigma_in_2 = pm.HalfNormal("sd__in_2", 1, initval=1) + + in_1 = pm.Normal("in_1", mu_in_1, sigma_in_1, observed=meas_in_1) + in_2 = pm.Normal("in_2", mu_in_2, sigma_in_2, observed=meas_in_2) + out_diff = in_1 + in_2 + pm.Deterministic("out", out_diff) + + with aesara.config.change_flags(mode=fast_unstable_sampling_mode): + trace = pm.sample( + tune=100, + draws=100, + step=pm.Metropolis(), + return_inferencedata=False, + compute_convergence_checks=False, + random_seed=rng, + ) + varnames = [v for v in trace.varnames if v != "out"] + ppc_trace = [ + dict(zip(varnames, row)) for row in zip(*(trace.get_values(v) for v in varnames)) + ] + ppc = pm.sample_posterior_predictive( + return_inferencedata=False, + model=model, + trace=ppc_trace, + var_names=[x.name for x in (model.deterministics + model.basic_RVs)], + ) + + rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-3 + npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol) + + def test_variable_type(self): + with pm.Model() as model: + mu = pm.HalfNormal("mu", 1) + a = pm.Normal("a", mu=mu, sigma=2, observed=np.array([1, 2])) + b = pm.Poisson("b", mu, observed=np.array([1, 2])) + with aesara.config.change_flags(mode=fast_unstable_sampling_mode): + trace = pm.sample( + tune=10, draws=10, compute_convergence_checks=False, return_inferencedata=False + ) + + with model: + ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False) + assert ppc["a"].dtype.kind == "f" + assert ppc["b"].dtype.kind == "i" + + def test_potentials_warning(self): + warning_msg = "The effect of Potentials on other parameters is ignored during" + with pm.Model() as m: + a = pm.Normal("a", 0, 1) + p = pm.Potential("p", a + 1) + obs = pm.Normal("obs", a, 1, observed=5) + + trace = az_from_dict({"a": np.random.rand(5)}) + with m: + with pytest.warns(UserWarning, match=warning_msg): + pm.sample_posterior_predictive(trace) + + def test_idata_extension(self): + """Testing if sample_posterior_predictive() extends inferenceData""" + + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=[0.0, 1.0]) + idata = pm.sample(tune=10, draws=10, compute_convergence_checks=False) + + base_test_dict = { + "posterior": ["mu", "~a"], + "sample_stats": ["diverging", "lp"], + "log_likelihood": ["a"], + "observed_data": ["a"], + } + test_dict = {"~posterior_predictive": [], "~predictions": [], **base_test_dict} + fails = check_multiple_attrs(test_dict, idata) + assert not fails + + # extending idata with in-sample ppc + with model: + pm.sample_posterior_predictive(idata, extend_inferencedata=True) + # test addition + test_dict = {"posterior_predictive": ["a"], "~predictions": [], **base_test_dict} + fails = check_multiple_attrs(test_dict, idata) + assert not fails + + # extending idata with out-of-sample ppc + with model: + pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True) + # test addition + test_dict = {"posterior_predictive": ["a"], "predictions": ["a"], **base_test_dict} + fails = check_multiple_attrs(test_dict, idata) + assert not fails + + @pytest.mark.parametrize("multitrace", [False, True]) + def test_deterministics_out_of_idata(self, multitrace): + draws = 10 + chains = 2 + coords = {"draw": range(draws), "chain": range(chains)} + ds = xr.Dataset( + { + "a": xr.DataArray( + [[0] * draws] * chains, + coords=coords, + dims=["chain", "draw"], + ) + }, + coords=coords, + ) + with pm.Model() as m: + a = pm.Normal("a") + if multitrace: + straces = [] + for chain in ds.chain: + strace = pm.backends.NDArray(model=m, vars=[a]) + strace.setup(len(ds.draw), int(chain)) + strace.values = {"a": ds.a.sel(chain=chain).data} + strace.draw_idx = len(ds.draw) + straces.append(strace) + trace = MultiTrace(straces) + else: + trace = ds + + d = pm.Deterministic("d", a - 4) + pm.Normal("c", d, sigma=0.01) + ppc = pm.sample_posterior_predictive(trace, var_names="c", return_inferencedata=True) + assert np.all(np.abs(ppc.posterior_predictive.c + 4) <= 0.1) + + def test_logging_sampled_basic_rvs_prior(self, caplog): + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Deterministic("y", x + 1) + z = pm.Normal("z", y, observed=0) + + with m: + pm.sample_prior_predictive(samples=1) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, z]")] + caplog.clear() + + with m: + pm.sample_prior_predictive(samples=1, var_names=["x"]) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x]")] + caplog.clear() + + def test_logging_sampled_basic_rvs_posterior(self, caplog): + with pm.Model() as m: + x = pm.Normal("x") + x_det = pm.Deterministic("x_det", x + 1) + y = pm.Normal("y", x_det) + z = pm.Normal("z", y, observed=0) + + idata = az_from_dict(posterior={"x": np.zeros(5), "x_det": np.ones(5), "y": np.ones(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [z]")] + caplog.clear() + + with m: + pm.sample_posterior_predictive(idata, var_names=["y", "z"]) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] + caplog.clear() + + # Resampling `x` will force resampling of `y`, even if it is in trace + with m: + pm.sample_posterior_predictive(idata, var_names=["x", "z"]) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, y, z]")] + caplog.clear() + + # Missing deterministic `x_det` does not show in the log, even if it is being + # recomputed, only `y` RV shows + idata = az_from_dict(posterior={"x": np.zeros(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] + caplog.clear() + + # Missing deterministic `x_det` does not cause recomputation of downstream `y` RV + idata = az_from_dict(posterior={"x": np.zeros(5), "y": np.ones(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [z]")] + caplog.clear() + + # Missing `x` causes sampling of downstream `y` RV, even if it is present in trace + idata = az_from_dict(posterior={"y": np.ones(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, y, z]")] + caplog.clear() + + def test_logging_sampled_basic_rvs_posterior_deterministic(self, caplog): + with pm.Model() as m: + x = pm.Normal("x") + x_det = pm.Deterministic("x_det", x + 1) + y = pm.Normal("y", x_det) + z = pm.Normal("z", y, observed=0) + + # Explicit resampling a deterministic will lead to resampling of downstream RV `y` + # This behavior could change in the future as the posterior of `y` is still valid + idata = az_from_dict(posterior={"x": np.zeros(5), "x_det": np.ones(5), "y": np.ones(5)}) + with m: + pm.sample_posterior_predictive(idata, var_names=["x_det", "z"]) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] + caplog.clear() + + @staticmethod + def make_mock_model(): + rng = np.random.default_rng(seed=42) + data = rng.normal(loc=1, scale=0.2, size=(10, 3)) + with pm.Model() as model: + model.add_coord("name", ["A", "B", "C"], mutable=True) + model.add_coord("obs", list(range(10, 20)), mutable=True) + offsets = pm.MutableData("offsets", rng.normal(0, 1, size=(10,))) + a = pm.Normal("a", mu=0, sigma=1, dims=["name"]) + b = pm.Normal("b", mu=offsets, sigma=1) + mu = pm.Deterministic("mu", a + b[..., None], dims=["obs", "name"]) + sigma = pm.HalfNormal("sigma", sigma=1, dims=["name"]) + + data = pm.MutableData( + "y_obs", + data, + dims=["obs", "name"], + ) + pm.Normal("y", mu=mu, sigma=sigma, observed=data, dims=["obs", "name"]) + return model + + @pytest.fixture(scope="class") + def mock_multitrace(self): + with self.make_mock_model(): + trace = pm.sample( + draws=10, + tune=10, + chains=2, + progressbar=False, + compute_convergence_checks=False, + return_inferencedata=False, + random_seed=42, + ) + return trace + + @pytest.fixture(scope="class", params=["MultiTrace", "InferenceData", "Dataset"]) + def mock_sample_results(self, request, mock_multitrace): + kind = request.param + trace = mock_multitrace + # We rebuild the class to ensure that all dimensions, data and coords start out + # the same across params values + model = self.make_mock_model() + if kind == "MultiTrace": + return kind, trace, model + else: + idata = pm.to_inference_data( + trace, + save_warmup=False, + model=model, + log_likelihood=False, + ) + if kind == "Dataset": + return kind, idata.posterior, model + else: + return kind, idata, model + + def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results, caplog): + kind, samples, model = mock_sample_results + with model: + pm.sample_posterior_predictive(samples) + if kind == "MultiTrace": + # MultiTrace will only have the actual MCMC posterior samples but no information on + # the MutableData and mutable coordinate values, so it will always assume they are volatile + # and resample their descendants + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "InferenceData": + # InferenceData has all MCMC posterior samples and the values for both coordinates and + # data containers. This enables it to see that no data has changed and it should only + # resample the observed variable + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y]")] + caplog.clear() + elif kind == "Dataset": + # Dataset has all MCMC posterior samples and the values of the coordinates. This + # enables it to see that the coordinates have not changed, but the MutableData is + # assumed volatile by default + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] + caplog.clear() + + original_offsets = model["offsets"].get_value() + with model: + # Changing the MutableData values. This will only be picked up by InferenceData + pm.set_data({"offsets": original_offsets + 1}) + pm.sample_posterior_predictive(samples) + if kind == "MultiTrace": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "InferenceData": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] + caplog.clear() + elif kind == "Dataset": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] + caplog.clear() + + with model: + # Changing the mutable coordinates. This will be picked up by InferenceData and Dataset + model.set_dim("name", new_length=4, coord_values=["D", "E", "F", "G"]) + pm.set_data({"offsets": original_offsets, "y_obs": np.zeros((10, 4))}) + pm.sample_posterior_predictive(samples) + if kind == "MultiTrace": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "InferenceData": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, sigma, y]")] + caplog.clear() + elif kind == "Dataset": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + + with model: + # Changing the mutable coordinate values, but not shape, and also changing MutableData. + # This will trigger resampling of all variables + model.set_dim("name", new_length=3, coord_values=["A", "B", "D"]) + pm.set_data({"offsets": original_offsets + 1, "y_obs": np.zeros((10, 3))}) + pm.sample_posterior_predictive(samples) + if kind == "MultiTrace": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "InferenceData": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "Dataset": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + + +@pytest.fixture(scope="class") +def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]: + with pm.Model() as pmodel: + n = pm.Normal("n") + trace = pm.sample(return_inferencedata=False) + + with pmodel: + d = pm.Deterministic("d", n * 4) + return pmodel, trace + + +class TestSamplePriorPredictive(SeededTest): + def test_ignores_observed(self): + observed = np.random.normal(10, 1, size=200) + with pm.Model(): + # Use a prior that's way off to show we're ignoring the observed variables + observed_data = pm.MutableData("observed_data", observed) + mu = pm.Normal("mu", mu=-100, sigma=1) + positive_mu = pm.Deterministic("positive_mu", np.abs(mu)) + z = -1 - positive_mu + pm.Normal("x_obs", mu=z, sigma=1, observed=observed_data) + prior = pm.sample_prior_predictive(return_inferencedata=False) + + assert "observed_data" not in prior + assert (prior["mu"] < -90).all() + assert (prior["positive_mu"] > 90).all() + assert (prior["x_obs"] < -90).all() + assert prior["x_obs"].shape == (500, 200) + npt.assert_array_almost_equal(prior["positive_mu"], np.abs(prior["mu"]), decimal=4) + + def test_respects_shape(self): + for shape in (2, (2,), (10, 2), (10, 10)): + with pm.Model(): + mu = pm.Gamma("mu", 3, 1, size=1) + goals = pm.Poisson("goals", mu, size=shape) + trace1 = pm.sample_prior_predictive( + 10, return_inferencedata=False, var_names=["mu", "mu", "goals"] + ) + trace2 = pm.sample_prior_predictive( + 10, return_inferencedata=False, var_names=["mu", "goals"] + ) + if shape == 2: # want to test shape as an int + shape = (2,) + assert trace1["goals"].shape == (10,) + shape + assert trace2["goals"].shape == (10,) + shape + + def test_multivariate(self): + with pm.Model(): + m = pm.Multinomial("m", n=5, p=np.array([0.25, 0.25, 0.25, 0.25])) + trace = pm.sample_prior_predictive(10) + + assert trace.prior["m"].shape == (1, 10, 4) + + def test_multivariate2(self): + # Added test for issue #3271 + mn_data = np.random.multinomial(n=100, pvals=[1 / 6.0] * 6, size=10) + with pm.Model() as dm_model: + probs = pm.Dirichlet("probs", a=np.ones(6)) + obs = pm.Multinomial("obs", n=100, p=probs, observed=mn_data) + with aesara.config.change_flags(mode=fast_unstable_sampling_mode): + burned_trace = pm.sample( + tune=10, + draws=20, + chains=1, + return_inferencedata=False, + compute_convergence_checks=False, + ) + sim_priors = pm.sample_prior_predictive( + return_inferencedata=False, samples=20, model=dm_model + ) + sim_ppc = pm.sample_posterior_predictive( + burned_trace, return_inferencedata=False, model=dm_model + ) + assert sim_priors["probs"].shape == (20, 6) + assert sim_priors["obs"].shape == (20,) + mn_data.shape + assert sim_ppc["obs"].shape == (1, 20) + mn_data.shape + + def test_layers(self): + with pm.Model() as model: + a = pm.Uniform("a", lower=0, upper=1, size=10) + b = pm.Binomial("b", n=1, p=a, size=10) + + b_sampler = compile_pymc([], b, mode="FAST_RUN", random_seed=232093) + avg = np.stack([b_sampler() for i in range(10000)]).mean(0) + npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2) + + def test_transformed(self): + n = 18 + at_bats = 45 * np.ones(n, dtype=int) + hits = np.random.randint(1, 40, size=n, dtype=int) + draws = 50 + + with pm.Model() as model: + phi = pm.Beta("phi", alpha=1.0, beta=1.0) + + kappa_log = pm.Exponential("logkappa", lam=5.0) + kappa = pm.Deterministic("kappa", at.exp(kappa_log)) + + thetas = pm.Beta("thetas", alpha=phi * kappa, beta=(1.0 - phi) * kappa, size=n) + + y = pm.Binomial("y", n=at_bats, p=thetas, observed=hits) + gen = pm.sample_prior_predictive(draws) + + assert gen.prior["phi"].shape == (1, draws) + assert gen.prior_predictive["y"].shape == (1, draws, n) + assert "thetas" in gen.prior.data_vars + + def test_shared(self): + n1 = 10 + obs = shared(np.random.rand(n1) < 0.5) + draws = 50 + + with pm.Model() as m: + p = pm.Beta("p", 1.0, 1.0) + y = pm.Bernoulli("y", p, observed=obs) + o = pm.Deterministic("o", obs) + gen1 = pm.sample_prior_predictive(draws) + + assert gen1.prior_predictive["y"].shape == (1, draws, n1) + assert gen1.prior["o"].shape == (1, draws, n1) + + n2 = 20 + obs.set_value(np.random.rand(n2) < 0.5) + with m: + gen2 = pm.sample_prior_predictive(draws) + + assert gen2.prior_predictive["y"].shape == (1, draws, n2) + assert gen2.prior["o"].shape == (1, draws, n2) + + def test_density_dist(self): + obs = np.random.normal(-1, 0.1, size=10) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1e-6) + a = pm.DensityDist( + "a", + mu, + sigma, + random=lambda mu, sigma, rng=None, size=None: rng.normal( + loc=mu, scale=sigma, size=size + ), + observed=obs, + ) + prior = pm.sample_prior_predictive(return_inferencedata=False) + + npt.assert_almost_equal((prior["a"] - prior["mu"][..., None]).mean(), 0, decimal=3) + + def test_shape_edgecase(self): + with pm.Model(): + mu = pm.Normal("mu", size=5) + sigma = pm.Uniform("sigma", lower=2, upper=3) + x = pm.Normal("x", mu=mu, sigma=sigma, size=5) + prior = pm.sample_prior_predictive(10) + assert prior.prior["mu"].shape == (1, 10, 5) + + def test_zeroinflatedpoisson(self): + with pm.Model(): + mu = pm.Beta("mu", alpha=1, beta=1) + psi = pm.HalfNormal("psi", sigma=1) + pm.ZeroInflatedPoisson("suppliers", psi=psi, mu=mu, size=20) + gen_data = pm.sample_prior_predictive(samples=5000) + assert gen_data.prior["mu"].shape == (1, 5000) + assert gen_data.prior["psi"].shape == (1, 5000) + assert gen_data.prior["suppliers"].shape == (1, 5000, 20) + + def test_potentials_warning(self): + warning_msg = "The effect of Potentials on other parameters is ignored during" + with pm.Model() as m: + a = pm.Normal("a", 0, 1) + p = pm.Potential("p", a + 1) + + with m: + with pytest.warns(UserWarning, match=warning_msg): + pm.sample_prior_predictive(samples=5) + + def test_transformed_vars(self): + # Test that prior predictive returns transformation of RVs when these are + # passed explicitly in `var_names` + + def ub_interval_forward(x, ub): + # Interval transform assuming lower bound is zero + return np.log(x - 0) - np.log(ub - x) + + with pm.Model() as model: + ub = pm.HalfNormal("ub", 10) + x = pm.Uniform("x", 0, ub) + + prior = pm.sample_prior_predictive( + var_names=["ub", "ub_log__", "x", "x_interval__"], + samples=10, + random_seed=123, + ) + + # Check values are correct + assert np.allclose(prior.prior["ub_log__"].data, np.log(prior.prior["ub"].data)) + assert np.allclose( + prior.prior["x_interval__"].data, + ub_interval_forward(prior.prior["x"].data, prior.prior["ub"].data), + ) + + # Check that it works when the original RVs are not mentioned in var_names + with pm.Model() as model_transformed_only: + ub = pm.HalfNormal("ub", 10) + x = pm.Uniform("x", 0, ub) + + prior_transformed_only = pm.sample_prior_predictive( + var_names=["ub_log__", "x_interval__"], + samples=10, + random_seed=123, + ) + assert ( + "ub" not in prior_transformed_only.prior.data_vars + and "x" not in prior_transformed_only.prior.data_vars + ) + assert np.allclose( + prior.prior["ub_log__"].data, prior_transformed_only.prior["ub_log__"].data + ) + assert np.allclose( + prior.prior["x_interval__"], prior_transformed_only.prior["x_interval__"].data + ) + + def test_issue_4490(self): + # Test that samples do not depend on var_name order or, more fundamentally, + # that they do not depend on the set order used inside `sample_prior_predictive` + seed = 4490 + with pm.Model() as m1: + a = pm.Normal("a") + b = pm.Normal("b") + c = pm.Normal("c") + d = pm.Normal("d") + prior1 = pm.sample_prior_predictive( + samples=1, var_names=["a", "b", "c", "d"], random_seed=seed + ) + + with pm.Model() as m2: + a = pm.Normal("a") + b = pm.Normal("b") + c = pm.Normal("c") + d = pm.Normal("d") + prior2 = pm.sample_prior_predictive( + samples=1, var_names=["b", "a", "d", "c"], random_seed=seed + ) + + assert prior1.prior["a"] == prior2.prior["a"] + assert prior1.prior["b"] == prior2.prior["b"] + assert prior1.prior["c"] == prior2.prior["c"] + assert prior1.prior["d"] == prior2.prior["d"] + + def test_aesara_function_kwargs(self): + sharedvar = aesara.shared(0) + with pm.Model() as m: + x = pm.DiracDelta("x", 0) + y = pm.Deterministic("y", x + sharedvar) + + prior = pm.sample_prior_predictive( + samples=5, + return_inferencedata=False, + compile_kwargs=dict( + mode=Mode("py"), + updates={sharedvar: sharedvar + 1}, + ), + ) + + assert np.all(prior["y"] == np.arange(5)) + + +class TestSamplePosteriorPredictive: + def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture): + pmodel, trace = point_list_arg_bug_fixture + with pmodel: + pp = pm.sample_posterior_predictive( + [trace[15]], return_inferencedata=False, var_names=["d"] + ) + + def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture): + pmodel, trace = point_list_arg_bug_fixture + + with pmodel: + prior = pm.sample_prior_predictive( + samples=20, + return_inferencedata=False, + ) + idat = pm.to_inference_data(trace, prior=prior) + + with pmodel: + pp = pm.sample_posterior_predictive( + idat.prior, return_inferencedata=False, var_names=["d"] + ) + + def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture): + pmodel, trace = point_list_arg_bug_fixture + with pmodel: + idat = pm.to_inference_data(trace) + pp = pm.sample_posterior_predictive(idat.posterior, var_names=["d"]) + + def test_aesara_function_kwargs(self): + sharedvar = aesara.shared(0) + with pm.Model() as m: + x = pm.DiracDelta("x", 0.0) + y = pm.Deterministic("y", x + sharedvar) + + pp = pm.sample_posterior_predictive( + trace=az_from_dict({"x": np.arange(5)}), + var_names=["y"], + return_inferencedata=False, + compile_kwargs=dict( + mode=Mode("py"), + updates={sharedvar: sharedvar + 1}, + ), + ) + + assert np.all(pp["y"] == np.arange(5) * 2) + + def test_sample_dims(self, point_list_arg_bug_fixture): + pmodel, trace = point_list_arg_bug_fixture + with pmodel: + post = pm.to_inference_data(trace).posterior.stack(sample=["chain", "draw"]) + pp = pm.sample_posterior_predictive(post, var_names=["d"], sample_dims=["sample"]) + assert "sample" in pp.posterior_predictive + assert len(pp.posterior_predictive["sample"]) == len(post["sample"]) + post = post.expand_dims(pred_id=5) + pp = pm.sample_posterior_predictive( + post, var_names=["d"], sample_dims=["sample", "pred_id"] + ) + assert "sample" in pp.posterior_predictive + assert "pred_id" in pp.posterior_predictive + assert len(pp.posterior_predictive["sample"]) == len(post["sample"]) + assert len(pp.posterior_predictive["pred_id"]) == 5 + + +def test_distinct_rvs(): + """Make sure `RandomVariable`s generated using a `Model`'s default RNG state all have distinct states.""" + + with pm.Model() as model: + X_rv = pm.Normal("x") + Y_rv = pm.Normal("y") + + pp_samples = pm.sample_prior_predictive( + samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) + ) + + assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0] + + with pm.Model(): + X_rv = pm.Normal("x") + Y_rv = pm.Normal("y") + + pp_samples_2 = pm.sample_prior_predictive( + samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532) + ) + + assert np.array_equal(pp_samples["y"], pp_samples_2["y"]) + + +class TestNestedRandom(SeededTest): + def build_model(self, distribution, shape, nested_rvs_info): + with pm.Model() as model: + nested_rvs = {} + for rv_name, info in nested_rvs_info.items(): + try: + value, nested_shape = info + loc = 0.0 + except ValueError: + value, nested_shape, loc = info + if value is None: + nested_rvs[rv_name] = pm.Uniform( + rv_name, + 0 + loc, + 1 + loc, + shape=nested_shape, + ) + else: + nested_rvs[rv_name] = value * np.ones(nested_shape) + rv = distribution( + "target", + shape=shape, + **nested_rvs, + ) + return model, rv, nested_rvs + + def sample_prior(self, distribution, shape, nested_rvs_info, prior_samples): + model, rv, nested_rvs = self.build_model( + distribution, + shape, + nested_rvs_info, + ) + with model: + return pm.sample_prior_predictive(prior_samples, return_inferencedata=False) + + @pytest.mark.parametrize( + ["prior_samples", "shape", "mu", "alpha"], + [ + [10, (3,), (None, tuple()), (None, (3,))], + [10, (3,), (None, (3,)), (None, tuple())], + [ + 10, + ( + 4, + 3, + ), + (None, (3,)), + (None, (3,)), + ], + [ + 10, + ( + 4, + 3, + ), + (None, (3,)), + (None, (4, 3)), + ], + ], + ids=str, + ) + def test_NegativeBinomial( + self, + prior_samples, + shape, + mu, + alpha, + ): + prior = self.sample_prior( + distribution=pm.NegativeBinomial, + shape=shape, + nested_rvs_info=dict(mu=mu, alpha=alpha), + prior_samples=prior_samples, + ) + assert prior["target"].shape == (prior_samples,) + shape + + @pytest.mark.parametrize( + ["prior_samples", "shape", "psi", "mu", "alpha"], + [ + [10, (3,), (0.5, tuple()), (None, tuple()), (None, (3,))], + [10, (3,), (0.5, (3,)), (None, tuple()), (None, (3,))], + [10, (3,), (0.5, tuple()), (None, (3,)), (None, tuple())], + [10, (3,), (0.5, (3,)), (None, (3,)), (None, tuple())], + [ + 10, + ( + 4, + 3, + ), + (0.5, (3,)), + (None, (3,)), + (None, (3,)), + ], + [ + 10, + ( + 4, + 3, + ), + (0.5, (3,)), + (None, (3,)), + (None, (4, 3)), + ], + ], + ids=str, + ) + def test_ZeroInflatedNegativeBinomial( + self, + prior_samples, + shape, + psi, + mu, + alpha, + ): + prior = self.sample_prior( + distribution=pm.ZeroInflatedNegativeBinomial, + shape=shape, + nested_rvs_info=dict(psi=psi, mu=mu, alpha=alpha), + prior_samples=prior_samples, + ) + assert prior["target"].shape == (prior_samples,) + shape + + @pytest.mark.parametrize( + ["prior_samples", "shape", "nu", "sigma"], + [ + [10, (3,), (None, tuple()), (None, (3,))], + [10, (3,), (None, tuple()), (None, (3,))], + [10, (3,), (None, (3,)), (None, tuple())], + [10, (3,), (None, (3,)), (None, tuple())], + [ + 10, + ( + 4, + 3, + ), + (None, (3,)), + (None, (3,)), + ], + [ + 10, + ( + 4, + 3, + ), + (None, (3,)), + (None, (4, 3)), + ], + ], + ids=str, + ) + def test_Rice( + self, + prior_samples, + shape, + nu, + sigma, + ): + prior = self.sample_prior( + distribution=pm.Rice, + shape=shape, + nested_rvs_info=dict(nu=nu, sigma=sigma), + prior_samples=prior_samples, + ) + assert prior["target"].shape == (prior_samples,) + shape + + @pytest.mark.parametrize( + ["prior_samples", "shape", "mu", "sigma", "lower", "upper"], + [ + [10, (3,), (None, tuple()), (1.0, tuple()), (None, tuple(), -1), (None, (3,))], + [10, (3,), (None, tuple()), (1.0, tuple()), (None, tuple(), -1), (None, (3,))], + [10, (3,), (None, tuple()), (1.0, tuple()), (None, (3,), -1), (None, tuple())], + [10, (3,), (None, tuple()), (1.0, tuple()), (None, (3,), -1), (None, tuple())], + [ + 10, + ( + 4, + 3, + ), + (None, (3,)), + (1.0, tuple()), + (None, (3,), -1), + (None, (3,)), + ], + [ + 10, + ( + 4, + 3, + ), + (None, (3,)), + (1.0, tuple()), + (None, (3,), -1), + (None, (4, 3)), + ], + [10, (3,), (0.0, tuple()), (None, tuple()), (None, tuple(), -1), (None, (3,))], + [10, (3,), (0.0, tuple()), (None, tuple()), (None, tuple(), -1), (None, (3,))], + [10, (3,), (0.0, tuple()), (None, tuple()), (None, (3,), -1), (None, tuple())], + [10, (3,), (0.0, tuple()), (None, tuple()), (None, (3,), -1), (None, tuple())], + [ + 10, + ( + 4, + 3, + ), + (0.0, tuple()), + (None, (3,)), + (None, (3,), -1), + (None, (3,)), + ], + [ + 10, + ( + 4, + 3, + ), + (0.0, tuple()), + (None, (3,)), + (None, (3,), -1), + (None, (4, 3)), + ], + ], + ids=str, + ) + def test_TruncatedNormal( + self, + prior_samples, + shape, + mu, + sigma, + lower, + upper, + ): + prior = self.sample_prior( + distribution=pm.TruncatedNormal, + shape=shape, + nested_rvs_info=dict(mu=mu, sigma=sigma, lower=lower, upper=upper), + prior_samples=prior_samples, + ) + assert prior["target"].shape == (prior_samples,) + shape + + @pytest.mark.parametrize( + ["prior_samples", "shape", "c", "lower", "upper"], + [ + [10, (3,), (None, tuple()), (-1.0, (3,)), (2, tuple())], + [10, (3,), (None, tuple()), (-1.0, tuple()), (None, tuple(), 1)], + [10, (3,), (None, (3,)), (-1.0, tuple()), (None, tuple(), 1)], + [ + 10, + ( + 4, + 3, + ), + (None, (3,)), + (-1.0, tuple()), + (None, (3,), 1), + ], + [ + 10, + ( + 4, + 3, + ), + (None, (3,)), + (None, tuple(), -1), + (None, (3,), 1), + ], + ], + ids=str, + ) + def test_Triangular( + self, + prior_samples, + shape, + c, + lower, + upper, + ): + prior = self.sample_prior( + distribution=pm.Triangular, + shape=shape, + nested_rvs_info=dict(c=c, lower=lower, upper=upper), + prior_samples=prior_samples, + ) + assert prior["target"].shape == (prior_samples,) + shape diff --git a/pymc/tests/test_sampling_utils.py b/pymc/tests/test_sampling_utils.py new file mode 100644 index 00000000000..a625b70886b --- /dev/null +++ b/pymc/tests/test_sampling_utils.py @@ -0,0 +1,484 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +import aesara +import aesara.tensor as at +import numpy as np +import pytest + +from aesara import Mode +from aesara.compile import SharedVariable + +import pymc as pm + +from pymc.backends.base import MultiTrace +from pymc.sampling_utils import ( + _get_seeds_per_chain, + compile_forward_sampling_function, + get_vars_in_point_list, +) +from pymc.tests.helpers import SeededTest + + +class TestDraw(SeededTest): + def test_univariate(self): + with pm.Model(): + x = pm.Normal("x") + + x_draws = pm.draw(x) + assert x_draws.shape == () + + (x_draws,) = pm.draw([x]) + assert x_draws.shape == () + + x_draws = pm.draw(x, draws=10) + assert x_draws.shape == (10,) + + (x_draws,) = pm.draw([x], draws=10) + assert x_draws.shape == (10,) + + def test_multivariate(self): + with pm.Model(): + mln = pm.Multinomial("mln", n=5, p=np.array([0.25, 0.25, 0.25, 0.25])) + + mln_draws = pm.draw(mln, draws=1) + assert mln_draws.shape == (4,) + + (mln_draws,) = pm.draw([mln], draws=1) + assert mln_draws.shape == (4,) + + mln_draws = pm.draw(mln, draws=10) + assert mln_draws.shape == (10, 4) + + (mln_draws,) = pm.draw([mln], draws=10) + assert mln_draws.shape == (10, 4) + + def test_multiple_variables(self): + with pm.Model(): + x = pm.Normal("x") + y = pm.Normal("y", shape=10) + z = pm.Uniform("z", shape=5) + w = pm.Dirichlet("w", a=[1, 1, 1]) + + num_draws = 100 + draws = pm.draw((x, y, z, w), draws=num_draws) + assert draws[0].shape == (num_draws,) + assert draws[1].shape == (num_draws, 10) + assert draws[2].shape == (num_draws, 5) + assert draws[3].shape == (num_draws, 3) + + def test_draw_different_samples(self): + with pm.Model(): + x = pm.Normal("x") + + x_draws_1 = pm.draw(x, 100) + x_draws_2 = pm.draw(x, 100) + assert not np.all(np.isclose(x_draws_1, x_draws_2)) + + def test_draw_aesara_function_kwargs(self): + sharedvar = aesara.shared(0) + x = pm.DiracDelta.dist(0.0) + y = x + sharedvar + draws = pm.draw( + y, + draws=5, + mode=Mode("py"), + updates={sharedvar: sharedvar + 1}, + ) + assert np.all(draws == np.arange(5)) + + +class TestCompileForwardSampler: + @staticmethod + def get_function_roots(function): + return [ + var + for var in aesara.graph.basic.graph_inputs(function.maker.fgraph.outputs) + if var.name + ] + + @staticmethod + def get_function_inputs(function): + return {i for i in function.maker.fgraph.inputs if not isinstance(i, SharedVariable)} + + def test_linear_model(self): + with pm.Model() as model: + x = pm.MutableData("x", np.linspace(0, 1, 10)) + y = pm.MutableData("y", np.ones(10)) + + alpha = pm.Normal("alpha", 0, 0.1) + beta = pm.Normal("beta", 0, 0.1) + mu = pm.Deterministic("mu", alpha + beta * x) + sigma = pm.HalfNormal("sigma", 0.1) + obs = pm.Normal("obs", mu, sigma, observed=y, shape=x.shape) + + f, volatile_rvs = compile_forward_sampling_function( + [obs], + vars_in_trace=[alpha, beta, sigma, mu], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {obs} + assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"x", "alpha", "beta", "sigma"} + + with pm.Model() as model: + x = pm.ConstantData("x", np.linspace(0, 1, 10)) + y = pm.MutableData("y", np.ones(10)) + + alpha = pm.Normal("alpha", 0, 0.1) + beta = pm.Normal("beta", 0, 0.1) + mu = pm.Deterministic("mu", alpha + beta * x) + sigma = pm.HalfNormal("sigma", 0.1) + obs = pm.Normal("obs", mu, sigma, observed=y, shape=x.shape) + + f, volatile_rvs = compile_forward_sampling_function( + [obs], + vars_in_trace=[alpha, beta, sigma, mu], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {obs} + assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma", "mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} + + def test_nested_observed_model(self): + with pm.Model() as model: + p = pm.ConstantData("p", np.array([0.25, 0.5, 0.25])) + x = pm.MutableData("x", np.zeros(10)) + y = pm.MutableData("y", np.ones(10)) + + category = pm.Categorical("category", p, observed=x) + beta = pm.Normal("beta", 0, 0.1, size=p.shape) + mu = pm.Deterministic("mu", beta[category]) + sigma = pm.HalfNormal("sigma", 0.1) + obs = pm.Normal("obs", mu, sigma, observed=y, shape=mu.shape) + + f, volatile_rvs = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[beta, mu, sigma], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {category, obs} + assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"x", "p", "beta", "sigma"} + + f, volatile_rvs = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[beta, mu, sigma], + basic_rvs=model.basic_RVs, + givens_dict={category: np.zeros(10, dtype=category.dtype)}, + ) + assert volatile_rvs == {obs} + assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == { + "x", + "p", + "category", + "beta", + "sigma", + } + + def test_volatile_parameters(self): + with pm.Model() as model: + y = pm.MutableData("y", np.ones(10)) + mu = pm.Normal("mu", 0, 1) + nested_mu = pm.Normal("nested_mu", mu, 1, size=10) + sigma = pm.HalfNormal("sigma", 1) + obs = pm.Normal("obs", nested_mu, sigma, observed=y, shape=nested_mu.shape) + + f, volatile_rvs = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[nested_mu, sigma], # mu isn't in the trace and will be deemed volatile + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {mu, nested_mu, obs} + assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"sigma"} + + f, volatile_rvs = compile_forward_sampling_function( + outputs=model.observed_RVs, + vars_in_trace=[mu, nested_mu, sigma], + basic_rvs=model.basic_RVs, + givens_dict={ + mu: np.array(1.0) + }, # mu will be considered volatile because it's in givens + ) + assert volatile_rvs == {nested_mu, obs} + assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} + + def test_mixture(self): + with pm.Model() as model: + w = pm.Dirichlet("w", a=np.ones(3), size=(5, 3)) + + mu = pm.Normal("mu", mu=np.arange(3), sigma=1) + + components = pm.Normal.dist(mu=mu, sigma=1, size=w.shape) + mix_mu = pm.Mixture("mix_mu", w=w, comp_dists=components) + obs = pm.Normal("obs", mix_mu, 1, observed=np.ones((5, 3))) + + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[mix_mu, mu, w], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {obs} + assert {i.name for i in self.get_function_inputs(f)} == {"w", "mu", "mix_mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mix_mu"} + + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[mu, w], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {mix_mu, obs} + assert {i.name for i in self.get_function_inputs(f)} == {"w", "mu"} + assert {i.name for i in self.get_function_roots(f)} == {"w", "mu"} + + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[mix_mu, mu], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {w, mix_mu, obs} + assert {i.name for i in self.get_function_inputs(f)} == {"mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mu"} + + def test_censored(self): + with pm.Model() as model: + latent_mu = pm.Normal("latent_mu", mu=np.arange(3), sigma=1) + mu = pm.Censored("mu", pm.Normal.dist(mu=latent_mu, sigma=1), lower=-1, upper=1) + obs = pm.Normal("obs", mu, 1, observed=np.ones((10, 3))) + + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[latent_mu, mu], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {obs} + assert {i.name for i in self.get_function_inputs(f)} == {"latent_mu", "mu"} + assert {i.name for i in self.get_function_roots(f)} == {"mu"} + + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[mu], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {latent_mu, mu, obs} + assert {i.name for i in self.get_function_inputs(f)} == set() + assert {i.name for i in self.get_function_roots(f)} == set() + + def test_lkj_cholesky_cov(self): + with pm.Model() as model: + mu = np.zeros(3) + sd_dist = pm.Exponential.dist(1.0, size=3) + chol, corr, stds = pm.LKJCholeskyCov( # pylint: disable=unpacking-non-sequence + "chol_packed", n=3, eta=2, sd_dist=sd_dist, compute_corr=True + ) + chol_packed = model["chol_packed"] + chol = pm.Deterministic("chol", chol) + obs = pm.MvNormal("obs", mu=mu, chol=chol, observed=np.zeros(3)) + + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[chol_packed, chol], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {obs} + assert {i.name for i in self.get_function_inputs(f)} == {"chol_packed", "chol"} + assert {i.name for i in self.get_function_roots(f)} == {"chol"} + + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[chol_packed], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {obs} + assert {i.name for i in self.get_function_inputs(f)} == {"chol_packed"} + assert {i.name for i in self.get_function_roots(f)} == {"chol_packed"} + + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[chol], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {chol_packed, obs} + assert {i.name for i in self.get_function_inputs(f)} == set() + assert {i.name for i in self.get_function_roots(f)} == set() + + def test_non_random_model_variable(self): + with pm.Model() as model: + # A user may register non-pure RandomVariables that can nevertheless be + # sampled, as long as a custom logprob is dispatched or Aeppl can infer + # its logprob (which is the case for `clip`) + y = at.clip(pm.Normal.dist(), -1, 1) + y = model.register_rv(y, name="y") + y_abs = pm.Deterministic("y_abs", at.abs(y)) + obs = pm.Normal("obs", y_abs, observed=np.zeros(10)) + + # y_abs should be resampled even if in the trace, because the source y is missing + f, volatile_rvs = compile_forward_sampling_function( + outputs=[obs], + vars_in_trace=[y_abs], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {y, obs} + assert {i.name for i in self.get_function_inputs(f)} == set() + assert {i.name for i in self.get_function_roots(f)} == set() + + def test_mutable_coords_volatile(self): + rng = np.random.default_rng(seed=42) + data = rng.normal(loc=1, scale=0.2, size=(10, 3)) + with pm.Model() as model: + model.add_coord("name", ["A", "B", "C"], mutable=True) + model.add_coord("obs", list(range(10, 20)), mutable=True) + offsets = pm.MutableData("offsets", rng.normal(0, 1, size=(10,))) + a = pm.Normal("a", mu=0, sigma=1, dims=["name"]) + b = pm.Normal("b", mu=offsets, sigma=1) + mu = pm.Deterministic("mu", a + b[..., None], dims=["obs", "name"]) + sigma = pm.HalfNormal("sigma", sigma=1, dims=["name"]) + + data = pm.MutableData( + "y_obs", + data, + dims=["obs", "name"], + ) + y = pm.Normal("y", mu=mu, sigma=sigma, observed=data, dims=["obs", "name"]) + + # When no constant_data and constant_coords, all the dependent nodes will be volatile and + # resampled + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {y, a, b, sigma} + assert {i.name for i in self.get_function_inputs(f)} == set() + assert {i.name for i in self.get_function_roots(f)} == {"name", "obs", "offsets"} + + # When the constant data has the same values as the shared data, offsets wont be volatile + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + constant_data={"offsets": offsets.get_value()}, + ) + assert volatile_rvs == {y, a, sigma} + assert {i.name for i in self.get_function_inputs(f)} == {"b"} + assert {i.name for i in self.get_function_roots(f)} == {"b", "name", "obs"} + + # When we declare constant_coords, the shared variables with matching names wont be volatile + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + constant_coords={"name", "obs"}, + ) + assert volatile_rvs == {y, b} + assert {i.name for i in self.get_function_inputs(f)} == {"a", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == { + "a", + "sigma", + "name", + "obs", + "offsets", + } + + # When we have both constant_data and constant_coords, only y will be volatile + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + constant_data={"offsets": offsets.get_value()}, + constant_coords={"name", "obs"}, + ) + assert volatile_rvs == {y} + assert {i.name for i in self.get_function_inputs(f)} == {"a", "b", "mu", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma", "name", "obs"} + + # When constant_data has different values than the shared variable, then + # offsets will be volatile + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + constant_data={"offsets": offsets.get_value() + 1}, + constant_coords={"name", "obs"}, + ) + assert volatile_rvs == {y, b} + assert {i.name for i in self.get_function_inputs(f)} == {"a", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == { + "a", + "sigma", + "name", + "obs", + "offsets", + } + + +def test_get_seeds_per_chain(): + ret = _get_seeds_per_chain(None, chains=1) + assert len(ret) == 1 and isinstance(ret[0], int) + + ret = _get_seeds_per_chain(None, chains=2) + assert len(ret) == 2 and isinstance(ret[0], int) + + ret = _get_seeds_per_chain(5, chains=1) + assert ret == (5,) + + ret = _get_seeds_per_chain(5, chains=3) + assert len(ret) == 3 and isinstance(ret[0], int) and not any(r == 5 for r in ret) + + rng = np.random.default_rng(123) + expected_ret = rng.integers(2**30, dtype=np.int64, size=1) + rng = np.random.default_rng(123) + ret = _get_seeds_per_chain(rng, chains=1) + assert ret == expected_ret + + rng = np.random.RandomState(456) + expected_ret = rng.randint(2**30, dtype=np.int64, size=2) + rng = np.random.RandomState(456) + ret = _get_seeds_per_chain(rng, chains=2) + assert np.all(ret == expected_ret) + + for expected_ret in ([0, 1, 2], (0, 1, 2, 3), np.arange(5)): + ret = _get_seeds_per_chain(expected_ret, chains=len(expected_ret)) + assert ret is expected_ret + + with pytest.raises(ValueError, match="does not match the number of chains"): + _get_seeds_per_chain(expected_ret, chains=len(expected_ret) + 1) + + with pytest.raises(ValueError, match=re.escape("The `seeds` must be array-like")): + _get_seeds_per_chain({1: 1, 2: 2}, 2) + + +def test_get_vars_in_point_list(): + with pm.Model() as modelA: + pm.Normal("a", 0, 1) + pm.Normal("b", 0, 1) + with pm.Model() as modelB: + a = pm.Normal("a", 0, 1) + pm.Normal("c", 0, 1) + + point_list = [{"a": 0, "b": 0}] + vars_in_trace = get_vars_in_point_list(point_list, modelB) + assert set(vars_in_trace) == {a} + + strace = pm.backends.NDArray(model=modelB, vars=modelA.free_RVs) + strace.setup(1, 1) + strace.values = point_list[0] + strace.draw_idx = 1 + trace = MultiTrace([strace]) + vars_in_trace = get_vars_in_point_list(trace, modelB) + assert set(vars_in_trace) == {a} diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 747b5582f45..a7cb0ff1119 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -70,7 +70,7 @@ from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext -from pymc.sampling import RandomState, _get_seeds_per_chain +from pymc.sampling_utils import RandomState, _get_seeds_per_chain from pymc.util import WithMemoization, locally_cachedmethod from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 8366a8b4707..1d57fa95c11 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -52,6 +52,8 @@ pymc/parallel_sampling.py pymc/plots/__init__.py pymc/sampling.py +pymc/sampling_predictive.py +pymc/sampling_utils.py pymc/smc/__init__.py pymc/smc/sampling.py pymc/smc/kernels.py