-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jnp.cumsum is incorrect when sharding over summing axis on GPUs #21403
Labels
bug
Something isn't working
Comments
It appears this bug is already fixed at head by openxla/xla#12476 Try with a nightly jaxlib and jax release (https://jax.readthedocs.io/en/latest/installation.html#jax-nightly-installation)? |
Thanks for the clear reproduction! |
hawkinsp
added a commit
to hawkinsp/jax
that referenced
this issue
May 24, 2024
No problem - just tested nightly and seems like it works. Thanks! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
I have this simple script below that tests
jnp.cumsum
when sharding along the same axis it is summing over. Tested on a machine with 8 40GB A100s.On TPUs, this code works completely fine for different Mesh shardings
On GPUs, the code produces incorrect output when args.sp > 1 (sharding over the summing axis)
Works for no sharding (sp = 1)
Breaks for sharding (sp > 1 = 2, 4, 8 - up to 8 GPUs)
One other bizarre thing is that if you insert
mask = jnp.ones_like(mask)
like so:The code works again on GPUs.
Any clue what might be going on?
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: