Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ptxas : Unsupported .version 8.4; current version is '8.2' with jaxlib 0.4.34 #25344

Closed
ccoulombe opened this issue Dec 9, 2024 · 5 comments
Closed
Labels
bug Something isn't working

Comments

@ccoulombe
Copy link

ccoulombe commented Dec 9, 2024

Description

Running the inference stage from Alphafold 3, some users are running into the error

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CustomCall failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-ng30101.narval.calcul.quebec-48db254e-784897-628d6ac249c58, line 5; fatal   : Unsupported .version 8.4; current version is '8.2'
ptxas fatal   : Ptx assembly aborted due to errors

It is my understanding, according to the jax documentation that jaxlib 0.4.34 was built with cuda 12.3 but is compatible with 12.1+.

Also, according to nvida ptxas documentation, cuda 12.3 is ISA 8.3, yet the error says 8.4 which corresponds to cuda 12.4.

Hence, which cuda was used to actually built jaxlib and its plugins ?

Possible alternative solution: use cuda 12.4+

System info (python version, jaxlib version, accelerator, etc.)

Jaxlib : v0.4.34
From pypi, patched and tested to work with our installed cuda 12.2
Cuda version: 12.2
GPU device: A100
nvidia driver: 550.127.08

@ccoulombe ccoulombe added the bug Something isn't working label Dec 9, 2024
@ccoulombe
Copy link
Author

A similar issue on AF3 : google-deepmind/alphafold3#68

@Rick0827
Copy link

Rick0827 commented Dec 9, 2024

I think it's a problem in jax-triton, which I have encountered. At that time, i modify jax-triton in triton_lib.py with option ptx_version=82 (a verison the current compiler supports) to solve. Maybe it will work in this case.

@ccoulombe
Copy link
Author

ccoulombe commented Dec 9, 2024

@Rick0827 Yes, you are right!! Thanks

I can reproduce this using the add example of jax-triton. Passing the debug option, I can see that it returns ISA 8.4:

$ python add.py

...
//
// Generated by LLVM NVPTX Back-End
//

.version 8.4
.target sm_80
.address_size 64

...

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CustomCall failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-ng30103.narval.calcul.quebec-5adaa54d-2604203-628d86ee2519d, line 5; fatal   : Unsupported .version 8.4; current version is '8.2'

And adding "ptx_version": 82, to the opts dict for triton_lib allows it work:

//
// Generated by LLVM NVPTX Back-End
//

.version 8.2
.target sm_80
.address_size 64

This originally comes from the triton-3.1.0 installed, which bundle llvm+nvptx plugin and was built with cuda 12.4:

$ $VIRTUAL_ENV/lib/python3.11/site-packages/triton/backends/nvidia/bin/ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:15:50_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0

Solutions:

  • But at least I can patch jax-triton to use the current available ptx.
  • An alternative solution could be to build triton with the installed cuda.
  • Also, removing the ptxas binary from the triton wheel to use the already available ptxas seems to work!
  • Setting TRITON_PTXAS_PATH=$CUDA_HOME/bin/ptxas works as well

Thanks

@ccoulombe
Copy link
Author

Closing as solutions works and more a jax-triton/triton issue.

@yifan-hou
Copy link

I got the same error when using JAX with cuda 12.2. I read @ccoulombe 's solution but I don't even have jax-triton installed, so this might still be an unsolved issue.

How to reproduce

The error can be reproduced in a clean env:

mamba create -n test python=3.12
mamba activate test
pip install -U "jax[cuda12]"

Then in python, the following code

import jax.numpy as jnp
x = jnp.arange(5.0)

caused the following error:

Error message

>>> x = jnp.arange(5.0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/yifanhou/miniforge3/envs/test/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6767, in arange
    output = _arange(start, stop=stop, step=step, dtype=dtype)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yifanhou/miniforge3/envs/test/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6807, in _arange
    return lax.iota(dtype, start)  # type: ignore[arg-type]
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yifanhou/miniforge3/envs/test/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1782, in iota
    return broadcasted_iota(dtype, (size,), 0)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yifanhou/miniforge3/envs/test/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1796, in broadcasted_iota
    return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yifanhou/miniforge3/envs/test/lib/python3.12/site-packages/jax/_src/core.py", line 463, in bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yifanhou/miniforge3/envs/test/lib/python3.12/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yifanhou/miniforge3/envs/test/lib/python3.12/site-packages/jax/_src/core.py", line 941, in process_primitive
    return primitive.impl(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yifanhou/miniforge3/envs/test/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-yifan-LambdaV1-74fb99cb76e95530-12139-62a776876b297, line 5; fatal   : Unsupported .version 8.3; current version is '8.2'
ptxas fatal   : Ptx assembly aborted due to errors

Also my LD_LIBRARY_PATH is empty.

System info

GPU: RTX4090
Driver Version: 535.161.07
CUDA Version: 12.2

The most bizarre thing to me is that I cannot find the complained ptxas version 8.2 or 8.3 anywhere in my system. The version that comes with cuda install or the version in the virtual env are both v12.xx. Any idea what I can do to further debug?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants