Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

If CUDA 12.1 is installed, pip-installed ptxas binary is not used and jax throws an error #25718

Open
takkyu2 opened this issue Jan 3, 2025 · 3 comments
Labels
bug Something isn't working

Comments

@takkyu2
Copy link

takkyu2 commented Jan 3, 2025

Description

Please feel free to close this issue in case this is an expected behavior.
In case this is expected, it would be great if there would be an easy way to fix it from our (users') side other than installing newer CUDA version.

Summary

If CUDA version 12.1 is installed to the system and ptxas is already in the system PATH:

ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:30:12_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0

After installing jax thorugh pip,

python -m venv venv
source venv/bin/activate
pip install -U "jax[cuda12]"

The system-installed ptxas binary (instead of pip-installed one) is used and jax throws an error:

import jax
jax.numpy.zeros(3)
Full Error Log
E0103 16:41:38.489316 2217885 ptx_compiler_helpers.cc:87] *** WARNING *** Invoking ptxas with version 12.1.66, which corresponds to a CUDA version <=12.6.2. CUDA version
s 12.x.y up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 6149, in zeros
  return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1752, in full
  return broadcast(fill_value, shape)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1244, in broadcast
  return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims,
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1278, in broadcast_in_dim
  return broadcast_in_dim_p.bind(
         ^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 463, in bind
  return self.bind_with_trace(prev_trace, args, params)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 468, in bind_with_trace
  return trace.process_primitive(self, args, params)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/core.py", line 941, in process_primitive
  return primitive.impl(*args, **params)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
  outs = fun(*args)
         ^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  return fun(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 337, in cache_miss
  pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper
  out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1672, in _pjit_call_impl_python
  ).compile()
    ^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2415, in compile
  executable = UnloadedMeshExecutable.from_hlo(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2923, in from_hlo
  xla_executable = _cached_compilation(
                   ^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 2729, in _cached_compilation
  xla_executable = compiler.compile_or_get_cached(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 452, in compile_or_get_cached
  return _compile_and_write_cache(
         ^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 653, in _compile_and_write_cache
  executable = backend_compile(
               ^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
  return func(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 309, in backend_compile
  raise e
File "/home/test/venv/lib/python3.12/site-packages/jax/_src/compiler.py", line 303, in backend_compile
  return backend.compile(built_c, compile_options=options)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-###-8aa5641830c7ce6-2217885
-62acff3e14c12, line 5; fatal   : Unsupported .version 8.3; current version is '8.1'
ptxas fatal   : Ptx assembly aborted due to errors

I set LD_LIBRARY_PATH to be empty, but still encountered this error.

Related Issues

#25344: About the same error, but in my case I don't have triton installed; I think this is a separate issue.
#18578: On ptxas binary priority issue.

Workaround

We can manually prepend the pip-installed ptxas binary path to PATH to avoid this error:

export PATH=$(python -c "import site; print(site.getsitepackages()[0] + '/nvidia/cuda_nvcc/bin')"):$PATH

System info (python version, jaxlib version, accelerator, etc.)

CUDA version: 12.1

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.1
python: 3.12.5 (main, Aug 19 2024, 18:21:17) [GCC 9.4.0]
device info: NVIDIA H100 80GB HBM3-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='###', release='###', version='#18-Ubuntu SMP Fri Jul 26 14:21:24 UTC 2024', machine='x86_64')
@takkyu2 takkyu2 added the bug Something isn't working label Jan 3, 2025
@MuhammadHakami
Copy link

I would also suggest installing ptxas from conda with conda install cuda -c nvidia
as ptxas is part of the cuda toolkit. installing two ptxas is not ideal but will work as temp solution for now.
Thanks @takkyu2 for the detailed comment.

@dkarkada
Copy link

dkarkada commented Jan 6, 2025

I'm running into the same issue. The two workarounds mentioned here didn't work for me. I tried downgrading instead:

pip install --upgrade --force-reinstall -v "jax[cuda12]==0.4.31"

This fixed the problem, but I was then getting the following warning every time I loaded jax:

The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.85). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

I didn't know where jax was loading the PTX installer from.. apparently not from the cuda 12.4 installation that comes with jaxlib 0.4.31. So, I tried simply downgrading my system's cuda version to 12.4 by using the installer here: https://developer.nvidia.com/cuda-12-4-0-download-archive

This didn't do anything. Just for kicks and giggles, I tried upgrading jax/jaxlib back to the most recent version:

pip install --upgrade -v jax[cuda12]

Somehow, now everything works. I don't know which of the aforementioned steps fixed the problem.

My jax version is now jax-0.4.38
My nvcc --version is

Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0

Sorry that this isn't the most organized comment, but maybe it'll help someone.

@AdamScherlis
Copy link

Similar problem, error was Unsupported .version 8.3; current version is '7.8' (with jax 0.4.38 and CUDA 12.2, driver 535.216.01)

I tried downgrading to 0.4.31, worked but disabled parallel compilation as seen above

Back to 0.4.38, same error as before

Tried conda install cuda -c nvidia, now everything works fine, thanks @MuhammadHakami for the workaround!

incidentally, ptxas wasn't available on the command line before running that last one but it is now.

ptxas --version:

ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Oct_29_23:47:06_PDT_2024
Cuda compilation tools, release 12.6, V12.6.85
Build cuda_12.6.r12.6/compiler.35059454_0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants