Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong determinant results for large batch #24843

Closed
ChenAo-Phys opened this issue Nov 11, 2024 · 5 comments
Closed

Wrong determinant results for large batch #24843

ChenAo-Phys opened this issue Nov 11, 2024 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@ChenAo-Phys
Copy link

Description

The bug can be reproduced using the following code:

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

a = jr.normal(jr.key(0), (1500000, 20, 20))
d = jnp.linalg.det(a)
plt.plot(d)

det

The values at the end are obviously incorrect.

I have tested different batch and matrix sizes. The bug only happens when the full matrix size (in the example 1500000*20*20) exceeds 2**29. As also shown in the figure, the determinant values are wrong for batch index > 2**29 / (20*20).

I also tested on different devices. This bug happens on different types of GPUs while not on the CPU.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.0.2
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: NVIDIA A100 80GB PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='alcc145', release='5.15.0-94-generic', version='#104-Ubuntu SMP Tue Jan 9 15:25:40 UTC 2024', machine='x86_64')
@ChenAo-Phys ChenAo-Phys added the bug Something isn't working label Nov 11, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 11, 2024

Thanks for the report – I also can repro this on a GPU but not CPU. It looks like the issue has something to do with failure of batched LU decomposition in this domain (LU is used to compute det) – you can see how the statistics of this distribution change where the erroneous results crop up:

import jax.numpy as jnp
import jax.random as jr
from jax import lax
import matplotlib.pyplot as plt

a = jr.normal(jr.key(0), (1500000, 20, 20))
d = jnp.linalg.det(a)
lu, _, _ = lax.linalg.lu(a)

fig, ax = plt.subplots(2, sharex=True)
ax[0].plot(d)
ax[0].set_ylabel('det(a)')
ax[1].plot(lu.std(axis=[1, 2]))
ax[1].set_ylabel('std(lu(a))');

download

The block of high standard deviation values also suggests something else is amiss with outputs in other parts of the batch.

@dfm dfm self-assigned this Nov 11, 2024
copybara-service bot pushed a commit that referenced this issue Nov 11, 2024
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695490133
copybara-service bot pushed a commit that referenced this issue Nov 11, 2024
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695490133
copybara-service bot pushed a commit that referenced this issue Nov 12, 2024
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695490133
@dfm
Copy link
Collaborator

dfm commented Nov 12, 2024

Thanks for this bug report and for the simple reproduction! I've fixed (at least part of) this issue over in #24846, which should be available in a day or two in the nightly version of jaxlib. I've added a test for the failures towards the end of the batch, but I'm interested to understand if there are other numerical issues here that we can see in Jake's plots. I'll take a closer look ASAP!

copybara-service bot pushed a commit that referenced this issue Nov 12, 2024
As reported in #24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
@ChenAo-Phys
Copy link
Author

@jakevdp @dfm, Thanks for checking and fixing the bug. Regarding the high standard deviation values, I also found something weird. It can be reproduced by the following code

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

a = jr.normal(jr.key(0), (1500000, 20, 20))
d = jnp.linalg.det(a)
std = jnp.std(a, axis=(1, 2))
plt.plot(std)
plt.show()

std

But it works normally without d = jnp.linalg.det(a)
std1

Probably it's also related to lax.linalg.lu.

@dfm
Copy link
Collaborator

dfm commented Nov 12, 2024

Whoa! Yeah, I expect that is the same nasty overflow bug, but I'll double check. Thanks for digging!

@dfm
Copy link
Collaborator

dfm commented Nov 12, 2024

I have confirmed that #24846 also fixes this issue. Thanks again for the bug report!

For reference, I introduced this bug in #23054, so this has been a problem for JAX versions 0.4.33-0.4.35. The next nightly and next official release will include this fix!

I'm going to close this issue as fixed, but please comment or open a new issue if you run into any further problems.

@dfm dfm closed this as completed Nov 12, 2024
yliu120 pushed a commit to yliu120/jax that referenced this issue Nov 16, 2024
As reported in jax-ml#24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
barnesjoseph pushed a commit to barnesjoseph/jax that referenced this issue Nov 21, 2024
As reported in jax-ml#24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants