-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
414 long running time of sample posterior predictive and eventual death by oom #436
Changes from all commits
70c57f2
927c3b3
aee9d2e
cef7726
c290409
2a06b14
85321ee
07d1140
3094391
c965923
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,8 +49,8 @@ | |
_get_alias_dict, | ||
_print_prior, | ||
_process_param_in_kwargs, | ||
_random_sample, | ||
_rearrange_data, | ||
_split_array, | ||
) | ||
|
||
from . import plotting | ||
|
@@ -404,6 +404,7 @@ def __init__( | |
self.model, self._parent_param, self.response_c, self.response_str | ||
) | ||
self.set_alias(self._aliases) | ||
# _logger.info(self.pymc_model.initial_point()) | ||
self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS) | ||
self._jitter_initvals( | ||
jitter_epsilon=INITVAL_JITTER_SETTINGS["jitter_epsilon"], vector_only=True | ||
|
@@ -510,24 +511,41 @@ def sample( | |
+ "The jitter argument will be ignored." | ||
) | ||
del kwargs["jitter"] | ||
else: | ||
pass | ||
|
||
self._inference_obj = self.model.fit( | ||
inference_method=sampler, init=init, **kwargs | ||
) | ||
if "include_mean" not in kwargs: | ||
# If not specified, include the mean prediction in | ||
# kwargs to be passed to the model.fit() method | ||
kwargs["include_mean"] = True | ||
idata = self.model.fit(inference_method=sampler, init=init, **kwargs) | ||
|
||
if self._inference_obj is None: | ||
self._inference_obj = idata | ||
elif isinstance(self._inference_obj, az.InferenceData): | ||
self._inference_obj.extend(idata) | ||
else: | ||
raise ValueError( | ||
"The model has an attached inference object under" | ||
+ " self._inference_obj, but it is not an InferenceData object." | ||
) | ||
|
||
# The parent was previously not part of deterministics --> compute it via | ||
# posterior_predictive (works because it acts as the 'mu' parameter | ||
# in the GLM as far as bambi is concerned) | ||
if self._inference_obj is not None: | ||
if self._parent not in self._inference_obj.posterior.data_vars.keys(): | ||
self.model.predict(self._inference_obj, kind="mean", inplace=True) | ||
# self.model.predict(self._inference_obj, kind="mean", inplace=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we remove debug comments? |
||
# rename 'rt,response_mean' to 'v' so in the traces everything | ||
# looks the way it should | ||
self._inference_obj.rename_vars( | ||
{"rt,response_mean": self._parent}, inplace=True | ||
) | ||
elif ( | ||
self._parent in self._inference_obj.posterior.data_vars | ||
and "rt,response_mean" in self._inference_obj.posterior.data_vars | ||
): | ||
# drop redundant 'rt,response_mean' variable, | ||
# if parent already in posterior | ||
del self._inference_obj.posterior["rt,response_mean"] | ||
return self.traces | ||
|
||
def sample_posterior_predictive( | ||
|
@@ -537,7 +555,8 @@ def sample_posterior_predictive( | |
inplace: bool = True, | ||
include_group_specific: bool = True, | ||
kind: Literal["pps", "mean"] = "pps", | ||
n_samples: int | float | None = None, | ||
draws: int | float | list[int] | np.ndarray | None = None, | ||
safe_mode: bool = True, | ||
digicosmos86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> az.InferenceData | None: | ||
"""Perform posterior predictive sampling from the HSSM model. | ||
|
||
|
@@ -563,7 +582,7 @@ def sample_posterior_predictive( | |
latter returns the draws from the posterior predictive distribution | ||
(i.e. the posterior probability distribution for a new observation). | ||
Defaults to `"pps"`. | ||
n_samples: optional | ||
draws: optional | ||
The number of samples to draw from the posterior predictive distribution | ||
from each chain. | ||
When it's an integer >= 1, the number of samples to be extracted from the | ||
|
@@ -574,6 +593,9 @@ def sample_posterior_predictive( | |
posterior predictive sampling.. If this proportion is very | ||
small, at least one sample will be used. When None, all posterior samples | ||
will be used. Defaults to None. | ||
safe_mode: bool | ||
If True, the function will split the draws into chunks of 10 to avoid memory | ||
issues. Defaults to True. | ||
|
||
Raises | ||
------ | ||
|
@@ -592,35 +614,123 @@ def sample_posterior_predictive( | |
+ "Please either provide an idata object or sample the model first." | ||
) | ||
idata = self._inference_obj | ||
_logger.info( | ||
"idata=None, we use the traces assigned to the HSSM object as idata." | ||
) | ||
|
||
if idata is not None: | ||
if "posterior_predictive" in idata.groups(): | ||
del idata["posterior_predictive"] | ||
_logger.warning( | ||
"pre-existing posterior_predictive group deleted from idata. \n" | ||
) | ||
|
||
if self._check_extra_fields(data): | ||
self._update_extra_fields(data) | ||
|
||
if n_samples is not None: | ||
# Make a copy of idata, set the `posterior` group to be a random sub-sample | ||
# of the original (draw dimension gets sub-sampled) | ||
idata_copy = idata.copy() | ||
idata_random_sample = _random_sample( | ||
idata_copy["posterior"], n_samples=n_samples | ||
if isinstance(draws, np.ndarray): | ||
draws = draws.astype(int) | ||
elif isinstance(draws, list): | ||
draws = np.array(draws).astype(int) | ||
elif isinstance(draws, int | float): | ||
draws = np.arange(int(draws)) | ||
elif draws is None: | ||
draws = idata["posterior"].draw.values | ||
else: | ||
raise ValueError( | ||
"draws must be an integer, " + "a list of integers, or a numpy array." | ||
) | ||
delattr(idata_copy, "posterior") | ||
idata_copy.add_groups(posterior=idata_random_sample) | ||
|
||
# If the user specifies an inplace operation, we need to modify the original | ||
if inplace: | ||
self.model.predict(idata_copy, kind, data, True, include_group_specific) | ||
idata.add_groups( | ||
posterior_predictive=idata_copy["posterior_predictive"] | ||
|
||
assert isinstance(draws, np.ndarray) | ||
|
||
# Make a copy of idata, set the `posterior` group to be a random sub-sample | ||
# of the original (draw dimension gets sub-sampled) | ||
|
||
idata_copy = idata.copy() | ||
|
||
if (draws.shape != idata["posterior"].draw.values.shape) or ( | ||
(draws.shape == idata["posterior"].draw.values.shape) | ||
and not np.allclose(draws, idata["posterior"].draw.values) | ||
): | ||
# Reassign posterior to sub-sampled version | ||
setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there any differences between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to be honest I don't know... let me look into that independently to understand it properly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, at least used somewhat semantically here, |
||
|
||
if kind == "pps": | ||
# If we run kind == 'pps' we actually run the observation RV | ||
if safe_mode: | ||
# safe mode splits the draws into chunks of 10 to avoid | ||
# memory issues (TODO: Figure out the source of memory issues) | ||
split_draws = _split_array( | ||
idata_copy["posterior"].draw.values, divisor=10 | ||
) | ||
|
||
return None | ||
posterior_predictive_list = [] | ||
for samples_tmp in split_draws: | ||
tmp_posterior = idata["posterior"].sel(draw=samples_tmp) | ||
setattr(idata_copy, "posterior", tmp_posterior) | ||
self.model.predict( | ||
idata_copy, kind, data, True, include_group_specific | ||
) | ||
posterior_predictive_list.append(idata_copy["posterior_predictive"]) | ||
|
||
if inplace: | ||
idata.add_groups( | ||
posterior_predictive=xr.concat( | ||
posterior_predictive_list, dim="draw" | ||
) | ||
) | ||
# for inplace, we don't return anything | ||
return None | ||
else: | ||
# Reassign original posterior to idata_copy | ||
setattr(idata_copy, "posterior", idata["posterior"]) | ||
# Add new posterior predictive group to idata_copy | ||
del idata_copy["posterior_predictive"] | ||
idata_copy.add_groups( | ||
posterior_predictive=xr.concat( | ||
posterior_predictive_list, dim="draw" | ||
) | ||
) | ||
return idata_copy | ||
else: | ||
if inplace: | ||
# If not safe-mode | ||
# We call .predict() directly without any | ||
# chunking of data. | ||
|
||
# .predict() is called on the copy of idata | ||
# since we still subsampled (or assigned) the draws | ||
self.model.predict( | ||
idata_copy, kind, data, True, include_group_specific | ||
) | ||
|
||
# posterior predictive group added to idata | ||
idata.add_groups( | ||
posterior_predictive=idata_copy["posterior_predictive"] | ||
) | ||
# don't return anything if inplace | ||
return None | ||
else: | ||
# Not safe mode and not inplace | ||
# Function acts as very thin wrapper around | ||
# .predict(). It just operates on the | ||
# idata_copy object | ||
return self.model.predict( | ||
idata_copy, kind, data, False, include_group_specific | ||
) | ||
elif kind == "mean": | ||
# If kind == 'mean', we don't need to run the RV directly, | ||
# there shouldn't really be any significant memory issues here, | ||
# we can simply ignore settings, since the computational overhead | ||
# should be very small --> nudges user towards good outputs. | ||
_logger.warning( | ||
"The kind argument is set to 'mean', but 'draws' argument " | ||
+ "is not None: The draws argument will be ignored!" | ||
) | ||
return self.model.predict( | ||
idata_copy, kind, data, False, include_group_specific | ||
idata, kind, data, inplace, include_group_specific | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an |
||
return self.model.predict(idata, kind, data, inplace, include_group_specific) | ||
|
||
def plot_posterior_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid: | ||
"""Produce a posterior predictive plot. | ||
|
||
|
@@ -678,7 +788,24 @@ def sample_prior_predictive( | |
``InferenceData`` object with the groups ``prior``, ``prior_predictive`` and | ||
``observed_data``. | ||
""" | ||
return self.model.prior_predictive(draws, var_names, omit_offsets, random_seed) | ||
prior_predictive = self.model.prior_predictive( | ||
draws, var_names, omit_offsets, random_seed | ||
) | ||
|
||
prior_predictive.add_groups(posterior=prior_predictive.prior) | ||
self.model.predict(prior_predictive, kind="mean", inplace=True) | ||
|
||
# clean | ||
setattr(prior_predictive, "prior", prior_predictive["posterior"]) | ||
del prior_predictive["posterior"] | ||
|
||
if self._inference_obj is None: | ||
self._inference_obj = prior_predictive | ||
else: | ||
self._inference_obj.extend(prior_predictive) | ||
|
||
# clean up `rt,response_mean` to `v` | ||
return self._drop_parent_str_from_idata(idata=self._inference_obj) | ||
|
||
@property | ||
def pymc_model(self) -> pm.Model: | ||
|
@@ -1353,6 +1480,35 @@ def _get_deterministic_var_names(self, idata) -> list[str]: | |
|
||
return var_names | ||
|
||
def _drop_parent_str_from_idata( | ||
self, idata: az.InferenceData | None | ||
) -> az.InferenceData: | ||
"""Drop the parent_str variable from an InferenceData object. | ||
|
||
Parameters | ||
---------- | ||
idata | ||
The InferenceData object to be modified. | ||
|
||
Returns | ||
------- | ||
xr.Dataset | ||
The modified InferenceData object. | ||
""" | ||
if idata is None: | ||
raise ValueError("Please provide an InferenceData object.") | ||
else: | ||
for group in idata.groups(): | ||
if ("rt,response_mean" in idata[group].data_vars) and ( | ||
self._parent not in idata[group].data_vars | ||
): | ||
setattr( | ||
idata, | ||
group, | ||
idata[group].rename({"rt,response_mean": self._parent}), | ||
) | ||
return idata | ||
|
||
def _handle_missing_data_and_deadline(self): | ||
"""Handle missing data and deadline.""" | ||
if not self.missing_data and not self.deadline: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove debug comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stylistically eventually yes, but rn, I think it can sometimes still help future PRs that interact with this code. Here I literally have the next PR that I need to work on in mind. So in general agree, but let's skip here :)