Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.cumsum is incorrect when sharding over summing axis on GPUs #21403

Closed
wilson1yan opened this issue May 23, 2024 · 3 comments · Fixed by #21423
Closed

jnp.cumsum is incorrect when sharding over summing axis on GPUs #21403

wilson1yan opened this issue May 23, 2024 · 3 comments · Fixed by #21423
Labels
bug Something isn't working

Comments

@wilson1yan
Copy link

wilson1yan commented May 23, 2024

Description

I have this simple script below that tests jnp.cumsum when sharding along the same axis it is summing over. Tested on a machine with 8 40GB A100s.

# test.py
import argparse
import jax
import jax.numpy as jnp
import jax.lax
from jax.sharding import PartitionSpec as PS, Mesh
from jax.experimental.pjit import pjit
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument('--sp', type=int, default=2)
args = parser.parse_args()

def f(mask):
    mask = jax.lax.with_sharding_constraint(mask, PS('dp', 'sp'))
    idxs = jnp.cumsum(mask, axis=-1)
    return idxs
f = pjit(f, in_shardings=PS(), out_shardings=PS())

mesh_shape = (jax.device_count() // args.sp, args.sp)
print(f"Mesh shape: {mesh_shape}")
mesh = Mesh(np.array(jax.devices()).reshape(mesh_shape), ('dp', 'sp'))

B, L = 8, 1024
with mesh:
    mask = np.ones((B, L), dtype=np.int32)
    out = jax.device_get(f(mask))
    expected = np.arange(L, dtype=np.int32)[None].repeat(B, axis=0) + 1
    print('Output:', out)
    print('Expected:', expected)
    assert np.allclose(out, expected)
    print('Passed')

On TPUs, this code works completely fine for different Mesh shardings

On GPUs, the code produces incorrect output when args.sp > 1 (sharding over the summing axis)

Works for no sharding (sp = 1)

> python test.py --sp 1
Mesh shape: (8, 1)
2024-05-23 18:27:20.884869: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow
down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Output: [[   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 ...
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]]
Expected: [[   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 ...
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]]
Passed

Breaks for sharding (sp > 1 = 2, 4, 8 - up to 8 GPUs)

> python test.py --sp 2
Mesh shape: (4, 2)
Output: [[   1    2    4 ... 1278 1279 1024]
 [   1    2    4 ... 1278 1279 1024]
 [   1    2    4 ... 1278 1279 1024]
 ...
 [   1    2    4 ... 1278 1279 1024]
 [   1    2    4 ... 1278 1279 1024]
 [   1    2    4 ... 1278 1279 1024]]
Expected: [[   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 ...
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]]
Traceback (most recent call last):
  File "/home/wilsonyan/test.py", line 31, in <module>
    assert np.allclose(out, expected)
AssertionError

One other bizarre thing is that if you insert mask = jnp.ones_like(mask) like so:

...
def f(mask):
    mask = jax.lax.with_sharding_constraint(mask, PS('dp', 'sp'))
    mask = jnp.ones_like(mask) # <<< new inserted line
    idxs = jnp.cumsum(mask, axis=-1)
    return idxs
...

The code works again on GPUs.

Any clue what might be going on?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', release='6.8.0-1007-gcp', version='#7-Ubuntu SMP Sat Apr 20 00:58:31 UTC 2024', machine='x86_64')

$ nvidia-smi
Thu May 23 18:28:53 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67                 Driver Version: 550.67         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             59W /  400W |     425MiB /  40960MiB |      3%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   33C    P0             69W /  400W |     425MiB /  40960MiB |      2%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          Off |   00000000:00:06.0 Off |                    0 |
| N/A   31C    P0             61W /  400W |     425MiB /  40960MiB |      2%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          Off |   00000000:00:07.0 Off |                    0 |
| N/A   32C    P0             59W /  400W |     425MiB /  40960MiB |      2%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A100-SXM4-40GB          Off |   00000000:80:00.0 Off |                    0 |
| N/A   32C    P0             61W /  400W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A100-SXM4-40GB          Off |   00000000:80:01.0 Off |                    0 |
| N/A   33C    P0             59W /  400W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A100-SXM4-40GB          Off |   00000000:80:02.0 Off |                    0 |
| N/A   32C    P0             60W /  400W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A100-SXM4-40GB          Off |   00000000:80:03.0 Off |                    0 |
| N/A   34C    P0             66W /  400W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    345969      C   python                                        416MiB |
|    1   N/A  N/A    345969      C   python                                        416MiB |
|    2   N/A  N/A    345969      C   python                                        416MiB |
|    3   N/A  N/A    345969      C   python                                        416MiB |
|    4   N/A  N/A    345969      C   python                                        416MiB |
|    5   N/A  N/A    345969      C   python                                        416MiB |
|    6   N/A  N/A    345969      C   python                                        416MiB |
|    7   N/A  N/A    345969      C   python                                        416MiB |
+-----------------------------------------------------------------------------------------+
@wilson1yan wilson1yan added the bug Something isn't working label May 23, 2024
@hawkinsp
Copy link
Collaborator

hawkinsp commented May 24, 2024

It appears this bug is already fixed at head by openxla/xla#12476

Try with a nightly jaxlib and jax release (https://jax.readthedocs.io/en/latest/installation.html#jax-nightly-installation)?

@hawkinsp
Copy link
Collaborator

Thanks for the clear reproduction!

@wilson1yan
Copy link
Author

No problem - just tested nightly and seems like it works. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants