-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong determinant results for large batch #24843
Comments
Thanks for the report – I also can repro this on a GPU but not CPU. It looks like the issue has something to do with failure of batched LU decomposition in this domain (LU is used to compute import jax.numpy as jnp
import jax.random as jr
from jax import lax
import matplotlib.pyplot as plt
a = jr.normal(jr.key(0), (1500000, 20, 20))
d = jnp.linalg.det(a)
lu, _, _ = lax.linalg.lu(a)
fig, ax = plt.subplots(2, sharex=True)
ax[0].plot(d)
ax[0].set_ylabel('det(a)')
ax[1].plot(lu.std(axis=[1, 2]))
ax[1].set_ylabel('std(lu(a))'); The block of high standard deviation values also suggests something else is amiss with outputs in other parts of the batch. |
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695490133
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695490133
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695490133
Thanks for this bug report and for the simple reproduction! I've fixed (at least part of) this issue over in #24846, which should be available in a day or two in the nightly version of jaxlib. I've added a test for the failures towards the end of the batch, but I'm interested to understand if there are other numerical issues here that we can see in Jake's plots. I'll take a closer look ASAP! |
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695694648
@jakevdp @dfm, Thanks for checking and fixing the bug. Regarding the high standard deviation values, I also found something weird. It can be reproduced by the following code
But it works normally without Probably it's also related to |
Whoa! Yeah, I expect that is the same nasty overflow bug, but I'll double check. Thanks for digging! |
I have confirmed that #24846 also fixes this issue. Thanks again for the bug report! For reference, I introduced this bug in #23054, so this has been a problem for JAX versions 0.4.33-0.4.35. The next nightly and next official release will include this fix! I'm going to close this issue as fixed, but please comment or open a new issue if you run into any further problems. |
As reported in jax-ml#24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695694648
As reported in jax-ml#24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695694648
Description
The bug can be reproduced using the following code:
The values at the end are obviously incorrect.
I have tested different batch and matrix sizes. The bug only happens when the full matrix size (in the example
1500000*20*20
) exceeds2**29
. As also shown in the figure, the determinant values are wrong for batch index >2**29 / (20*20)
.I also tested on different devices. This bug happens on different types of GPUs while not on the CPU.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: