Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

414 long running time of sample posterior predictive and eventual death by oom #436

Conversation

AlexanderFengler
Copy link
Collaborator

  • posterior predictive now has a safe_mode that chunks computations
  • the n_samples argument was renamed to draws, and one can pass None | int | list | np.ndarray
  • when running posterior predictive with kind='mean', the posterior naming cleans up rt,response_mean --> v
  • prior predictives get assigned to .traces now, and naming is also cleaned up
  • our sample_prior_predictive() will include the parent parameter as well now via an internal call to .predict()

@AlexanderFengler
Copy link
Collaborator Author

I will try to add a few more tests to this before merging.

Copy link
Collaborator

@digicosmos86 digicosmos86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Two higher-level comments:

  1. Since the only call to simulator is done here:

    sim_out = simulator(
    theta=theta,
    model=model_name,
    n_samples=n_samples,
    random_state=seed,
    **kwargs,
    )
    , maybe we can use a for-loop here over n_samples to make the sampling safe, instead of patching the higher-level functions themselves? This way we can avoid running many intermediate-level code multiple times.

  2. InferenceData object does not come with attributes like posterior, or posterior_predictive by default, so type checker complains. The use of the square bracket notation is preferred. Or if this is too annoying we can disable this check (attr-defined) globally in pyproject.toml mypy section, but that can be a bit risky

src/hssm/hssm.py Outdated Show resolved Hide resolved
src/hssm/hssm.py Outdated Show resolved Hide resolved
src/hssm/hssm.py Outdated Show resolved Hide resolved
src/hssm/hssm.py Outdated

if "posterior_predictive" in idata.groups():
del idata.posterior_predictive
print("pre-existing posterior_predictive group deleted from idata. \n")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a warning

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

src/hssm/hssm.py Outdated Show resolved Hide resolved
src/hssm/hssm.py Outdated Show resolved Hide resolved
src/hssm/hssm.py Outdated Show resolved Hide resolved
src/hssm/utils.py Outdated Show resolved Hide resolved
src/hssm/hssm.py Outdated Show resolved Hide resolved
src/hssm/hssm.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@digicosmos86 digicosmos86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! Just some style suggestions at this point. Feel free to merge after the fixes :)

src/hssm/hssm.py Outdated
Comment on lines 614 to 615
if "posterior_predictive" in idata.groups():
if idata is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the order be the other way around?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digicosmos86 changed. This was useless to begin with, just an artifact appeasing mypy...

src/hssm/hssm.py Outdated
@@ -10,7 +10,7 @@
from copy import deepcopy
from inspect import isclass
from os import PathLike
from typing import Any, Callable, Literal
from typing import Any, Callable, Literal, Union
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use Union any more. Now that we have Python 3.10, we use the | operator instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

@@ -404,6 +404,7 @@ def __init__(
self.model, self._parent_param, self.response_c, self.response_str
)
self.set_alias(self._aliases)
# _logger.info(self.pymc_model.initial_point())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove debug comments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistically eventually yes, but rn, I think it can sometimes still help future PRs that interact with this code. Here I literally have the next PR that I need to work on in mind. So in general agree, but let's skip here :)


# The parent was previously not part of deterministics --> compute it via
# posterior_predictive (works because it acts as the 'mu' parameter
# in the GLM as far as bambi is concerned)
if self._inference_obj is not None:
if self._parent not in self._inference_obj.posterior.data_vars.keys():
self.model.predict(self._inference_obj, kind="mean", inplace=True)
# self.model.predict(self._inference_obj, kind="mean", inplace=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove debug comments?

src/hssm/hssm.py Outdated
Comment on lines 543 to 544
self._parent in self._inference_obj.posterior.data_vars.keys()
and "rt,response_mean" in self._inference_obj.posterior.data_vars.keys()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_vars are dicts, so the Python 3 style is to not use keys()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

and not np.allclose(draws, idata["posterior"].draw.values)
):
# Reassign posterior to sub-sampled version
setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any differences between setattr() and idata.add_groups()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be honest I don't know... let me look into that independently to understand it properly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, at least used somewhat semantically here, add_groups is about new groups, setattr is about reassigning to pre-existing group.

src/hssm/hssm.py Outdated
Comment on lines 660 to 717
if safe_mode:
# safe mode splits the draws into chunks of 10 to avoid
# memory issues (TODO: Figure out the source of memory issues)
split_draws = _split_array(
idata_copy["posterior"].draw.values, divisor=10
)

posterior_predictive_list = []
for samples_tmp in split_draws:
tmp_posterior = idata["posterior"].sel(draw=samples_tmp)
setattr(idata_copy, "posterior", tmp_posterior)
self.model.predict(
idata_copy, kind, data, True, include_group_specific
)
posterior_predictive_list.append(idata_copy["posterior_predictive"])

if inplace:
idata.add_groups(
posterior_predictive=xr.concat(
posterior_predictive_list, dim="draw"
)
)
# for inplace, we don't return anything
return None
else:
# Reassign original posterior to idata_copy
setattr(idata_copy, "posterior", idata["posterior"])
# Add new posterior predictive group to idata_copy
del idata_copy["posterior_predictive"]
idata_copy.add_groups(
posterior_predictive=xr.concat(
posterior_predictive_list, dim="draw"
)
)
return idata_copy
elif inplace:
# If not safe-mode
# We call .predict() directly without any
# chunking of data.

# .predict() is called on the copy of idata
# since we still subsampled (or assigned) the draws
self.model.predict(idata_copy, kind, data, True, include_group_specific)

# posterior predictive group added to idata
idata.add_groups(
posterior_predictive=idata_copy["posterior_predictive"]
)

# don't return anything if inplace
return None

else:
# Not safe mode and not inplace
# Function acts as very thin wrapper around
# .predict(). It just operates on the
# idata_copy object
return self.model.predict(
idata_copy, kind, data, inplace, include_group_specific
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if block looks slightly confusing. I think I understand what you mean, but would

if safe_mode:
    if inplace:
        ...
    else:
        ...
else:
    if inplace:
         ...
    else:
         ...

be more readable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

return self.model.predict(
idata_copy, kind, data, False, include_group_specific
idata, kind, data, inplace, include_group_specific
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an else clause here to throw an error whenever other values are specified?

src/hssm/hssm.py Outdated
@@ -1353,6 +1477,35 @@ def _get_deterministic_var_names(self, idata) -> list[str]:

return var_names

def _drop_parent_str_from_idata(
self, idata: Union[az.InferenceData, None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self, idata: Union[az.InferenceData, None]
self, idata: az.InferenceData | None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed that

@AlexanderFengler AlexanderFengler merged commit 80c5248 into main May 23, 2024
2 checks passed
@digicosmos86 digicosmos86 deleted the 414-long-running-time-of-sample_posterior_predictive-and-eventual-death-by-oom branch November 28, 2024 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Long running time of sample_posterior_predictive() and eventual death by OOM
3 participants