Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

414 long running time of sample posterior predictive and eventual death by oom #436

212 changes: 184 additions & 28 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
_get_alias_dict,
_print_prior,
_process_param_in_kwargs,
_random_sample,
_rearrange_data,
_split_array,
)

from . import plotting
Expand Down Expand Up @@ -404,6 +404,7 @@ def __init__(
self.model, self._parent_param, self.response_c, self.response_str
)
self.set_alias(self._aliases)
# _logger.info(self.pymc_model.initial_point())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove debug comments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistically eventually yes, but rn, I think it can sometimes still help future PRs that interact with this code. Here I literally have the next PR that I need to work on in mind. So in general agree, but let's skip here :)

self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS)
self._jitter_initvals(
jitter_epsilon=INITVAL_JITTER_SETTINGS["jitter_epsilon"], vector_only=True
Expand Down Expand Up @@ -510,24 +511,41 @@ def sample(
+ "The jitter argument will be ignored."
)
del kwargs["jitter"]
else:
pass

self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
)
if "include_mean" not in kwargs:
# If not specified, include the mean prediction in
# kwargs to be passed to the model.fit() method
kwargs["include_mean"] = True
idata = self.model.fit(inference_method=sampler, init=init, **kwargs)

if self._inference_obj is None:
self._inference_obj = idata
elif isinstance(self._inference_obj, az.InferenceData):
self._inference_obj.extend(idata)
else:
raise ValueError(
"The model has an attached inference object under"
+ " self._inference_obj, but it is not an InferenceData object."
)

# The parent was previously not part of deterministics --> compute it via
# posterior_predictive (works because it acts as the 'mu' parameter
# in the GLM as far as bambi is concerned)
if self._inference_obj is not None:
if self._parent not in self._inference_obj.posterior.data_vars.keys():
self.model.predict(self._inference_obj, kind="mean", inplace=True)
# self.model.predict(self._inference_obj, kind="mean", inplace=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove debug comments?

# rename 'rt,response_mean' to 'v' so in the traces everything
# looks the way it should
self._inference_obj.rename_vars(
{"rt,response_mean": self._parent}, inplace=True
)
elif (
self._parent in self._inference_obj.posterior.data_vars
and "rt,response_mean" in self._inference_obj.posterior.data_vars
):
# drop redundant 'rt,response_mean' variable,
# if parent already in posterior
del self._inference_obj.posterior["rt,response_mean"]
return self.traces

