diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index f7dd01bdf..fa8ca65c0 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -6,3 +6,13 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 # Remove f-prefix from strings that don't use formatting 7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 + +# format bitsandbytes/cextension.py +04f691ef3061e6659aa0a741ca97d00d031618c4 + +# whitespace in pyproject.toml +f7b791863083429ba79dc00f925a041beab63297 + +# format bitsandbytes/functional.py +64ad928224ab1134dff416feee5e7ca663331bc0 +01327aa0119fa503ea16322dd72f69f202e4502e diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index e54e933d9..a8948d807 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, research, utils +from . import research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, @@ -12,6 +12,7 @@ matmul_cublas, mm_cublas, ) +from .backends import _backend as backend from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index 61b42e78f..d8ba54500 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -59,7 +59,7 @@ def main(): generate_bug_report_information() from . import COMPILED_WITH_CUDA - from .cuda_setup.main import get_compute_capabilities + from .device_setup.cuda.main import get_compute_capabilities print_header("OTHER") print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py new file mode 100644 index 000000000..5da58f25e --- /dev/null +++ b/bitsandbytes/backends/__init__.py @@ -0,0 +1,11 @@ +from ..cextension import lib +from ._base import COOSparseTensor +from .nvidia import CudaBackend + +_backend = CudaBackend(lib) if lib else None +# TODO: this should actually be done in `cextension.py` and potentially with .get_instance() +# for now this is just a simplifying assumption +# +# Notes from Tim: +# backend = CUDABackend.get_instance() +# -> CUDASetup -> lib -> backend.clib = lib diff --git a/bitsandbytes/backends/_base.py b/bitsandbytes/backends/_base.py new file mode 100644 index 000000000..d31f721e4 --- /dev/null +++ b/bitsandbytes/backends/_base.py @@ -0,0 +1,143 @@ +import torch + + +class COOSparseTensor: + def __init__(self, rows, cols, nnz, rowidx, colidx, values): + assert rowidx.dtype == torch.int32 + assert colidx.dtype == torch.int32 + assert values.dtype == torch.float16 + assert values.numel() == nnz + assert rowidx.numel() == nnz + assert colidx.numel() == nnz + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.rowidx = rowidx + self.colidx = colidx + self.values = values + + +class BackendInterface: + _instance = None + + def __new__(cls, lib=None): + if cls._instance is None: + if lib is None: + raise ValueError( + "A 'lib' binary must be provided during the first initialization of BackendInterface." + ) + cls._instance = super().__new__(cls) + cls._instance.lib = ( + lib # Set the binary name during the first and only instantiation + ) + else: + if lib is not None: + raise ValueError( + "The BackendInterface singleton has already been initialized with a 'lib' value. Re-initialization with a new 'lib' value is not allowed." + ) + return cls._instance + + def check_matmul( + self, + A, + B, + out=None, + transposed_A=False, + transposed_B=False, + expected_type=torch.int8, + ): + """ + Checks if the matrix multiplication between A and B can be performed, considering their shapes, + whether they are transposed, and their data types. It also determines the shape of the output tensor. + + Parameters: + - A (torch.Tensor): The first matrix in the multiplication. + - B (torch.Tensor): The second matrix in the multiplication. + - out (torch.Tensor, optional): The output tensor to store the result of the multiplication. Default is None. + - transposed_A (bool, optional): Indicates if matrix A is transposed. Default is False. + - transposed_B (bool, optional): Indicates if matrix B is transposed. Default is False. + - expected_type (torch.dtype, optional): The expected data type of matrices A and B. Default is torch.int8. + + Returns: + - tuple: The shape of the output tensor resulting from the matrix multiplication. + + Raises: + - TypeError: If the data types of A or B do not match the expected type. + - ValueError: If the dimensions of A and B are not compatible for matrix multiplication. + """ + raise NotImplementedError + + # 8-bit matmul interface + def coo_zeros(self, rows, cols, nnz, device, dtype=torch.half): + rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) + colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) + values = torch.zeros((nnz,), dtype=dtype, device=device) + + return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) + + def get_colrow_absmax( + self, A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 + ): + raise NotImplementedError + + def double_quant( + self, + A, + col_stats=None, + row_stats=None, + out_col=None, + out_row=None, + threshold=0.0, + ): + raise NotImplementedError + + def extract_outliers(self, *args, **kwargs): + raise NotImplementedError + + def igemmlt(self, *args, **kwargs): + raise NotImplementedError + + def mm_dequant(self, *args, **kwargs): + raise NotImplementedError + + # k-bit quantization interface + def create_quant_map(self, interface, quant_name): + """ + Below functions should be abstracted into a general method + "create_quant_map(interface, "quant_name")", so we can call e.g. + create_quant_map(..., quant_name='normal'): + - 'create_dynamic_map' + - 'create_fp8_map' + - 'create_linear_map' + - 'create_normal_map' + - 'create_quantile_map' + """ + raise NotImplementedError + + def estimate_quantiles(self, *args, **kwargs): + raise NotImplementedError + + def dequantize_blockwise(self, *args, **kwargs): + raise NotImplementedError + + def quantize_blockwise(self, *args, **kwargs): + raise NotImplementedError + + # 4-bit matmul interface + def dequantize_4bit(self, *args, **kwargs): + raise NotImplementedError + + def quantize_4bit(self, *args, **kwargs): + raise NotImplementedError + + def gemv_4bit(self, *args, **kwargs): + raise NotImplementedError + + # 8-bit optimizer interface + def optimizer_update_32bit(self, *args, **kwargs): + """This is needed for tests""" + raise NotImplementedError("Subclasses must implement 'optimizer_update_32bit'.") + + def optimizer_update_8bit_blockwise(self, *args, **kwargs): + raise NotImplementedError diff --git a/bitsandbytes/backends/_helpers.py b/bitsandbytes/backends/_helpers.py new file mode 100644 index 000000000..adfc2a1c2 --- /dev/null +++ b/bitsandbytes/backends/_helpers.py @@ -0,0 +1,54 @@ +import ctypes +from typing import Optional + +import torch + + +def pre_call(device): + prev_device = torch.cuda.current_device() + torch.cuda.set_device(device) + return prev_device + + +def post_call(prev_device): + torch.cuda.set_device(prev_device) + + +def get_ptr(A: Optional[torch.Tensor]) -> Optional[ctypes.c_void_p]: + """ + Get the ctypes pointer from a PyTorch Tensor. + + Parameters + ---------- + A : torch.tensor + The PyTorch tensor. + + Returns + ------- + ctypes.c_void_p + """ + if A is None: + return None + else: + return ctypes.c_void_p(A.data.data_ptr()) + + +def is_on_gpu(tensors): + on_gpu = True + gpu_ids = set() + for t in tensors: + if t is None: + continue # NULL pointers are fine + is_paged = getattr(t, "is_paged", False) + on_gpu &= t.device.type == "cuda" or is_paged + if not is_paged: + gpu_ids.add(t.device.index) + if not on_gpu: + raise TypeError( + f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}" + ) + if len(gpu_ids) > 1: + raise TypeError( + f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}" + ) + return on_gpu diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/backends/amd.py similarity index 100% rename from bitsandbytes/cuda_setup/__init__.py rename to bitsandbytes/backends/amd.py diff --git a/bitsandbytes/backends/apple.py b/bitsandbytes/backends/apple.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/intel.py b/bitsandbytes/backends/intel.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/nvidia.py b/bitsandbytes/backends/nvidia.py new file mode 100644 index 000000000..f4c14cfab --- /dev/null +++ b/bitsandbytes/backends/nvidia.py @@ -0,0 +1,293 @@ +import ctypes + +import torch + +from ._base import BackendInterface +from ._helpers import get_ptr, is_on_gpu, post_call, pre_call + + +class CudaBackend(BackendInterface): + def check_matmul( + self, A, B, out, transposed_A, transposed_B, expected_type=torch.int8 + ): + if not torch.cuda.is_initialized(): + torch.cuda.init() + if A.dtype != expected_type or B.dtype != expected_type: + raise TypeError( + f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" + ) + + sA = A.shape + sB = B.shape + tA = transposed_A + tB = transposed_B + + correct = True + + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[0] != B.shape[0]: + correct = False + elif tA and tB and A.shape[0] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB and A.shape[2] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[1]: + correct = False + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB and A.shape[2] != B.shape[1]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[1]: + correct = False + elif tA and tB and A.shape[1] != B.shape[2]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[2]: + correct = False + + if out is not None: + sout = out.shape + # special case common in backprop + if not correct and len(sA) == 3 and len(sB) == 3: + if ( + sout[0] == sA[2] + and sout[1] == sB[2] + and sA[0] == sB[0] + and sA[1] == sB[1] + ): + correct = True + else: + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB: + sout = (sA[0], sB[1]) + elif tA and tB: + sout = (sA[1], sB[0]) + elif tA and not tB: + sout = (sA[1], sB[1]) + elif not tA and tB: + sout = (sA[0], sB[0]) + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB: + sout = (sA[0], sA[1], sB[1]) + elif tA and tB: + sout = (sA[0], sA[2], sB[0]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[1]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[0]) + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB: + sout = (sA[0], sA[1], sB[2]) + elif tA and tB: + sout = (sA[0], sA[2], sB[1]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[2]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[1]) + + if not correct: + raise ValueError( + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + ) + + return sout + + def get_colrow_absmax( + self, A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 + ): + assert A.dtype == torch.float16 + device = A.device + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + col_tiles = (cols + 255) // 256 + tiled_rows = ((rows + 15) // 16) * 16 + if row_stats is None: + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) + if col_stats is None: + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) + + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros( + ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device + ) + + ptrA = get_ptr(A) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNnzrows = get_ptr(nnz_block_ptr) + rows = ctypes.c_int32(rows) + cols = ctypes.c_int32(cols) + + prev_device = pre_call(A.device) + is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) + self.lib.cget_col_row_stats( + ptrA, + ptrRowStats, + ptrColStats, + ptrNnzrows, + ctypes.c_float(threshold), + rows, + cols, + ) + post_call(prev_device) + + if threshold > 0.0: + nnz_block_ptr.cumsum_(0) + + return row_stats, col_stats, nnz_block_ptr + + def double_quant( + self, + A, + col_stats=None, + row_stats=None, + out_col=None, + out_row=None, + threshold=0.0, + ): + device = A.device + assert A.dtype == torch.half + assert device.type == "cuda" + prev_device = pre_call(A.device) + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = self.get_colrow_absmax( + A, threshold=threshold + ) + + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + + coo_tensor = None + ptrA = get_ptr(A) + ptrColStats = get_ptr(col_stats) + ptrRowStats = get_ptr(row_stats) + ptrOutCol = get_ptr(out_col) + ptrOutRow = get_ptr(out_row) + + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) + if threshold > 0.0: + nnz = nnz_row_ptr[-1].item() + if nnz > 0: + coo_tensor = self.coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) + ptrRowIdx = get_ptr(coo_tensor.rowidx) + ptrColIdx = get_ptr(coo_tensor.colidx) + ptrVal = get_ptr(coo_tensor.values) + ptrRowPtr = get_ptr(nnz_row_ptr) + + self.lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ctypes.c_float(threshold), + ctypes.c_int32(rows), + ctypes.c_int32(cols), + ) + val, idx = torch.sort(coo_tensor.rowidx) + coo_tensor.rowidx = val + coo_tensor.colidx = coo_tensor.colidx[idx] + coo_tensor.values = coo_tensor.values[idx] + else: + self.lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ctypes.c_float(0.0), + ctypes.c_int32(rows), + ctypes.c_int32(cols), + ) + else: + self.lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ctypes.c_float(threshold), + ctypes.c_int32(rows), + ctypes.c_int32(cols), + ) + post_call(prev_device) + + return out_row, out_col, row_stats, col_stats, coo_tensor + + """ + # CUDA specific interface (do not include in general interface): + 'CUBLAS_Context' + 'Cusparse_Context' + 'GlobalPageManager' + '_mul' + 'arange' + 'dtype2bytes' + 'elementwise_func' + 'fill' + 'get_paged' + 'get_4bit_type' + 'get_ptr' + 'get_special_format_str' + 'get_transform_buffer' + 'get_transform_func' + 'is_on_gpu' + 'nvidia_transform' + 'transform' + + ## Deprecate these: + 'optimizer_update_8bit' + 'dequant_min_max' + 'dequantize' + 'dequantize_no_absmax' + 'igemm' + 'quantize' + 'spmm_coo' + 'spmm_coo_very_sparse' + 'vectorwise_dequant' + 'vectorwise_mm_dequant' + 'vectorwise_quant' + 'CSCSparseTensor' + 'CSRSparseTensor' + 'coo2csc' + 'coo2csr' + 'histogram_scatter_add_2d' + """ diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 858365f02..1636d06b0 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -3,7 +3,7 @@ import torch -from bitsandbytes.cuda_setup.main import CUDASetup +from bitsandbytes.device_setup.cuda.main import CUDASetup setup = CUDASetup.get_instance() if setup.initialized != True: @@ -32,8 +32,3 @@ "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") COMPILED_WITH_CUDA = False print(str(ex)) - - -# print the setup details after checking for errors so we do not print twice -#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - #setup.print_log_stack() diff --git a/bitsandbytes/device_setup/__init__.py b/bitsandbytes/device_setup/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/device_setup/cuda/__init__.py b/bitsandbytes/device_setup/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/device_setup/cuda/env_vars.py similarity index 100% rename from bitsandbytes/cuda_setup/env_vars.py rename to bitsandbytes/device_setup/cuda/env_vars.py diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/device_setup/cuda/main.py similarity index 99% rename from bitsandbytes/cuda_setup/main.py rename to bitsandbytes/device_setup/cuda/main.py index cd0d94cd7..a37cbf36a 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/device_setup/cuda/main.py @@ -142,7 +142,7 @@ def run_cuda_setup(self): self.binary_name = binary_name self.manual_override() - package_dir = Path(__file__).parent.parent + package_dir = Path(__file__).parent.parent.parent binary_path = package_dir / self.binary_name try: @@ -278,7 +278,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: 1. active conda env 2. LD_LIBRARY_PATH 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) + - are known to be unrelated (see `bnb.device_setup.cuda.env_vars.to_be_ignored`) - don't contain the path separator `/` If multiple libraries are found in part 3, we optimistically try one, diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f0de962e1..0841d11be 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,10 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct -from functools import reduce # Required in Python 3 +from functools import ( + reduce, # Required in Python 3 + wraps, +) import itertools import operator from typing import Any, Dict, Optional, Tuple +import warnings import numpy as np import torch @@ -14,6 +18,8 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict +from .backends import _backend as backend +from .backends._helpers import get_ptr, is_on_gpu, post_call, pre_call from .cextension import COMPILED_WITH_CUDA, lib @@ -21,6 +27,39 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) + +def deprecated(_func=None, *, new_func_name=None): + """ + A decorator to mark functions as deprecated. It issues a warning when the decorated function is called, + advising to use a specified new function instead. + + Parameters: + - _func (callable, optional): The function to be deprecated. This is for internal use when the decorator is applied without parentheses. + - new_func_name (str, optional): The name of the new function to use instead of the deprecated one. Defaults to 'bitsandbytes.backend.'. + + Usage: + @deprecated + def old_function(): + ... + + @deprecated(new_func_name='module.new_function') + def another_old_function(): + ... + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + replacement = new_func_name or f"bitsandbytes.backend.{func.__name__}" + warning_message = f"'{func.__name__}' is deprecated and will be removed in a future version. Use '{replacement}' instead." + warnings.warn(warning_message, DeprecationWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapper + + return decorator if _func is None else decorator(_func) + + name2qmap = {} if COMPILED_WITH_CUDA: @@ -396,52 +435,6 @@ def get_special_format_str(): return "col_turing" - -def is_on_gpu(tensors): - on_gpu = True - gpu_ids = set() - for t in tensors: - if t is None: continue # NULL pointers are fine - is_paged = getattr(t, 'is_paged', False) - on_gpu &= (t.device.type == 'cuda' or is_paged) - if not is_paged: - gpu_ids.add(t.device.index) - if not on_gpu: - raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') - if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') - return on_gpu - - -def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: - """ - Get the ctypes pointer from a PyTorch Tensor. - - Parameters - ---------- - A : torch.tensor - The PyTorch tensor. - - Returns - ------- - ctypes.c_void_p - """ - if A is None: - return None - else: - return ct.c_void_p(A.data.data_ptr()) - - -def pre_call(device): - prev_device = torch.cuda.current_device() - torch.cuda.set_device(device) - return prev_device - - -def post_call(prev_device): - torch.cuda.set_device(prev_device) - - def get_transform_func(dtype, orderA, orderOut, transpose=False): name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' if not hasattr(lib, name): @@ -2022,67 +2015,9 @@ def mm_dequant( return out -def get_colrow_absmax( - A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 -): - assert A.dtype == torch.float16 - device = A.device - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - col_tiles = (cols + 255) // 256 - tiled_rows = ((rows + 15) // 16) * 16 - if row_stats is None: - row_stats = torch.empty( - (rows,), dtype=torch.float32, device=device - ).fill_(-50000.0) - if col_stats is None: - col_stats = torch.empty( - (cols,), dtype=torch.float32, device=device - ).fill_(-50000.0) - - if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros( - ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device - ) - - ptrA = get_ptr(A) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNnzrows = get_ptr(nnz_block_ptr) - rows = ct.c_int32(rows) - cols = ct.c_int32(cols) - - prev_device = pre_call(A.device) - is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) - lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) - post_call(prev_device) - - if threshold > 0.0: - nnz_block_ptr.cumsum_(0) - - return row_stats, col_stats, nnz_block_ptr - - -class COOSparseTensor: - def __init__(self, rows, cols, nnz, rowidx, colidx, values): - assert rowidx.dtype == torch.int32 - assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 - assert values.numel() == nnz - assert rowidx.numel() == nnz - assert colidx.numel() == nnz - - self.rows = rows - self.cols = cols - self.nnz = nnz - self.rowidx = rowidx - self.colidx = colidx - self.values = values +@deprecated +def get_colrow_absmax(*args, **kwargs): + return backend.get_colrow_absmax(*args, **kwargs) class CSRSparseTensor: @@ -2147,108 +2082,14 @@ def coo2csc(cooA): cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values ) +@deprecated +def coo_zeros(*args, **kwargs): + return backend.coo_zeros(*args, **kwargs) -def coo_zeros(rows, cols, nnz, device, dtype=torch.half): - rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) - colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) - values = torch.zeros((nnz,), dtype=dtype, device=device) - return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) - - -def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): - device = A.device - assert A.dtype == torch.half - assert device.type == "cuda" - prev_device = pre_call(A.device) - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) - - if out_col is None: - out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: - out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) - - coo_tensor = None - ptrA = get_ptr(A) - ptrColStats = get_ptr(col_stats) - ptrRowStats = get_ptr(row_stats) - ptrOutCol = get_ptr(out_col) - ptrOutRow = get_ptr(out_row) - - is_on_gpu([A, col_stats, row_stats, out_col, out_row]) - if threshold > 0.0: - nnz = nnz_row_ptr[-1].item() - if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) - ptrRowIdx = get_ptr(coo_tensor.rowidx) - ptrColIdx = get_ptr(coo_tensor.colidx) - ptrVal = get_ptr(coo_tensor.values) - ptrRowPtr = get_ptr(nnz_row_ptr) - - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - ptrRowIdx, - ptrColIdx, - ptrVal, - ptrRowPtr, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - val, idx = torch.sort(coo_tensor.rowidx) - coo_tensor.rowidx = val - coo_tensor.colidx = coo_tensor.colidx[idx] - coo_tensor.values = coo_tensor.values[idx] - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(0.0), - ct.c_int32(rows), - ct.c_int32(cols), - ) - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - post_call(prev_device) - return out_row, out_col, row_stats, col_stats, coo_tensor +@deprecated +def double_quant(*args, **kwargs): + return backend.double_quant(*args, **kwargs) def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): diff --git a/pyproject.toml b/pyproject.toml index f74750720..9a2072af6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,22 @@ src = [ "tests", "benchmarking" ] + +target-version = "py38" + +[tool.ruff.lint] +ignore = [ + "B007", # Loop control variable not used within the loop body (TODO: enable) + "B028", # Warning without stacklevel (TODO: enable) + "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. + "E701", # Multiple statements on one line (TODO: enable) + "E712", # Allow using if x == False, as it's not always equivalent to if x. + "E731", # Do not use lambda + "F841", # Local assigned but not used (TODO: enable, these are likely bugs) + "RUF012", # Mutable class attribute annotations + "ISC001", # String concatination warning: may cause conflicts when used with the formatter +] +ignore-init-module-imports = true # allow to expose in __init__.py via imports select = [ "B", # bugbear: security warnings "E", # pycodestyle @@ -17,20 +33,8 @@ select = [ "UP", # alert you when better syntax is available in your python version "RUF", # the ruff developer's own rules ] -target-version = "py38" -ignore = [ - "B007", # Loop control variable not used within the loop body (TODO: enable) - "B028", # Warning without stacklevel (TODO: enable) - "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. - "E701", # Multiple statements on one line (TODO: enable) - "E712", # Allow using if x == False, as it's not always equivalent to if x. - "E731", # Do not use lambda - "F841", # Local assigned but not used (TODO: enable, these are likely bugs) - "RUF012", # Mutable class attribute annotations -] -ignore-init-module-imports = true # allow to expose in __init__.py via imports -[tool.ruff.extend-per-file-ignores] +[tool.ruff.lint.extend-per-file-ignores] "**/__init__.py" = ["F401"] # allow unused imports in __init__.py "{benchmarking,tests}/**/*.py" = [ "B007", @@ -42,7 +46,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports "UP030", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true detect-same-package = true force-sort-within-sections = true diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 189aa75b5..e3620bf41 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -17,5 +17,5 @@ def test_manual_override(requires_cuda): os.environ['BNB_CUDA_VERSION']='122' #assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH'] import bitsandbytes as bnb - loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name + loaded_lib = bnb.device_setup.cuda.main.CUDASetup.get_instance().binary_name #assert loaded_lib == 'libbitsandbytes_cuda122.so'