-
-
Notifications
You must be signed in to change notification settings - Fork 635
Description
🐞 bug report
Affected Rule
n/a
Is this a regression?
No. Indeed, there are several similar issues already open. See e.g. #2156, pytorch/pytorch#117350, pytorch/pytorch#101314. I also know that the venv/site-packages layout is still experimental and thus this is not really a bug. I still wanted to share the information in case someone finds it useful (and also to provide a log for myself to come back to this next time I have time to look into it).
Description
I tried my luck in installing the nvidia
CUDA libraries (nvidia-cuda-runtime-cu12 and friends) to see if I could get them to work out of the box with the new venv/site-packages layout (#2156). The TLDR is: no, they didn't work out of the box. I've added some data from my initial investigation below.
🔬 Minimal Reproduction
🔥 Exception or Error
Details
bazel run //:test
Starting local Bazel server (8.2.1) and connecting to it...
INFO: Analyzed target //:test (138 packages loaded, 9986 targets configured).
INFO: Found 1 target...
Target //:test up-to-date:
bazel-bin/test
INFO: Elapsed time: 34.576s, Critical Path: 23.99s
INFO: 22 processes: 63 action cache hit, 22 internal.
INFO: Build completed successfully, 22 total actions
INFO: Running command line: external/bazel_tools/tools/test/test-setup.sh ./test
exec ${PAGER:-/usr/bin/less} "$0" || exit 1
Executing tests from //:test
-----------------------------------------------------------------------------
ERROR:2025-09-01 19:24:47,302:jax._src.xla_bridge:487: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 201, in _version_check
version = get_version()
^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:81: operation cusparseGetProperty(MAJOR_VERSION, &major) failed: The cuSPARSE library was not found.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 485, in discover_pjrt_plugins
plugin_module.initialize()
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 328, in initialize
_check_cuda_versions(raise_on_first_error=True)
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 266, in _check_cuda_versions
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 205, in _version_check
raise RuntimeError(err_msg) from e
RuntimeError: Unable to load cuSPARSE. Is it installed?
WARNING:2025-09-01 19:24:47,311:jax._src.xla_bridge:864: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Hello, nvidia=<module 'nvidia' (namespace) from ['/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia']>!
Hello, jax=<module 'jax' from '/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/__init__.py'>!
mujoco=<module 'mujoco' from '/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/mujoco/__init__.py'>!
Traceback (most recent call last):
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test_stage2_bootstrap.py", line 474, in <module>
main()
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test_stage2_bootstrap.py", line 468, in main
_run_py_path(main_filename, args=sys.argv[1:])
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test_stage2_bootstrap.py", line 284, in _run_py_path
runpy.run_path(main_filename, run_name="__main__")
File "<frozen runpy>", line 287, in run_path
File "<frozen runpy>", line 98, in _run_module_code
File "<frozen runpy>", line 88, in _run_code
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/test.py", line 20, in <module>
main()
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/test.py", line 16, in main
print(f"{jax.devices('gpu')=}")
^^^^^^^^^^^^^^^^^^
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 1010, in devices
return get_backend(backend).devices()
^^^^^^^^^^^^^^^^^^^^
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 944, in get_backend
return _get_backend_uncached(platform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 925, in _get_backend_uncached
platform = canonicalize_platform(platform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 728, in canonicalize_platform
raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: cpu
🌍 Your Environment
Operating System:
$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description: Ubuntu 24.04.2 LTS
Release: 24.04
Codename: noble
Output of bazel version
:
$ bazel --version
bazel 8.2.1
Rules_python version:
Additional Information
I first diffed uv venv
-generated site-packages against a py_binary
target built with venvs_site_packages=yes
. The differences are minor, with just missing RECORD
files and different INSTALLER
files from the bazel environment. The INSTALLER
file diffs are expected. The missing RECORD
files I'm not sure of. They are missing from all the following packages:
Details
nvidia_cublas_cu12-12.9.1.4.dist-info.diff
nvidia_cuda_cupti_cu12-12.9.79.dist-info.diff
nvidia_cuda_nvcc_cu12-12.9.86.dist-info.diff
nvidia_cuda_nvrtc_cu12-12.9.86.dist-info.diff
nvidia_cuda_runtime_cu12-12.9.79.dist-info.diff
nvidia_cudnn_cu12-9.12.0.46.dist-info.diff
nvidia_cufft_cu12-11.4.1.4.dist-info.diff
nvidia_cusolver_cu12-11.7.5.82.dist-info.diff
nvidia_cusparse_cu12-12.5.10.65.dist-info.diff
nvidia_nccl_cu12-2.27.7.dist-info.diff
nvidia_nvjitlink_cu12-12.9.86.dist-info.diff
nvidia_nvshmem_cu12-3.3.24.dist-info.diff
Here's an example diff for one of them, nvidia_cublas_cu12-12.9.1.4.dist-info.diff
:
Details
diff -ruN ./.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER ./bazel-bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER
--- ./.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER 2025-09-01 18:19:24.134031041 +0000
+++ ./bazel-bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER 2025-08-28 20:28:52.097396422 +0000
@@ -1 +1 @@
-uv
\ No newline at end of file
+https://github.com/bazel-contrib/rules_python
\ No newline at end of file
diff -ruN ./.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD ./bazel-bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD
--- ./.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD 2025-09-01 18:19:24.141031054 +0000
+++ ./bazel-bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD 1970-01-01 00:00:00.000000000 +0000
@@ -1,16 +0,0 @@
-nvidia/cublas/include/cublas.h,sha256=a0lLqy-k47NuwyDjuueC3W0Mpc908MTU7o5sMJqE-1w,41246
-nvidia/cublas/include/cublasLt.h,sha256=Zo_7r-ZRHWkohyAf1jLotjx0KBGWYh7vV2m6fLM7TFo,104268
-nvidia/cublas/include/cublasXt.h,sha256=CW9dyXYGSUW1wEXrVVyhU6OxBK1PUvMoYdVGlQT7L9A,37380
-nvidia/cublas/include/cublas_api.h,sha256=1dtyy6TIQB-ymoRH5k5rnBJaMN4sPjhSQt5lf-TpbL4,375691
-nvidia/cublas/include/cublas_v2.h,sha256=qxMdB5jb97luEfw61LEAB-Wlr8A9DLBvO4rRypDCNKw,15460
-nvidia/cublas/include/nvblas.h,sha256=dXCLR-2oUiJFzLsDtIAK09m42ct4G0HWdYzBUuDPXpc,23341
-nvidia/cublas/lib/libcublas.so.12,sha256=1jRW2U8bjiQIqKlTjIdA60dMnrpIqE9CCxXi4Q-80Eg,105140976
-nvidia/cublas/lib/libcublasLt.so.12,sha256=lt30iRyk9O0UdZZng6kYXfKr1kqBw8xxO5bQNmlvxUE,749205904
-nvidia/cublas/lib/libnvblas.so.12,sha256=9BKnB5ZDn2pEtI5RHlJ9Ht90GOcz7bpFGJJpqA6WIxE,753824
-nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER,sha256=5hhM4Q4mYTT9z6QB6PGpUAW81PGNFrYrdXMj4oM_6ak,2
-nvidia_cublas_cu12-12.9.1.4.dist-info/METADATA,sha256=9QYFRxRORca_jQAL9Q8kKFFOzH7d0swlXbn4mDUr0UY,1707
-nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD,,
-nvidia_cublas_cu12-12.9.1.4.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
-nvidia_cublas_cu12-12.9.1.4.dist-info/WHEEL,sha256=pQkYcE_zO8O3IUE-PXGJ2qj02_9aLe-AqqPPB3Nn1Xg,109
-nvidia_cublas_cu12-12.9.1.4.dist-info/licenses/License.txt,sha256=rW9YU_ugyg0VnQ9Y1JrkmDDC-Mk_epJki5zpCttMbM0,59262
-nvidia_cublas_cu12-12.9.1.4.dist-info/top_level.txt,sha256=fTkAtiFuL16nUrB9ytDDtpytz2t0B4NvYTnRzwAhO14,7
The nvidia
package, which is holding the actual implementation files, is identical between uv venv
and rules_python
venv, modulo the files being symlinked in the latter.
Despite the near-identical site-packages structure, the py_binary
target is unable to load the shared nvidia
libraries out of the box. After poking around a bit with strace
, I see that libcusparse.so.12
exists but loading it fails because its transitive dependency of libnvJitLink.so.12
cannot be found1. My guess is that this failure is caused by the dynamic linker resolving paths from the real file locations, which are different from the symlinked directory structure in Bazel-generated site-packages, breaking the run path lookups inside the C libraries.
As pointed out by @aignas (https://bazelbuild.slack.com/archives/CA306CEV6/p1756127831137709?thread_ts=1756088512.734139&cid=CA306CEV6), it would be interesting to debug this by copying the files instead of symlinking them to verify if it's really the symlinking that breaks things.
Questions and Things to Do
- Is this something that can/could be fixed from
rules_python
side? Or should thenvidia
packages do something different to make the linker lookups work with symlinks? - Is my diagnosis that the files exist but the symlinks make the linker upset even correct? Can there be something else here? Should try raw copies instead of symlinks.
Workaround for Jax
For what it's worth, I'm using Jax, not Pytorch as most others seem to do. I couldn't find any existing examples of patching Jax to preload the packages. Here's how I got around this issue:
Details
modified MODULE.bazel
@@ -67,6 +67,15 @@ pip.parse(
"requirements.macos_arm64.txt": "osx_*",
},
)
+pip.override(
+ # file = "jax_cuda12_pjrt-0.7.1-py3-none-manylinux_2_27_x86_64.whl",
+ file = "jax_cuda12_pjrt-0.7.0-py3-none-manylinux2014_x86_64.whl",
+ patch_strip = 0,
+ patches = [
+ "//:jax_cuda12_pjrt.patch",
+ "//:jax_cuda12_pjrt_record.patch",
+ ],
+)
use_repo(pip, "pypi")
pip_onshape = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
new file jax_cuda12_pjrt.patch
@@ -0,0 +1,10 @@
+--- jax_plugins/xla_cuda12/__init__.py
++++ jax_plugins/xla_cuda12/__init__.py
+@@ -123,6 +123,7 @@
+ We prefer the Python packages, if present. If not, we fall back to loading
+ them from LD_LIBRARY_PATH. By loading the libraries here, later lookups will
+ find these copies."""
++ _load("nvjitlink", ["libnvJitLink.so.12"])
+ _load("cuda_runtime", ["libcudart.so.12"])
+ _load("cu13", ["libcudart.so.13"])
+ # cuda_nvrtc isn't directly a dependency of JAX, but CUDNN appears to need it
new file jax_cuda12_pjrt_record.patch
@@ -0,0 +1,8 @@
+--- jax_cuda12_pjrt-0.7.0.dist-info/RECORD
++++ jax_cuda12_pjrt-0.7.0.dist-info/RECORD
+@@ -1,4 +1,4 @@
+-jax_plugins/xla_cuda12/__init__.py,sha256=5MWu6cM-YrC38VX6Qkay5IdWRn8Ekp7PuzW2LJoikaY,13182
++jax_plugins/xla_cuda12/__init__.py,sha256=cPm09fPeZlIf0ypdDr5g_1-rV3Ke9097eVW__lKDl48,13227
+ jax_plugins/xla_cuda12/version.py,sha256=dXnNpN9dnyGzJlW3C2fWP189bYNDvP0dHRXriwNP0bY,6733
+ jax_plugins/xla_cuda12/xla_cuda_plugin.so,sha256=VdGXGhi7-IMH7fAIIPIrRKLrl0AZwC2UIs5jQURXOyE,329428392
+ jax_cuda12_pjrt-0.7.0.dist-info/METADATA,sha256=KL0Di3q1g9gILfBxZDrpz9IcW5PSldd1Htwp8ddN1_w,579
Footnotes
-
This seems like a common issue and can be worked around by preloading the packages. See e.g. https://gist.github.com/qxcv/183c2d6cd81f7028b802b232d6a9dd62, https://github.com/pytorch/pytorch/pull/137059. What I'm trying to understand here is how to get these work out-of-the-box without having to patch the packages. ↩