-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add blas_cores argument to pm.sample #7318
Conversation
51daa0c
to
15cbbf4
Compare
pymc/sampling/mcmc.py
Outdated
@@ -499,6 +504,13 @@ def sample( | |||
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. | |||
This requires the chosen sampler to be installed. | |||
All samplers, except "pymc", require the full model to be continuous. | |||
blas_cores: int or "auto" or None, default = "auto" | |||
The total number of threads blas and openmp functions should use during sampling. If set to None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain the default first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Do you think we should already default to "auto", or first release something where the default is None so that this can be tested a bit more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the new default makes much more sense. This often shows up in MvNormal models and it's very tricky for beginners to debug
pymc/sampling/mcmc.py
Outdated
if cores < 1: | ||
raise ValueError("`cores` must be larger or equal to one") | ||
|
||
if chains < 1: | ||
raise ValueError("`chains` must be larger or equal to one") | ||
|
||
if blas_cores is not None and blas_cores < 1: | ||
raise ValueError("`blas_cores` must be larger or equal to one") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove it for the sake of less code? I don't believe anybody was ever hurt by this and couldn't figure out the problem?
pre commit failed |
Description
We currently do not configure blas in any way. This can lead to very bad behavior if we sample in several threads:
Many blas implementations default to using one worker thread per hardware thread in the machine. But if we sample in parallel with multiprocessing, each chain will use an independent thread pool, so we end up starting
chains*hardware_chains
worker threads. Combined with some spinnlocking that some blas implementations seem to do, this can lead to terrible performance.This PR adds a
blas_cores
argument topm.sample()
, and then usesthreadpoolctl
to control how many worker threads we start.If it is set to
None
, we don't do anything, and keep the current behavior of just using whatever the blas implementation uses as default. If set toauto
(the default) use the cores argument to guess a decent number of blas worker threads. If it is set to an integer, we use that number of total blas worker.See for instance here for a model that shows bad behavior without this PR.
Type of change
📚 Documentation preview 📚: https://pymc--7318.org.readthedocs.build/en/7318/