-
-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for numpyro and blackjax PyMC samplers #526
Conversation
OK so I've hit a snag I don't quite understand. My personal project works fine but my test example fails: import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd
import time
az.style.use("arviz-darkgrid")
rng = np.random.default_rng(0)
size = 1000
x = rng.normal(size=size)
print(x)
data = pd.DataFrame(
{
"x": x,
"y": rng.normal(loc=x, size=size)
}
)
print(data)
bmb_model = bmb.Model("y ~ x", data)
bmb_model_numpyro = bmb.Model("y ~ x", data)
bmb_model_blackjax = bmb.Model("y ~ x", data)
t0 = time.time()
idata = bmb_model.fit()
t1 = time.time()
idata_numpyro = bmb_model_numpyro.fit(chains=4, tune=1000, draws=1000, sampler_backend="numpyro", chain_method="vectorized")
t2 = time.time()
idata_blackjax = bmb_model_blackjax.fit(chains=4, tune=1000, draws=1000, sampler_backend="blackjax", chain_method="vectorized")
t3 = time.time()
print(f"Default: {t1-t0} Numpyro: {t2-t1} Blackjax: {t3-t2}") It fails on line Line 322 in 762f30a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@markgoodhead Thanks for your contribution. I have a few comments.
Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>
This is weird, this runs for me with both numpyro and blackjax |
…rt of the method argument
…lackjax # Conflicts: # bambi/models.py
How odd! Perhaps my environment isn't setup correctly and I'm behind on the latest versions. What versions of pymc/jax/arviz/xarray etc are you using? I have modified my example script to work now with the new method arg approach (and fixed a bug in chains handling I spotted). Note: I'm not actually sure if I need to construct 3 models - call it paranoia at ruling out bugs in case bmb.Model was stateful between fits 😂 import arviz as az
import bambi as bmb
import numpy as np
import pandas as pd
import time
az.style.use("arviz-darkgrid")
rng = np.random.default_rng(0)
size = 1000
x = rng.normal(size=size)
data = pd.DataFrame(
{
"x": x,
"y": rng.normal(loc=x, size=size)
}
)
bmb_model = bmb.Model("y ~ x", data)
bmb_model_numpyro = bmb.Model("y ~ x", data)
bmb_model_blackjax = bmb.Model("y ~ x", data)
t0 = time.time()
idata = bmb_model.fit()
t1 = time.time()
idata_numpyro = bmb_model_numpyro.fit(method="nuts_numpyro", chain_method="vectorized")
t2 = time.time()
idata_blackjax = bmb_model_blackjax.fit(method="nuts_blackjax", chain_method="vectorized")
t3 = time.time()
print(f"Default: {t1-t0} Numpyro: {t2-t1} Blackjax: {t3-t2}") |
Vecrorized does work for both blackjack and numpyro(
https://www.pymc.io/projects/docs/en/stable/api/samplers.html) but I found
vectorized was slower on cpu (and my gpu won’t work since I have an M1).
…On Wed, Jun 8, 2022 at 08:44 Osvaldo A Martin ***@***.***> wrote:
chain_method="vectorized"
This is weird, this runs for me with both numpyro and blackjax
—
Reply to this email directly, view it on GitHub
<#526 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AH3QQV3MUOM6VTU6MFJSXATVOCIUBANCNFSM5YGBJLIA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Yes vectorized is generally faster if you're on a single GPU, otherwise for multiple GPUs or multiple CPU cores I expect parallel would be better. |
Looks good, the only missing part is a test |
Hmm so I just tried modifying existing tests to also run the new fit methods, e.g. def test_group_specific_categorical_interaction(crossed_data):
crossed_data["fourcats"] = sum([[x] * 10 for x in ["a", "b", "c", "d"]], list()) * 3
model = Model("Y ~ continuous + (threecats:fourcats|site)", crossed_data)
model.fit(tune=10, draws=10)
model.fit(tune=10, draws=10, method="nuts_numpyro") However I again get an import error on |
Do you mind adding the test anyway? |
Tests added... Fingers crossed they actually work! |
Be sure to run black and pylint https://github.com/bambinos/bambi/blob/main/CONTRIBUTING.md#pull-request-checklist |
Done 👍 |
One small issue is that pylint isn't happy with the import within the code itself - I assume you're happy to ignore the error here? |
You could add |
I just tried updating my version of xarray (which was 0.21.1 before) to the latest on pip (2022.3.0) and I still get the same xarray error... otherwise my versions are all compatible with the pymc 4.0.0 release on pip. Does anyone else get this error? If not, what versions of xarray etc are you using? |
…ar imports and this is needed for tests to work
Tests look to be failing due to Jax not being installed ( |
We can add jax, numpyro, blackjax and any other necessary requirement for jax-based samplers to https://github.com/bambinos/bambi/blob/main/requirements-dev.txt |
I think we could have something like Below, you will need to add another line saying bambi/.github/workflows/test.yml Lines 35 to 36 in ecfdbf2
|
Agreed, that's cleaner. |
Please also ensure the optional dependencies in |
I've added the optional requirements files and hopefully done the setup.py changes @canyon289 requested correctly (all a bit new to me so I could well have done it wrong!). I wasn't sure what versions to specify in the file so I tried to find the equivalents in pymc to align with what they have... and was a bit surprised when I couldn't find any! Perhaps something similar should be added to pymc and then by depending on a specific pymc version this would flow naturally upstream to bambi? Another thing to note here is that if a user installs jax via this version I believe they won't get CUDA support by default - further downstream libraries like numpyro look to sort of copy the Jax installation instructions in their setup optional structure. I think the best solution overall would be for each part of the library hierarchy to depend on the correct optional install in the sub-library they depend on, e.g. bambi[gpu] would end up calling pymc[gpu] which would call numpyro[gpu] etc... perhaps this is a bit out of scope for this PR though as it requires a lot of co-ordination with other repos and this current solution is a reasonable intermediate step? |
OK it looks like the tests are failing for the same reason my local environment doesn't work which I've no idea how to fix! Anyone got any advice what I should try/do here? 2022-06-09T09:04:21.8572705Z bambi/models.py:265: in fit
2022-06-09T09:04:21.8572923Z return self.backend.run(
2022-06-09T09:04:21.8573156Z bambi/backend/pymc.py:91: in run
2022-06-09T09:04:21.8573373Z result = self._run_mcmc(
2022-06-09T09:04:21.8573593Z bambi/backend/pymc.py:288: in _run_mcmc
2022-06-09T09:04:21.8573880Z idata = self._clean_mcmc_results(idata, omit_offsets, include_mean)
2022-06-09T09:04:21.8574228Z bambi/backend/pymc.py:363: in _clean_mcmc_results
2022-06-09T09:04:21.8574631Z idata.posterior[intercept_name] -= np.dot(X.mean(0), coefs).reshape(shape)
2022-06-09T09:04:21.8575105Z /usr/share/miniconda/envs/test/lib/python3.8/site-packages/xarray/core/_typed_ops.py:290: in __isub__
2022-06-09T09:04:21.8575466Z return self._inplace_binary_op(other, operator.isub)
2022-06-09T09:04:21.8575943Z /usr/share/miniconda/envs/test/lib/python3.8/site-packages/xarray/core/dataarray.py:3121: in _inplace_binary_op
2022-06-09T09:04:21.8576273Z f(self.variable, other_variable)
2022-06-09T09:04:21.8576690Z /usr/share/miniconda/envs/test/lib/python3.8/site-packages/xarray/core/_typed_ops.py:480: in __isub__
2022-06-09T09:04:21.8577039Z return self._inplace_binary_op(other, operator.isub) |
# Conflicts: # bambi/tests/test_built_models.py
Codecov Report
@@ Coverage Diff @@
## main #526 +/- ##
==========================================
+ Coverage 86.69% 86.84% +0.14%
==========================================
Files 32 32
Lines 2586 2622 +36
==========================================
+ Hits 2242 2277 +35
- Misses 344 345 +1
Continue to review full report at Codecov.
|
**kwargs, | ||
) | ||
else: | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please raise a specific exception with a helpful message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah this was the code before I changed this function, it's just been moved around. To be honest I wondered about removing this whole error handling because I've seen pymc do the same thing internally anyway but I thought that might be out of scope for this PR - I'll do whatever is the consensus here 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If youre willing editing here would be helpful, but youre right if you just moved the code it can be out of scope! My ask is just open an issue ticket to track and reference this discussion :)
model.fit(method="nuts_blackjax", chain_method="vectorized") | ||
|
||
|
||
def test_regression_blackjax(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: These two tests test_regression_blackjax
and test_regression_nunpyro
could be parameterized to reduce amount of code that needs to be read or maintained
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @canyon289 here, but if you want @markgoodhead you can open an issue fix this later.
@markgoodhead thanks for doing this! this is a great capability add for bambi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you @markgoodhead
This is to address #522 and #525 inspired by @zwelitunyiswa's example
I decided to add a single new value to the fit() method which allows switching in of numpyro/blackjax samplers instead of the pymc default. I decided against some cpu/gpu flags because it's mostly decided by whatever Jax can find and the methods I saw to disable GPUs are quite hacky involving playing with your environment variables which I felt is out of scope for a library to be fiddling with so I've just noted this in the documentation instead.
I've tested the samplers locally and they work on one of my personal projects, but I'll try and knock up a simple example shortly which demonstrates them all.
One note: The PyMC 4 release blog post says:
So we should expect the implementation here to change pretty soon, so I think it's worth keeping the implementation in bambi simple so it's easy to port-over when this happens.