Skip to content

Commit

Permalink
Merge pull request #436 from lnccbrown/414-long-running-time-of-sampl…
Browse files Browse the repository at this point in the history
…e_posterior_predictive-and-eventual-death-by-oom

414 long running time of sample posterior predictive and eventual death by oom
  • Loading branch information
AlexanderFengler authored May 23, 2024
2 parents 32ed463 + c965923 commit 80c5248
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 44 deletions.
212 changes: 184 additions & 28 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
_get_alias_dict,
_print_prior,
_process_param_in_kwargs,
_random_sample,
_rearrange_data,
_split_array,
)

from . import plotting
Expand Down Expand Up @@ -404,6 +404,7 @@ def __init__(
self.model, self._parent_param, self.response_c, self.response_str
)
self.set_alias(self._aliases)
# _logger.info(self.pymc_model.initial_point())
self._postprocess_initvals_deterministic(initval_settings=INITVAL_SETTINGS)
self._jitter_initvals(
jitter_epsilon=INITVAL_JITTER_SETTINGS["jitter_epsilon"], vector_only=True
Expand Down Expand Up @@ -510,24 +511,41 @@ def sample(
+ "The jitter argument will be ignored."
)
del kwargs["jitter"]
else:
pass

self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
)
if "include_mean" not in kwargs:
# If not specified, include the mean prediction in
# kwargs to be passed to the model.fit() method
kwargs["include_mean"] = True
idata = self.model.fit(inference_method=sampler, init=init, **kwargs)

if self._inference_obj is None:
self._inference_obj = idata
elif isinstance(self._inference_obj, az.InferenceData):
self._inference_obj.extend(idata)
else:
raise ValueError(
"The model has an attached inference object under"
+ " self._inference_obj, but it is not an InferenceData object."
)

# The parent was previously not part of deterministics --> compute it via
# posterior_predictive (works because it acts as the 'mu' parameter
# in the GLM as far as bambi is concerned)
if self._inference_obj is not None:
if self._parent not in self._inference_obj.posterior.data_vars.keys():
self.model.predict(self._inference_obj, kind="mean", inplace=True)
# self.model.predict(self._inference_obj, kind="mean", inplace=True)
# rename 'rt,response_mean' to 'v' so in the traces everything
# looks the way it should
self._inference_obj.rename_vars(
{"rt,response_mean": self._parent}, inplace=True
)
elif (
self._parent in self._inference_obj.posterior.data_vars
and "rt,response_mean" in self._inference_obj.posterior.data_vars
):
# drop redundant 'rt,response_mean' variable,
# if parent already in posterior
del self._inference_obj.posterior["rt,response_mean"]
return self.traces

def sample_posterior_predictive(
Expand All @@ -537,7 +555,8 @@ def sample_posterior_predictive(
inplace: bool = True,
include_group_specific: bool = True,
kind: Literal["pps", "mean"] = "pps",
n_samples: int | float | None = None,
draws: int | float | list[int] | np.ndarray | None = None,
safe_mode: bool = True,
) -> az.InferenceData | None:
"""Perform posterior predictive sampling from the HSSM model.
Expand All @@ -563,7 +582,7 @@ def sample_posterior_predictive(
latter returns the draws from the posterior predictive distribution
(i.e. the posterior probability distribution for a new observation).
Defaults to `"pps"`.
n_samples: optional
draws: optional
The number of samples to draw from the posterior predictive distribution
from each chain.
When it's an integer >= 1, the number of samples to be extracted from the
Expand All @@ -574,6 +593,9 @@ def sample_posterior_predictive(
posterior predictive sampling.. If this proportion is very
small, at least one sample will be used. When None, all posterior samples
will be used. Defaults to None.
safe_mode: bool
If True, the function will split the draws into chunks of 10 to avoid memory
issues. Defaults to True.
Raises
------
Expand All @@ -592,35 +614,123 @@ def sample_posterior_predictive(
+ "Please either provide an idata object or sample the model first."
)
idata = self._inference_obj
_logger.info(
"idata=None, we use the traces assigned to the HSSM object as idata."
)

if idata is not None:
if "posterior_predictive" in idata.groups():
del idata["posterior_predictive"]
_logger.warning(
"pre-existing posterior_predictive group deleted from idata. \n"
)

if self._check_extra_fields(data):
self._update_extra_fields(data)

if n_samples is not None:
# Make a copy of idata, set the `posterior` group to be a random sub-sample
# of the original (draw dimension gets sub-sampled)
idata_copy = idata.copy()
idata_random_sample = _random_sample(
idata_copy["posterior"], n_samples=n_samples
if isinstance(draws, np.ndarray):
draws = draws.astype(int)
elif isinstance(draws, list):
draws = np.array(draws).astype(int)
elif isinstance(draws, int | float):
draws = np.arange(int(draws))
elif draws is None:
draws = idata["posterior"].draw.values
else:
raise ValueError(
"draws must be an integer, " + "a list of integers, or a numpy array."
)
delattr(idata_copy, "posterior")
idata_copy.add_groups(posterior=idata_random_sample)

# If the user specifies an inplace operation, we need to modify the original
if inplace:
self.model.predict(idata_copy, kind, data, True, include_group_specific)
idata.add_groups(
posterior_predictive=idata_copy["posterior_predictive"]

assert isinstance(draws, np.ndarray)

# Make a copy of idata, set the `posterior` group to be a random sub-sample
# of the original (draw dimension gets sub-sampled)

idata_copy = idata.copy()

if (draws.shape != idata["posterior"].draw.values.shape) or (
(draws.shape == idata["posterior"].draw.values.shape)
and not np.allclose(draws, idata["posterior"].draw.values)
):
# Reassign posterior to sub-sampled version
setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws))

if kind == "pps":
# If we run kind == 'pps' we actually run the observation RV
if safe_mode:
# safe mode splits the draws into chunks of 10 to avoid
# memory issues (TODO: Figure out the source of memory issues)
split_draws = _split_array(
idata_copy["posterior"].draw.values, divisor=10
)

return None
posterior_predictive_list = []
for samples_tmp in split_draws:
tmp_posterior = idata["posterior"].sel(draw=samples_tmp)
setattr(idata_copy, "posterior", tmp_posterior)
self.model.predict(
idata_copy, kind, data, True, include_group_specific
)
posterior_predictive_list.append(idata_copy["posterior_predictive"])

if inplace:
idata.add_groups(
posterior_predictive=xr.concat(
posterior_predictive_list, dim="draw"
)
)
# for inplace, we don't return anything
return None
else:
# Reassign original posterior to idata_copy
setattr(idata_copy, "posterior", idata["posterior"])
# Add new posterior predictive group to idata_copy
del idata_copy["posterior_predictive"]
idata_copy.add_groups(
posterior_predictive=xr.concat(
posterior_predictive_list, dim="draw"
)
)
return idata_copy
else:
if inplace:
# If not safe-mode
# We call .predict() directly without any
# chunking of data.

# .predict() is called on the copy of idata
# since we still subsampled (or assigned) the draws
self.model.predict(
idata_copy, kind, data, True, include_group_specific
)

# posterior predictive group added to idata
idata.add_groups(
posterior_predictive=idata_copy["posterior_predictive"]
)
# don't return anything if inplace
return None
else:
# Not safe mode and not inplace
# Function acts as very thin wrapper around
# .predict(). It just operates on the
# idata_copy object
return self.model.predict(
idata_copy, kind, data, False, include_group_specific
)
elif kind == "mean":
# If kind == 'mean', we don't need to run the RV directly,
# there shouldn't really be any significant memory issues here,
# we can simply ignore settings, since the computational overhead
# should be very small --> nudges user towards good outputs.
_logger.warning(
"The kind argument is set to 'mean', but 'draws' argument "
+ "is not None: The draws argument will be ignored!"
)
return self.model.predict(
idata_copy, kind, data, False, include_group_specific
idata, kind, data, inplace, include_group_specific
)

return self.model.predict(idata, kind, data, inplace, include_group_specific)

def plot_posterior_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
"""Produce a posterior predictive plot.
Expand Down Expand Up @@ -678,7 +788,24 @@ def sample_prior_predictive(
``InferenceData`` object with the groups ``prior``, ``prior_predictive`` and
``observed_data``.
"""
return self.model.prior_predictive(draws, var_names, omit_offsets, random_seed)
prior_predictive = self.model.prior_predictive(
draws, var_names, omit_offsets, random_seed
)

prior_predictive.add_groups(posterior=prior_predictive.prior)
self.model.predict(prior_predictive, kind="mean", inplace=True)

# clean
setattr(prior_predictive, "prior", prior_predictive["posterior"])
del prior_predictive["posterior"]

if self._inference_obj is None:
self._inference_obj = prior_predictive
else:
self._inference_obj.extend(prior_predictive)

# clean up `rt,response_mean` to `v`
return self._drop_parent_str_from_idata(idata=self._inference_obj)

@property
def pymc_model(self) -> pm.Model:
Expand Down Expand Up @@ -1353,6 +1480,35 @@ def _get_deterministic_var_names(self, idata) -> list[str]:

return var_names

def _drop_parent_str_from_idata(
self, idata: az.InferenceData | None
) -> az.InferenceData:
"""Drop the parent_str variable from an InferenceData object.
Parameters
----------
idata
The InferenceData object to be modified.
Returns
-------
xr.Dataset
The modified InferenceData object.
"""
if idata is None:
raise ValueError("Please provide an InferenceData object.")
else:
for group in idata.groups():
if ("rt,response_mean" in idata[group].data_vars) and (
self._parent not in idata[group].data_vars
):
setattr(
idata,
group,
idata[group].rename({"rt,response_mean": self._parent}),
)
return idata

def _handle_missing_data_and_deadline(self):
"""Handle missing data and deadline."""
if not self.missing_data and not self.deadline:
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _use_traces_or_sample(
idata=idata,
data=data,
inplace=True,
n_samples=n_samples,
draws=n_samples,
)
idata = model.traces
sampled = True
Expand Down
5 changes: 5 additions & 0 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,8 @@ def _rearrange_data(data: pd.DataFrame | np.ndarray) -> pd.DataFrame | np.ndarra
split_not_missing = data[~missing_indices, :]

return np.concatenate([split_missing, split_not_missing])


def _split_array(data: np.ndarray | list[int], divisor: int) -> list[np.ndarray]:
num_splits = len(data) // divisor + (1 if len(data) % divisor != 0 else 0)
return [tmp.astype(int) for tmp in np.array_split(data, num_splits)]
Loading

0 comments on commit 80c5248

Please sign in to comment.