def sample_posterior_predictive(
Expand All @@ -537,7 +555,8 @@ def sample_posterior_predictive(
inplace: bool = True,
include_group_specific: bool = True,
kind: Literal["pps", "mean"] = "pps",
n_samples: int | float | None = None,
draws: int | float | list[int] | np.ndarray | None = None,
safe_mode: bool = True,
digicosmos86 marked this conversation as resolved.
Show resolved Hide resolved
) -> az.InferenceData | None:
"""Perform posterior predictive sampling from the HSSM model.

Expand All @@ -563,7 +582,7 @@ def sample_posterior_predictive(
latter returns the draws from the posterior predictive distribution
(i.e. the posterior probability distribution for a new observation).
Defaults to `"pps"`.
n_samples: optional
draws: optional
The number of samples to draw from the posterior predictive distribution
from each chain.
When it's an integer >= 1, the number of samples to be extracted from the
Expand All @@ -574,6 +593,9 @@ def sample_posterior_predictive(
posterior predictive sampling.. If this proportion is very
small, at least one sample will be used. When None, all posterior samples
will be used. Defaults to None.
safe_mode: bool
If True, the function will split the draws into chunks of 10 to avoid memory
issues. Defaults to True.

Raises
------
Expand All @@ -592,35 +614,123 @@ def sample_posterior_predictive(
+ "Please either provide an idata object or sample the model first."
)
idata = self._inference_obj
_logger.info(
"idata=None, we use the traces assigned to the HSSM object as idata."
)

if idata is not None:
if "posterior_predictive" in idata.groups():
del idata["posterior_predictive"]
_logger.warning(
"pre-existing posterior_predictive group deleted from idata. \n"
)

if self._check_extra_fields(data):
self._update_extra_fields(data)

if n_samples is not None:
# Make a copy of idata, set the `posterior` group to be a random sub-sample
# of the original (draw dimension gets sub-sampled)
idata_copy = idata.copy()
idata_random_sample = _random_sample(
idata_copy["posterior"], n_samples=n_samples
if isinstance(draws, np.ndarray):
draws = draws.astype(int)
elif isinstance(draws, list):
draws = np.array(draws).astype(int)
elif isinstance(draws, int | float):
draws = np.arange(int(draws))
elif draws is None:
draws = idata["posterior"].draw.values
else:
raise ValueError(
"draws must be an integer, " + "a list of integers, or a numpy array."
)
delattr(idata_copy, "posterior")
idata_copy.add_groups(posterior=idata_random_sample)

# If the user specifies an inplace operation, we need to modify the original
if inplace:
self.model.predict(idata_copy, kind, data, True, include_group_specific)
idata.add_groups(
posterior_predictive=idata_copy["posterior_predictive"]

assert isinstance(draws, np.ndarray)

# Make a copy of idata, set the `posterior` group to be a random sub-sample
# of the original (draw dimension gets sub-sampled)

idata_copy = idata.copy()

if (draws.shape != idata["posterior"].draw.values.shape) or (
(draws.shape == idata["posterior"].draw.values.shape)
and not np.allclose(draws, idata["posterior"].draw.values)
):
# Reassign posterior to sub-sampled version
setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any differences between setattr() and idata.add_groups()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be honest I don't know... let me look into that independently to understand it properly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, at least used somewhat semantically here, add_groups is about new groups, setattr is about reassigning to pre-existing group.


if kind == "pps":
# If we run kind == 'pps' we actually run the observation RV
if safe_mode:
# safe mode splits the draws into chunks of 10 to avoid
# memory issues (TODO: Figure out the source of memory issues)
split_draws = _split_array(
idata_copy["posterior"].draw.values, divisor=10
)

return None
posterior_predictive_list = []
for samples_tmp in split_draws:
tmp_posterior = idata["posterior"].sel(draw=samples_tmp)
setattr(idata_copy, "posterior", tmp_posterior)
self.model.predict(
idata_copy, kind, data, True, include_group_specific
)
posterior_predictive_list.append(idata_copy["posterior_predictive"])

if inplace:
idata.add_groups(
posterior_predictive=xr.concat(
posterior_predictive_list, dim="draw"
)
)
# for inplace, we don't return anything
return None
else:
# Reassign original posterior to idata_copy
setattr(idata_copy, "posterior", idata["posterior"])
# Add new posterior predictive group to idata_copy
del idata_copy["posterior_predictive"]
idata_copy.add_groups(
posterior_predictive=xr.concat(
posterior_predictive_list, dim="draw"
)
)
return idata_copy
else:
if inplace:
# If not safe-mode
# We call .predict() directly without any
# chunking of data.

# .predict() is called on the copy of idata
# since we still subsampled (or assigned) the draws
self.model.predict(
idata_copy, kind, data, True, include_group_specific
)

# posterior predictive group added to idata
idata.add_groups(
posterior_predictive=idata_copy["posterior_predictive"]
)
# don't return anything if inplace
return None
else:
# Not safe mode and not inplace
# Function acts as very thin wrapper around
# .predict(). It just operates on the
# idata_copy object
return self.model.predict(
idata_copy, kind, data, False, include_group_specific
)
elif kind == "mean":
# If kind == 'mean', we don't need to run the RV directly,
# there shouldn't really be any significant memory issues here,
# we can simply ignore settings, since the computational overhead
# should be very small --> nudges user towards good outputs.
_logger.warning(
"The kind argument is set to 'mean', but 'draws' argument "
+ "is not None: The draws argument will be ignored!"
)
return self.model.predict(
idata_copy, kind, data, False, include_group_specific
idata, kind, data, inplace, include_group_specific
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an else clause here to throw an error whenever other values are specified?

return self.model.predict(idata, kind, data, inplace, include_group_specific)

def plot_posterior_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
"""Produce a posterior predictive plot.

Expand Down Expand Up @@ -678,7 +788,24 @@ def sample_prior_predictive(
``InferenceData`` object with the groups ``prior``, ``prior_predictive`` and
``observed_data``.
"""
return self.model.prior_predictive(draws, var_names, omit_offsets, random_seed)
prior_predictive = self.model.prior_predictive(
draws, var_names, omit_offsets, random_seed
)

prior_predictive.add_groups(posterior=prior_predictive.prior)
self.model.predict(prior_predictive, kind="mean", inplace=True)

# clean
setattr(prior_predictive, "prior", prior_predictive["posterior"])
del prior_predictive["posterior"]

if self._inference_obj is None:
self._inference_obj = prior_predictive
else:
self._inference_obj.extend(prior_predictive)

# clean up `rt,response_mean` to `v`
return self._drop_parent_str_from_idata(idata=self._inference_obj)

@property
def pymc_model(self) -> pm.Model:
Expand Down Expand Up @@ -1353,6 +1480,35 @@ def _get_deterministic_var_names(self, idata) -> list[str]:

return var_names

def _drop_parent_str_from_idata(
self, idata: az.InferenceData | None
) -> az.InferenceData:
"""Drop the parent_str variable from an InferenceData object.

Parameters
----------
idata
The InferenceData object to be modified.

Returns
-------
xr.Dataset
The modified InferenceData object.
"""
if idata is None:
raise ValueError("Please provide an InferenceData object.")
else:
for group in idata.groups():
if ("rt,response_mean" in idata[group].data_vars) and (
self._parent not in idata[group].data_vars
):
setattr(
idata,
group,
idata[group].rename({"rt,response_mean": self._parent}),
)
return idata

def _handle_missing_data_and_deadline(self):
"""Handle missing data and deadline."""
if not self.missing_data and not self.deadline:
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _use_traces_or_sample(
idata=idata,
data=data,
inplace=True,
n_samples=n_samples,
draws=n_samples,
)
idata = model.traces
sampled = True
Expand Down
5 changes: 5 additions & 0 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,8 @@ def _rearrange_data(data: pd.DataFrame | np.ndarray) -> pd.DataFrame | np.ndarra
split_not_missing = data[~missing_indices, :]

return np.concatenate([split_missing, split_not_missing])


def _split_array(data: np.ndarray | list[int], divisor: int) -> list[np.ndarray]:
num_splits = len(data) // divisor + (1 if len(data) % divisor != 0 else 0)
return [tmp.astype(int) for tmp in np.array_split(data, num_splits)]
Loading
Loading