Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make forward sampling functions return InferenceData #5073

Merged
merged 22 commits into from
Oct 15, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def sample(
idata_kwargs: dict = None,
mp_ctx=None,
**kwargs,
):
) -> Union[InferenceData, MultiTrace]:
r"""Draw samples from the posterior using the given step methods.

Multiple step methods are supported via compound step methods.
Expand Down Expand Up @@ -338,7 +338,7 @@ def sample(
init methods.
return_inferencedata : bool, default=True
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
Defaults to `False`, but we'll switch to `True` in an upcoming release.
Defaults to `True`.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`
mp_ctx : multiprocessing.context.BaseContent
Expand Down Expand Up @@ -1893,7 +1893,9 @@ def sample_prior_predictive(
var_names: Optional[Iterable[str]] = None,
random_seed=None,
mode: Optional[Union[str, Mode]] = None,
) -> Dict[str, np.ndarray]:
return_inferencedata=None,
michaelosthege marked this conversation as resolved.
Show resolved Hide resolved
idata_kwargs: dict = None,
) -> Union[InferenceData, Dict[str, np.ndarray]]:
"""Generate samples from the prior predictive distribution.

Parameters
Expand All @@ -1909,14 +1911,21 @@ def sample_prior_predictive(
Seed for the random number generator.
mode:
The mode used by ``aesara.function`` to compile the graph.
return_inferencedata : bool, default=True
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
Defaults to `True`.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
dict
Dictionary with variable names as keys. The values are numpy arrays of prior
samples.
arviz.InferenceData or Dict
An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default),
or a dictionary with variable names as keys and samples as numpy arrays.
"""
model = modelcontext(model)
if return_inferencedata is None:
return_inferencedata = True

if model.potentials:
warnings.warn(
Expand Down Expand Up @@ -1980,7 +1989,15 @@ def sample_prior_predictive(
for var_name in vars_:
if var_name in data:
prior[var_name] = data[var_name]
return prior

if not return_inferencedata:
return prior

ikwargs = dict(model=model)
if idata_kwargs:
ikwargs.update(idata_kwargs)

return pm.to_inference_data(prior=prior, **ikwargs)


def _init_jitter(model, point, chains, jitter_max_retries):
Expand Down