-
-
Notifications
You must be signed in to change notification settings - Fork 417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Reduce memory usage in log_likelihood io_pymc3 #1082
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, It would be great to run some benchmarks and also add some tests of log_likelihood
argument behaviour (better wait until #1080 to add tests though, otherwise you will probably have merging issues).
These are the benchmarks I ran using memory profiler on this code. The first 4 graphs show the memory being used with time for samples of different chains. The last graph summarizes the maximum memory used for different chains and the memory difference between the Preallocated code and Non-preallocated code. I'll add the tests soon. Thanks to @OriolAbril for all the help! Summary |
This looks great, thanks! It would also be great if you could review @rpgoldman given that it uses 2 things left to do: Could you add some tests to check On second thought, using pymc3's |
Yes I'm working on testing these cases only. Thanks!
Then the previous one with the code defined in utils would be more preferable? |
I'll try to check in the next day or so. Busy today! |
Minor nit: it's |
arviz/tests/test_data_pymc.py
Outdated
test_dict = { | ||
"posterior": ["x"], | ||
"observed_data": ["y1", "y2"], | ||
"log_likelihood": ["y1", "y2"], | ||
"sample_stats": ["diverging", "lp"], | ||
} | ||
fails = check_multiple_attrs(test_dict, inference_data) | ||
assert not fails | ||
if log_likelihood is True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be better to modify the test_dict, something like:
if not log_likelihood:
test_dict.pop("log_likelihood")
test_dict["~log_likelihood"] = []
if isinstance(log_likelihood, list):
test_dict["log_likelihood"] = ["y1", "~y2"]
Also, make sure to get your branch up to date with master! |
3377b66
to
5f74099
Compare
@@ -215,12 +216,19 @@ def test_multiple_observed_rv(self): | |||
pm.Normal("y2", x, 1, observed=y2_data) | |||
trace = pm.sample(100, chains=2) | |||
inference_data = from_pymc3(trace=trace) | |||
inference_data = from_pymc3(trace=trace, log_likelihood=log_likelihood) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the conversion is done twice.
arviz/data/io_pymc3.py
Outdated
@@ -193,7 +202,7 @@ def sample_stats_to_xarray(self): | |||
@requires("model") | |||
def log_likelihood_to_xarray(self): | |||
"""Extract log likelihood and log_p data from PyMC3 trace.""" | |||
if self.predictions: | |||
if self.predictions or not self.log_likelihood: | |||
return None | |||
data = self._extract_log_likelihood() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra idea. We could put a
try:
data = self._extract...
except TypeError:
warnings.warn("could not compute log likelihood. log_likelihood group will be omitted. Check your model object or set log_likelihood=False")
return None
Which if I am correct should fix several issues such as #395 (it has a minimal example to reproduce so it should be easy to check the issue is fixed) or even pymc-devs/pymc#3728
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it does fix both the issues!
@rpgoldman Can you explain what the 2 peaks near the middle of the graphs above represent? They seem to occur at the start of allocation of |
So you are asking about the two spikes that come at the start of the process, as opposed to the two near the end. Correct? Do we know if they are happening in the building of Can you derive plots like the above where the x axis plots position in the code, instead of clock time? |
Back in the day when I did some experiments on this, I found the pattern to be: first slope corresponds to model creation and trace allocation, plateau is posterior sampling, the gentle slope is posterior predictive sampling and steep slope is conversion to inference data (which in terms of memory is basically retrieving and allocating log_likelihood data). This would leave the two spikes at the beginning of posterior predictive sampling or at the end of posterior sampling. While writing this I realized it is not difficult to identify which of the two is happening, so I ran a quick example similar to the one above with just posterior sampling, no posterior predictive sampling. I see these same peaks towards the end of sampling. |
I changed posterior predictive sampling to pre-allocate memory because it used to be building enormous python lists and then translating them into numpy arrays. Now we pre-allocate the arrays and then just fill them. Posterior predictive sampling is still much slower than I would like (I have an open PR for this), but pre-allocation substantially reduces memory usage. That could be what is causing this issue. I'm a bit confused about why this question is here on the ArviZ development repo, instead of on the `pymc3-devs1 one. |
I have done a quick check between pymc3 versions, only posterior sampling, no posterior predictive sampling nor conversion to inference data. This is the result for 6 chains: It is here because we saw this behaviour while checking that PR did indeed reduce memory usage in If there are no issues with |
That's a relief! It can't be my fault, then, because I didn't change anything in the normal sampling, only posterior predictive. 😌
Please do! |
arviz/data/io_pymc3.py
Outdated
log_likelihood_dict = self.pymc3.sampling._DefaultTrace( # pylint: disable=protected-access | ||
len(self.trace.chains) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should wrap this in a try except block, so that when pymc3 version is an old one, the error risen explains that pymc3 should be upgraded or arviz downgraded, otherwise users won't know what hit them.
There is no need to edit requirements file though, as pymc3 development is shown there.
arviz/data/io_pymc3.py
Outdated
) | ||
except AttributeError: | ||
raise AttributeError( | ||
"Either upgrade PyMC3 to latest version or downgrade ArviZ for log_likelihood." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Installed version of ArviZ requires PyMC3>=3.8. Please upgrade with `pip install pymc3>=3.8` "
"or `conda install -c conda-forge pymc3>=3.8`."
Description
Related to issue #1077
Checklist