-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
414 long running time of sample posterior predictive and eventual death by oom #436
414 long running time of sample posterior predictive and eventual death by oom #436
Conversation
…-hddm Merging main.
…d to _mean prediction consistently
I will try to add a few more tests to this before merging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Two higher-level comments:
-
Since the only call to
simulator
is done here:HSSM/src/hssm/distribution_utils/dist.py
Lines 309 to 315 in 19f786d
sim_out = simulator( theta=theta, model=model_name, n_samples=n_samples, random_state=seed, **kwargs, ) -
InferenceData
object does not come with attributes likeposterior
, orposterior_predictive
by default, so type checker complains. The use of the square bracket notation is preferred. Or if this is too annoying we can disable this check (attr-defined
) globally inpyproject.toml
mypy
section, but that can be a bit risky
src/hssm/hssm.py
Outdated
|
||
if "posterior_predictive" in idata.groups(): | ||
del idata.posterior_predictive | ||
print("pre-existing posterior_predictive group deleted from idata. \n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a warning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome! Just some style suggestions at this point. Feel free to merge after the fixes :)
src/hssm/hssm.py
Outdated
if "posterior_predictive" in idata.groups(): | ||
if idata is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the order be the other way around?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@digicosmos86 changed. This was useless to begin with, just an artifact appeasing mypy...
src/hssm/hssm.py
Outdated
@@ -10,7 +10,7 @@ | |||
from copy import deepcopy | |||
from inspect import isclass | |||
from os import PathLike | |||
from typing import Any, Callable, Literal | |||
from typing import Any, Callable, Literal, Union |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use Union
any more. Now that we have Python 3.10, we use the |
operator instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
@@ -404,6 +404,7 @@ def __init__( | |||
self.model, self._parent_param, self.response_c, self.response_str | |||
) | |||
self.set_alias(self._aliases) | |||
# _logger.info(self.pymc_model.initial_point()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove debug comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stylistically eventually yes, but rn, I think it can sometimes still help future PRs that interact with this code. Here I literally have the next PR that I need to work on in mind. So in general agree, but let's skip here :)
|
||
# The parent was previously not part of deterministics --> compute it via | ||
# posterior_predictive (works because it acts as the 'mu' parameter | ||
# in the GLM as far as bambi is concerned) | ||
if self._inference_obj is not None: | ||
if self._parent not in self._inference_obj.posterior.data_vars.keys(): | ||
self.model.predict(self._inference_obj, kind="mean", inplace=True) | ||
# self.model.predict(self._inference_obj, kind="mean", inplace=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove debug comments?
src/hssm/hssm.py
Outdated
self._parent in self._inference_obj.posterior.data_vars.keys() | ||
and "rt,response_mean" in self._inference_obj.posterior.data_vars.keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_vars
are dicts
, so the Python 3 style is to not use keys()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
and not np.allclose(draws, idata["posterior"].draw.values) | ||
): | ||
# Reassign posterior to sub-sampled version | ||
setattr(idata_copy, "posterior", idata["posterior"].isel(draw=draws)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any differences between setattr()
and idata.add_groups()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be honest I don't know... let me look into that independently to understand it properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, at least used somewhat semantically here, add_groups
is about new groups, setattr
is about reassigning to pre-existing group.
src/hssm/hssm.py
Outdated
if safe_mode: | ||
# safe mode splits the draws into chunks of 10 to avoid | ||
# memory issues (TODO: Figure out the source of memory issues) | ||
split_draws = _split_array( | ||
idata_copy["posterior"].draw.values, divisor=10 | ||
) | ||
|
||
posterior_predictive_list = [] | ||
for samples_tmp in split_draws: | ||
tmp_posterior = idata["posterior"].sel(draw=samples_tmp) | ||
setattr(idata_copy, "posterior", tmp_posterior) | ||
self.model.predict( | ||
idata_copy, kind, data, True, include_group_specific | ||
) | ||
posterior_predictive_list.append(idata_copy["posterior_predictive"]) | ||
|
||
if inplace: | ||
idata.add_groups( | ||
posterior_predictive=xr.concat( | ||
posterior_predictive_list, dim="draw" | ||
) | ||
) | ||
# for inplace, we don't return anything | ||
return None | ||
else: | ||
# Reassign original posterior to idata_copy | ||
setattr(idata_copy, "posterior", idata["posterior"]) | ||
# Add new posterior predictive group to idata_copy | ||
del idata_copy["posterior_predictive"] | ||
idata_copy.add_groups( | ||
posterior_predictive=xr.concat( | ||
posterior_predictive_list, dim="draw" | ||
) | ||
) | ||
return idata_copy | ||
elif inplace: | ||
# If not safe-mode | ||
# We call .predict() directly without any | ||
# chunking of data. | ||
|
||
# .predict() is called on the copy of idata | ||
# since we still subsampled (or assigned) the draws | ||
self.model.predict(idata_copy, kind, data, True, include_group_specific) | ||
|
||
# posterior predictive group added to idata | ||
idata.add_groups( | ||
posterior_predictive=idata_copy["posterior_predictive"] | ||
) | ||
|
||
# don't return anything if inplace | ||
return None | ||
|
||
else: | ||
# Not safe mode and not inplace | ||
# Function acts as very thin wrapper around | ||
# .predict(). It just operates on the | ||
# idata_copy object | ||
return self.model.predict( | ||
idata_copy, kind, data, inplace, include_group_specific | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if block looks slightly confusing. I think I understand what you mean, but would
if safe_mode:
if inplace:
...
else:
...
else:
if inplace:
...
else:
...
be more readable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
return self.model.predict( | ||
idata_copy, kind, data, False, include_group_specific | ||
idata, kind, data, inplace, include_group_specific | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an else
clause here to throw an error whenever other values are specified?
src/hssm/hssm.py
Outdated
@@ -1353,6 +1477,35 @@ def _get_deterministic_var_names(self, idata) -> list[str]: | |||
|
|||
return var_names | |||
|
|||
def _drop_parent_str_from_idata( | |||
self, idata: Union[az.InferenceData, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self, idata: Union[az.InferenceData, None] | |
self, idata: az.InferenceData | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed that
safe_mode
that chunks computationsn_samples
argument was renamed todraws
, and one can pass None | int | list | np.ndarraykind='mean'
, the posterior naming cleans uprt,response_mean
-->v
.traces
now, and naming is also cleaned upsample_prior_predictive()
will include the parent parameter as well now via an internal call to.predict()