Skip to content

Attempt to install nvidia packages with the new site-packages layout #3228

@hartikainen

Description

@hartikainen

🐞 bug report

Affected Rule

n/a

Is this a regression?

No. Indeed, there are several similar issues already open. See e.g. #2156, pytorch/pytorch#117350, pytorch/pytorch#101314. I also know that the venv/site-packages layout is still experimental and thus this is not really a bug. I still wanted to share the information in case someone finds it useful (and also to provide a log for myself to come back to this next time I have time to look into it).

Description

I tried my luck in installing the nvidia CUDA libraries (nvidia-cuda-runtime-cu12 and friends) to see if I could get them to work out of the box with the new venv/site-packages layout (#2156). The TLDR is: no, they didn't work out of the box. I've added some data from my initial investigation below.

🔬 Minimal Reproduction

e2d73e2

🔥 Exception or Error

Details
bazel run //:test
Starting local Bazel server (8.2.1) and connecting to it...
INFO: Analyzed target //:test (138 packages loaded, 9986 targets configured).
INFO: Found 1 target...
Target //:test up-to-date:
  bazel-bin/test
INFO: Elapsed time: 34.576s, Critical Path: 23.99s
INFO: 22 processes: 63 action cache hit, 22 internal.
INFO: Build completed successfully, 22 total actions
INFO: Running command line: external/bazel_tools/tools/test/test-setup.sh ./test
exec ${PAGER:-/usr/bin/less} "$0" || exit 1
Executing tests from //:test
-----------------------------------------------------------------------------
ERROR:2025-09-01 19:24:47,302:jax._src.xla_bridge:487: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 201, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:81: operation cusparseGetProperty(MAJOR_VERSION, &major) failed: The cuSPARSE library was not found.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 485, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 328, in initialize
    _check_cuda_versions(raise_on_first_error=True)
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 266, in _check_cuda_versions
    _version_check("cuSPARSE", cuda_versions.cusparse_get_version,
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 205, in _version_check
    raise RuntimeError(err_msg) from e
RuntimeError: Unable to load cuSPARSE. Is it installed?
WARNING:2025-09-01 19:24:47,311:jax._src.xla_bridge:864: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Hello, nvidia=<module 'nvidia' (namespace) from ['/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia']>!
Hello, jax=<module 'jax' from '/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/__init__.py'>!
mujoco=<module 'mujoco' from '/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/mujoco/__init__.py'>!
Traceback (most recent call last):
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test_stage2_bootstrap.py", line 474, in <module>
    main()
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test_stage2_bootstrap.py", line 468, in main
    _run_py_path(main_filename, args=sys.argv[1:])
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test_stage2_bootstrap.py", line 284, in _run_py_path
    runpy.run_path(main_filename, run_name="__main__")
  File "<frozen runpy>", line 287, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/test.py", line 20, in <module>
    main()
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/test.py", line 16, in main
    print(f"{jax.devices('gpu')=}")
             ^^^^^^^^^^^^^^^^^^
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 1010, in devices
    return get_backend(backend).devices()
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 944, in get_backend
    return _get_backend_uncached(platform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 925, in _get_backend_uncached
    platform = canonicalize_platform(platform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.cache/bazel/_bazel_user/78efd52eda40769e82a38d778950bc83/execroot/_main/bazel-out/k8-fastbuild/bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 728, in canonicalize_platform
    raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: cpu

🌍 Your Environment

Operating System:

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 24.04.2 LTS
Release:        24.04
Codename:       noble

Output of bazel version:

$ bazel --version
bazel 8.2.1

Rules_python version:

4a422b0

Additional Information

I first diffed uv venv-generated site-packages against a py_binary target built with venvs_site_packages=yes. The differences are minor, with just missing RECORD files and different INSTALLER files from the bazel environment. The INSTALLER file diffs are expected. The missing RECORD files I'm not sure of. They are missing from all the following packages:

Details
nvidia_cublas_cu12-12.9.1.4.dist-info.diff
nvidia_cuda_cupti_cu12-12.9.79.dist-info.diff
nvidia_cuda_nvcc_cu12-12.9.86.dist-info.diff
nvidia_cuda_nvrtc_cu12-12.9.86.dist-info.diff
nvidia_cuda_runtime_cu12-12.9.79.dist-info.diff
nvidia_cudnn_cu12-9.12.0.46.dist-info.diff
nvidia_cufft_cu12-11.4.1.4.dist-info.diff
nvidia_cusolver_cu12-11.7.5.82.dist-info.diff
nvidia_cusparse_cu12-12.5.10.65.dist-info.diff
nvidia_nccl_cu12-2.27.7.dist-info.diff
nvidia_nvjitlink_cu12-12.9.86.dist-info.diff
nvidia_nvshmem_cu12-3.3.24.dist-info.diff

Here's an example diff for one of them, nvidia_cublas_cu12-12.9.1.4.dist-info.diff:

Details
  diff -ruN ./.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER ./bazel-bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER
--- ./.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER	2025-09-01 18:19:24.134031041 +0000
+++ ./bazel-bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER	2025-08-28 20:28:52.097396422 +0000
@@ -1 +1 @@
-uv
\ No newline at end of file
+https://github.com/bazel-contrib/rules_python
\ No newline at end of file
diff -ruN ./.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD ./bazel-bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD
--- ./.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD	2025-09-01 18:19:24.141031054 +0000
+++ ./bazel-bin/test.runfiles/_main/_test.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD	1970-01-01 00:00:00.000000000 +0000
@@ -1,16 +0,0 @@
-nvidia/cublas/include/cublas.h,sha256=a0lLqy-k47NuwyDjuueC3W0Mpc908MTU7o5sMJqE-1w,41246
-nvidia/cublas/include/cublasLt.h,sha256=Zo_7r-ZRHWkohyAf1jLotjx0KBGWYh7vV2m6fLM7TFo,104268
-nvidia/cublas/include/cublasXt.h,sha256=CW9dyXYGSUW1wEXrVVyhU6OxBK1PUvMoYdVGlQT7L9A,37380
-nvidia/cublas/include/cublas_api.h,sha256=1dtyy6TIQB-ymoRH5k5rnBJaMN4sPjhSQt5lf-TpbL4,375691
-nvidia/cublas/include/cublas_v2.h,sha256=qxMdB5jb97luEfw61LEAB-Wlr8A9DLBvO4rRypDCNKw,15460
-nvidia/cublas/include/nvblas.h,sha256=dXCLR-2oUiJFzLsDtIAK09m42ct4G0HWdYzBUuDPXpc,23341
-nvidia/cublas/lib/libcublas.so.12,sha256=1jRW2U8bjiQIqKlTjIdA60dMnrpIqE9CCxXi4Q-80Eg,105140976
-nvidia/cublas/lib/libcublasLt.so.12,sha256=lt30iRyk9O0UdZZng6kYXfKr1kqBw8xxO5bQNmlvxUE,749205904
-nvidia/cublas/lib/libnvblas.so.12,sha256=9BKnB5ZDn2pEtI5RHlJ9Ht90GOcz7bpFGJJpqA6WIxE,753824
-nvidia_cublas_cu12-12.9.1.4.dist-info/INSTALLER,sha256=5hhM4Q4mYTT9z6QB6PGpUAW81PGNFrYrdXMj4oM_6ak,2
-nvidia_cublas_cu12-12.9.1.4.dist-info/METADATA,sha256=9QYFRxRORca_jQAL9Q8kKFFOzH7d0swlXbn4mDUr0UY,1707
-nvidia_cublas_cu12-12.9.1.4.dist-info/RECORD,,
-nvidia_cublas_cu12-12.9.1.4.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
-nvidia_cublas_cu12-12.9.1.4.dist-info/WHEEL,sha256=pQkYcE_zO8O3IUE-PXGJ2qj02_9aLe-AqqPPB3Nn1Xg,109
-nvidia_cublas_cu12-12.9.1.4.dist-info/licenses/License.txt,sha256=rW9YU_ugyg0VnQ9Y1JrkmDDC-Mk_epJki5zpCttMbM0,59262
-nvidia_cublas_cu12-12.9.1.4.dist-info/top_level.txt,sha256=fTkAtiFuL16nUrB9ytDDtpytz2t0B4NvYTnRzwAhO14,7

The nvidia package, which is holding the actual implementation files, is identical between uv venv and rules_python venv, modulo the files being symlinked in the latter.

Despite the near-identical site-packages structure, the py_binary target is unable to load the shared nvidia libraries out of the box. After poking around a bit with strace, I see that libcusparse.so.12 exists but loading it fails because its transitive dependency of libnvJitLink.so.12 cannot be found1. My guess is that this failure is caused by the dynamic linker resolving paths from the real file locations, which are different from the symlinked directory structure in Bazel-generated site-packages, breaking the run path lookups inside the C libraries.

As pointed out by @aignas (https://bazelbuild.slack.com/archives/CA306CEV6/p1756127831137709?thread_ts=1756088512.734139&cid=CA306CEV6), it would be interesting to debug this by copying the files instead of symlinking them to verify if it's really the symlinking that breaks things.

Questions and Things to Do

  • Is this something that can/could be fixed from rules_python side? Or should the nvidia packages do something different to make the linker lookups work with symlinks?
  • Is my diagnosis that the files exist but the symlinks make the linker upset even correct? Can there be something else here? Should try raw copies instead of symlinks.

Workaround for Jax

For what it's worth, I'm using Jax, not Pytorch as most others seem to do. I couldn't find any existing examples of patching Jax to preload the packages. Here's how I got around this issue:

Details
modified   MODULE.bazel
@@ -67,6 +67,15 @@ pip.parse(
         "requirements.macos_arm64.txt": "osx_*",
     },
 )
+pip.override(
+    # file = "jax_cuda12_pjrt-0.7.1-py3-none-manylinux_2_27_x86_64.whl",
+    file = "jax_cuda12_pjrt-0.7.0-py3-none-manylinux2014_x86_64.whl",
+    patch_strip = 0,
+    patches = [
+        "//:jax_cuda12_pjrt.patch",
+        "//:jax_cuda12_pjrt_record.patch",
+    ],
+)
 use_repo(pip, "pypi")
 
 pip_onshape = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
new file   jax_cuda12_pjrt.patch
@@ -0,0 +1,10 @@
+--- jax_plugins/xla_cuda12/__init__.py
++++ jax_plugins/xla_cuda12/__init__.py
+@@ -123,6 +123,7 @@
+   We prefer the Python packages, if present. If not, we fall back to loading
+   them from LD_LIBRARY_PATH. By loading the libraries here, later lookups will
+   find these copies."""
++  _load("nvjitlink", ["libnvJitLink.so.12"])
+   _load("cuda_runtime", ["libcudart.so.12"])
+   _load("cu13", ["libcudart.so.13"])
+   # cuda_nvrtc isn't directly a dependency of JAX, but CUDNN appears to need it
new file   jax_cuda12_pjrt_record.patch
@@ -0,0 +1,8 @@
+--- jax_cuda12_pjrt-0.7.0.dist-info/RECORD
++++ jax_cuda12_pjrt-0.7.0.dist-info/RECORD
+@@ -1,4 +1,4 @@
+-jax_plugins/xla_cuda12/__init__.py,sha256=5MWu6cM-YrC38VX6Qkay5IdWRn8Ekp7PuzW2LJoikaY,13182
++jax_plugins/xla_cuda12/__init__.py,sha256=cPm09fPeZlIf0ypdDr5g_1-rV3Ke9097eVW__lKDl48,13227
+ jax_plugins/xla_cuda12/version.py,sha256=dXnNpN9dnyGzJlW3C2fWP189bYNDvP0dHRXriwNP0bY,6733
+ jax_plugins/xla_cuda12/xla_cuda_plugin.so,sha256=VdGXGhi7-IMH7fAIIPIrRKLrl0AZwC2UIs5jQURXOyE,329428392
+ jax_cuda12_pjrt-0.7.0.dist-info/METADATA,sha256=KL0Di3q1g9gILfBxZDrpz9IcW5PSldd1Htwp8ddN1_w,579

Footnotes

  1. This seems like a common issue and can be worked around by preloading the packages. See e.g. https://gist.github.com/qxcv/183c2d6cd81f7028b802b232d6a9dd62, https://github.com/pytorch/pytorch/pull/137059. What I'm trying to understand here is how to get these work out-of-the-box without having to patch the packages.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions