Skip to content

Commit

Permalink
Bugfix: Load correct nocublaslt library variant when BNB_CUDA_VERSION…
Browse files Browse the repository at this point in the history
… override is set (bitsandbytes-foundation#1318)
  • Loading branch information
matthewdouglas authored Aug 14, 2024
1 parent 6d714a5 commit a4875fc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 2 additions & 7 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import os
from pathlib import Path
import re

import torch

Expand All @@ -44,13 +45,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:

override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value:
library_name_stem, _, library_name_ext = library_name.rpartition(".")
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`;
# let's remove any trailing numbers:
library_name_stem = library_name_stem.rstrip("0123456789")
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda`;
# let's tack the new version number and the original extension back on.
library_name = f"{library_name_stem}{override_value}.{library_name_ext}"
library_name = re.sub("cuda\d+", f"cuda{override_value}", library_name, count=1)
logger.warning(
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
Expand Down
6 changes: 6 additions & 0 deletions tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?


def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "125")
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt"
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?


def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"

0 comments on commit a4875fc

Please sign in to comment.