-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
If CUDA 12.1 is installed, pip-installed ptxas binary is not used and jax throws an error #25718
Comments
I would also suggest installing ptxas from conda with |
I'm running into the same issue. The two workarounds mentioned here didn't work for me. I tried downgrading instead:
This fixed the problem, but I was then getting the following warning every time I loaded jax:
I didn't know where jax was loading the PTX installer from.. apparently not from the cuda 12.4 installation that comes with jaxlib 0.4.31. So, I tried simply downgrading my system's cuda version to 12.4 by using the installer here: https://developer.nvidia.com/cuda-12-4-0-download-archive This didn't do anything. Just for kicks and giggles, I tried upgrading jax/jaxlib back to the most recent version:
Somehow, now everything works. I don't know which of the aforementioned steps fixed the problem. My jax version is now jax-0.4.38
Sorry that this isn't the most organized comment, but maybe it'll help someone. |
Similar problem, error was I tried downgrading to 0.4.31, worked but disabled parallel compilation as seen above Back to 0.4.38, same error as before Tried incidentally,
|
Description
Please feel free to close this issue in case this is an expected behavior.
In case this is expected, it would be great if there would be an easy way to fix it from our (users') side other than installing newer CUDA version.
Summary
If CUDA version 12.1 is installed to the system and ptxas is already in the system PATH:
After installing jax thorugh pip,
The system-installed ptxas binary (instead of pip-installed one) is used and jax throws an error:
Full Error Log
I set
LD_LIBRARY_PATH
to be empty, but still encountered this error.Related Issues
#25344: About the same error, but in my case I don't have triton installed; I think this is a separate issue.
#18578: On ptxas binary priority issue.
Workaround
We can manually prepend the pip-installed ptxas binary path to PATH to avoid this error:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: