Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: allow for multiple hardware backends #1077

Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,13 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848

# Remove f-prefix from strings that don't use formatting
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6

# format bitsandbytes/cextension.py
04f691ef3061e6659aa0a741ca97d00d031618c4

# whitespace in pyproject.toml
f7b791863083429ba79dc00f925a041beab63297

# format bitsandbytes/functional.py
64ad928224ab1134dff416feee5e7ca663331bc0
01327aa0119fa503ea16322dd72f69f202e4502e
3 changes: 2 additions & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import cuda_setup, research, utils
from . import research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
Expand All @@ -12,6 +12,7 @@
matmul_cublas,
mm_cublas,
)
from .backends import _backend as backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this renaming re-export should be here. Is there an excellent reason for lay users of the library to be able to do from bitsandbytes import backend?

from .cextension import COMPILED_WITH_CUDA
from .nn import modules

Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main():
generate_bug_report_information()

from . import COMPILED_WITH_CUDA
from .cuda_setup.main import get_compute_capabilities
from .device_setup.cuda.main import get_compute_capabilities

print_header("OTHER")
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
Expand Down
11 changes: 11 additions & 0 deletions bitsandbytes/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ..cextension import lib
from ._base import COOSparseTensor
from .nvidia import CudaBackend

_backend = CudaBackend(lib) if lib else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the initialization of the backend should happen on-demand, not eagerly and implicitly during import time. This is exacerbated by the re-export in __init__.py.

I think

_backend: Backend | None = None

def get_backend() -> Backend:
    if not _backend:
         _backend = CudaBackend()

would be the better API.

(Note that I also think lib should be something the backend knows about, not something that gets passed in.)

# TODO: this should actually be done in `cextension.py` and potentially with .get_instance()
# for now this is just a simplifying assumption
#
# Notes from Tim:
# backend = CUDABackend.get_instance()
# -> CUDASetup -> lib -> backend.clib = lib
143 changes: 143 additions & 0 deletions bitsandbytes/backends/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import torch


class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing types?

assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
assert values.dtype == torch.float16
assert values.numel() == nnz
assert rowidx.numel() == nnz
assert colidx.numel() == nnz
Comment on lines +6 to +11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are critical to happen during runtime, they shouldn't be asserts (but just if ...: raise ..., since someone may be running this library with python -O, which disables asserts running.


self.rows = rows
self.cols = cols
self.nnz = nnz
self.rowidx = rowidx
self.colidx = colidx
self.values = values
Comment on lines +4 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it should be a @dataclasses.dataclass.



class BackendInterface:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't smell like an interface, it smells like a base class, so

Suggested change
class BackendInterface:
class Backend:

_instance = None

def __new__(cls, lib=None):
if cls._instance is None:
if lib is None:
raise ValueError(
"A 'lib' binary must be provided during the first initialization of BackendInterface."
)
cls._instance = super().__new__(cls)
cls._instance.lib = (
lib # Set the binary name during the first and only instantiation
)
Comment on lines +24 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not super clear what exactly lib is supposed to be, both from the lack of typing information and the naming. Plus, I'm not sure every backend would needs some kind of binary library implementation?

else:
if lib is not None:
raise ValueError(
"The BackendInterface singleton has already been initialized with a 'lib' value. Re-initialization with a new 'lib' value is not allowed."
)
return cls._instance
Comment on lines +22 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really do think this is unnecessary and will complicate e.g. testing. The canonical way to get at the backend instance would be the above get_backend(); if someone needs to construct a backend by hand from the class, we should assume they know exactly what they're doing.


def check_matmul(
self,
A,
B,
out=None,
transposed_A=False,
transposed_B=False,
expected_type=torch.int8,
):
"""
Checks if the matrix multiplication between A and B can be performed, considering their shapes,
whether they are transposed, and their data types. It also determines the shape of the output tensor.

Parameters:
- A (torch.Tensor): The first matrix in the multiplication.
- B (torch.Tensor): The second matrix in the multiplication.
- out (torch.Tensor, optional): The output tensor to store the result of the multiplication. Default is None.
- transposed_A (bool, optional): Indicates if matrix A is transposed. Default is False.
- transposed_B (bool, optional): Indicates if matrix B is transposed. Default is False.
- expected_type (torch.dtype, optional): The expected data type of matrices A and B. Default is torch.int8.

Returns:
- tuple: The shape of the output tensor resulting from the matrix multiplication.

Raises:
- TypeError: If the data types of A or B do not match the expected type.
- ValueError: If the dimensions of A and B are not compatible for matrix multiplication.
"""
raise NotImplementedError

# 8-bit matmul interface
def coo_zeros(self, rows, cols, nnz, device, dtype=torch.half):
Comment on lines +71 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this expected to be actually overridden by a specific backend type? If not, maybe it should be a @classmethod on COOSparseTensor instead:

@classmethod
def zeros(cls, rows, cols, nnz, device, dtype=torch.half) -> "COOSparseTensor":

rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
values = torch.zeros((nnz,), dtype=dtype, device=device)

return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)

def get_colrow_absmax(
self, A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
):
raise NotImplementedError

def double_quant(
self,
A,
col_stats=None,
row_stats=None,
out_col=None,
out_row=None,
threshold=0.0,
):
raise NotImplementedError

def extract_outliers(self, *args, **kwargs):
raise NotImplementedError

def igemmlt(self, *args, **kwargs):
raise NotImplementedError

def mm_dequant(self, *args, **kwargs):
raise NotImplementedError
Comment on lines +95 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no APIs defined with just *args, **kwargs; I'll assume this is because this is a draft :)

All in all, now that we have the opportunity, the API shape should be really thought-out; I, for one, am a great proponent of making some functions at least partially kwarg-only, so it's impossible to accidentally pass in things in the wrong order etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am a great proponent of making some functions at least partially kwarg-only

Totally agree, I also like locking APIs with at least the *, like def fn(positional_only, /, keyword_and_positional, *, keyword, only):..

Yes, the args, kwargs are only as scaffolding to make a point that these methods still come and will be part of the API (based on Tim's input).


# k-bit quantization interface
def create_quant_map(self, interface, quant_name):
"""
Below functions should be abstracted into a general method
"create_quant_map(interface, "quant_name")", so we can call e.g.
create_quant_map(..., quant_name='normal'):
- 'create_dynamic_map'
- 'create_fp8_map'
- 'create_linear_map'
- 'create_normal_map'
- 'create_quantile_map'
"""
raise NotImplementedError

def estimate_quantiles(self, *args, **kwargs):
raise NotImplementedError

def dequantize_blockwise(self, *args, **kwargs):
raise NotImplementedError

def quantize_blockwise(self, *args, **kwargs):
raise NotImplementedError

# 4-bit matmul interface
def dequantize_4bit(self, *args, **kwargs):
raise NotImplementedError

def quantize_4bit(self, *args, **kwargs):
raise NotImplementedError

def gemv_4bit(self, *args, **kwargs):
raise NotImplementedError

# 8-bit optimizer interface
def optimizer_update_32bit(self, *args, **kwargs):
"""This is needed for tests"""
raise NotImplementedError("Subclasses must implement 'optimizer_update_32bit'.")

def optimizer_update_8bit_blockwise(self, *args, **kwargs):
raise NotImplementedError
54 changes: 54 additions & 0 deletions bitsandbytes/backends/_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import ctypes
from typing import Optional

import torch


def pre_call(device):
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
return prev_device


def post_call(prev_device):
torch.cuda.set_device(prev_device)


def get_ptr(A: Optional[torch.Tensor]) -> Optional[ctypes.c_void_p]:
"""
Get the ctypes pointer from a PyTorch Tensor.

Parameters
----------
A : torch.tensor
The PyTorch tensor.

Returns
-------
ctypes.c_void_p
"""
if A is None:
return None
else:
return ctypes.c_void_p(A.data.data_ptr())


def is_on_gpu(tensors):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Types? This seems it should be called assert_all_on_gpu?

on_gpu = True
gpu_ids = set()
for t in tensors:
if t is None:
continue # NULL pointers are fine
is_paged = getattr(t, "is_paged", False)
on_gpu &= t.device.type == "cuda" or is_paged
if not is_paged:
gpu_ids.add(t.device.index)
if not on_gpu:
raise TypeError(
f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}"
)
if len(gpu_ids) > 1:
raise TypeError(
f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}"
)
return on_gpu
File renamed without changes.
Empty file added bitsandbytes/backends/apple.py
Empty file.
Empty file added bitsandbytes/backends/intel.py
Empty file.
Loading
Loading