Skip to content

Commit

Permalink
Log sampled basic_RVs sample_*_predictive functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 21, 2022
1 parent afa7bbf commit feb72ae
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 37 deletions.
48 changes: 28 additions & 20 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,7 @@ def compile_forward_sampling_function(
basic_rvs: Optional[List[Variable]] = None,
givens_dict: Optional[Dict[Variable, Any]] = None,
**kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]:
"""Compile a function to draw samples, conditioned on the values of some variables.
The goal of this function is to walk the aesara computational graph from the list
Expand All @@ -1635,13 +1635,10 @@ def compile_forward_sampling_function(
- Variables in the outputs list
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
- Basic RVs that are not in the ``vars_in_trace`` list
- Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
- Variables that are keys in the ``givens_dict``
- Variables that have volatile inputs
Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op``
that are in the ``basic_rvs`` list.
Concretely, this function can be used to compile a function to sample from the
posterior predictive distribution of a model that has variables that are conditioned
on ``MutableData`` instances. The variables that depend on the mutable data will be
Expand Down Expand Up @@ -1670,12 +1667,19 @@ def compile_forward_sampling_function(
output of ``model.basic_RVs``) should have a reference to the variables that should
be considered as random variable instances. This includes variables that have
a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
Censored distributions. If ``None``, only pure random variables will be considered
as potential random variables.
Censored distributions.
givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]]
A dictionary that maps tensor variables to the values that should be used to replace them
in the compiled function. The types of the key and value should match or an error will be
raised during compilation.
Returns
-------
function: Callable
Compiled forward sampling Aesara function
volatile_basic_rvs: Set of Variable
Set of all basic_rvs that were considered volatile and will be resampled when
the function is evaluated
"""
if givens_dict is None:
givens_dict = {}
Expand Down Expand Up @@ -1741,7 +1745,10 @@ def expand(node):
for node, value in givens_dict.items()
]

return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs)
return (
compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled
)


def sample_posterior_predictive(
Expand Down Expand Up @@ -1900,7 +1907,6 @@ def sample_posterior_predictive(
vars_ = model.observed_RVs + model.auto_deterministics

indices = np.arange(samples)

if progressbar:
indices = progress_bar(indices, total=samples, display=progressbar)

Expand All @@ -1923,17 +1929,17 @@ def sample_posterior_predictive(
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = point_wrapper(
compile_forward_sampling_function(
outputs=vars_to_sample,
vars_in_trace=vars_in_trace,
basic_rvs=model.basic_RVs,
givens_dict=None,
random_seed=random_seed,
**compile_kwargs,
)
_sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
outputs=vars_to_sample,
vars_in_trace=vars_in_trace,
basic_rvs=model.basic_RVs,
givens_dict=None,
random_seed=random_seed,
**compile_kwargs,
)

sampler_fn = point_wrapper(_sampler_fn)
# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
if isinstance(_trace, MultiTrace):
Expand Down Expand Up @@ -2242,7 +2248,7 @@ def sample_prior_predictive(
compile_kwargs.setdefault("allow_input_downcast", True)
compile_kwargs.setdefault("accept_inplace", True)

sampler_fn = compile_forward_sampling_function(
sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
vars_to_sample,
vars_in_trace=[],
basic_rvs=model.basic_RVs,
Expand All @@ -2251,6 +2257,8 @@ def sample_prior_predictive(
**compile_kwargs,
)

# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
values = zip(*(sampler_fn() for i in range(samples)))

data = {k: np.stack(v) for k, v in zip(names, values)}
Expand Down
Loading

0 comments on commit feb72ae

Please sign in to comment.