-
Notifications
You must be signed in to change notification settings - Fork 641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DRAFT: allow for multiple hardware backends #1077
Changes from 18 commits
35f056d
c46fcf6
8a29bc5
8415b6e
a162d40
7ffe552
9b45576
c2a8594
3a57942
83dc363
bd9fb62
01327aa
2d05dc5
8fe3cb3
84f1ab6
eb90e39
d8e13b7
4370e61
044147c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from ..cextension import lib | ||
from ._base import COOSparseTensor | ||
from .nvidia import CudaBackend | ||
|
||
_backend = CudaBackend(lib) if lib else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the initialization of the backend should happen on-demand, not eagerly and implicitly during import time. This is exacerbated by the re-export in I think _backend: Backend | None = None
def get_backend() -> Backend:
if not _backend:
_backend = CudaBackend() would be the better API. (Note that I also think |
||
# TODO: this should actually be done in `cextension.py` and potentially with .get_instance() | ||
# for now this is just a simplifying assumption | ||
# | ||
# Notes from Tim: | ||
# backend = CUDABackend.get_instance() | ||
# -> CUDASetup -> lib -> backend.clib = lib |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,143 @@ | ||||||
import torch | ||||||
|
||||||
|
||||||
class COOSparseTensor: | ||||||
def __init__(self, rows, cols, nnz, rowidx, colidx, values): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing types? |
||||||
assert rowidx.dtype == torch.int32 | ||||||
assert colidx.dtype == torch.int32 | ||||||
assert values.dtype == torch.float16 | ||||||
assert values.numel() == nnz | ||||||
assert rowidx.numel() == nnz | ||||||
assert colidx.numel() == nnz | ||||||
Comment on lines
+6
to
+11
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If these are critical to happen during runtime, they shouldn't be |
||||||
|
||||||
self.rows = rows | ||||||
self.cols = cols | ||||||
self.nnz = nnz | ||||||
self.rowidx = rowidx | ||||||
self.colidx = colidx | ||||||
self.values = values | ||||||
Comment on lines
+4
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like it should be a |
||||||
|
||||||
|
||||||
class BackendInterface: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't smell like an interface, it smells like a base class, so
Suggested change
|
||||||
_instance = None | ||||||
|
||||||
def __new__(cls, lib=None): | ||||||
if cls._instance is None: | ||||||
if lib is None: | ||||||
raise ValueError( | ||||||
"A 'lib' binary must be provided during the first initialization of BackendInterface." | ||||||
) | ||||||
cls._instance = super().__new__(cls) | ||||||
cls._instance.lib = ( | ||||||
lib # Set the binary name during the first and only instantiation | ||||||
) | ||||||
Comment on lines
+24
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not super clear what exactly |
||||||
else: | ||||||
if lib is not None: | ||||||
raise ValueError( | ||||||
"The BackendInterface singleton has already been initialized with a 'lib' value. Re-initialization with a new 'lib' value is not allowed." | ||||||
) | ||||||
return cls._instance | ||||||
Comment on lines
+22
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really do think this is unnecessary and will complicate e.g. testing. The canonical way to get at the backend instance would be the above |
||||||
|
||||||
def check_matmul( | ||||||
self, | ||||||
A, | ||||||
B, | ||||||
out=None, | ||||||
transposed_A=False, | ||||||
transposed_B=False, | ||||||
expected_type=torch.int8, | ||||||
): | ||||||
""" | ||||||
Checks if the matrix multiplication between A and B can be performed, considering their shapes, | ||||||
whether they are transposed, and their data types. It also determines the shape of the output tensor. | ||||||
|
||||||
Parameters: | ||||||
- A (torch.Tensor): The first matrix in the multiplication. | ||||||
- B (torch.Tensor): The second matrix in the multiplication. | ||||||
- out (torch.Tensor, optional): The output tensor to store the result of the multiplication. Default is None. | ||||||
- transposed_A (bool, optional): Indicates if matrix A is transposed. Default is False. | ||||||
- transposed_B (bool, optional): Indicates if matrix B is transposed. Default is False. | ||||||
- expected_type (torch.dtype, optional): The expected data type of matrices A and B. Default is torch.int8. | ||||||
|
||||||
Returns: | ||||||
- tuple: The shape of the output tensor resulting from the matrix multiplication. | ||||||
|
||||||
Raises: | ||||||
- TypeError: If the data types of A or B do not match the expected type. | ||||||
- ValueError: If the dimensions of A and B are not compatible for matrix multiplication. | ||||||
""" | ||||||
raise NotImplementedError | ||||||
|
||||||
# 8-bit matmul interface | ||||||
def coo_zeros(self, rows, cols, nnz, device, dtype=torch.half): | ||||||
Comment on lines
+71
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this expected to be actually overridden by a specific backend type? If not, maybe it should be a @classmethod
def zeros(cls, rows, cols, nnz, device, dtype=torch.half) -> "COOSparseTensor": |
||||||
rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) | ||||||
colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) | ||||||
values = torch.zeros((nnz,), dtype=dtype, device=device) | ||||||
|
||||||
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) | ||||||
|
||||||
def get_colrow_absmax( | ||||||
self, A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 | ||||||
): | ||||||
raise NotImplementedError | ||||||
|
||||||
def double_quant( | ||||||
self, | ||||||
A, | ||||||
col_stats=None, | ||||||
row_stats=None, | ||||||
out_col=None, | ||||||
out_row=None, | ||||||
threshold=0.0, | ||||||
): | ||||||
raise NotImplementedError | ||||||
|
||||||
def extract_outliers(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
|
||||||
def igemmlt(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
|
||||||
def mm_dequant(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
Comment on lines
+95
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be no APIs defined with just All in all, now that we have the opportunity, the API shape should be really thought-out; I, for one, am a great proponent of making some functions at least partially kwarg-only, so it's impossible to accidentally pass in things in the wrong order etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Totally agree, I also like locking APIs with at least the Yes, the args, kwargs are only as scaffolding to make a point that these methods still come and will be part of the API (based on Tim's input). |
||||||
|
||||||
# k-bit quantization interface | ||||||
def create_quant_map(self, interface, quant_name): | ||||||
""" | ||||||
Below functions should be abstracted into a general method | ||||||
"create_quant_map(interface, "quant_name")", so we can call e.g. | ||||||
create_quant_map(..., quant_name='normal'): | ||||||
- 'create_dynamic_map' | ||||||
- 'create_fp8_map' | ||||||
- 'create_linear_map' | ||||||
- 'create_normal_map' | ||||||
- 'create_quantile_map' | ||||||
""" | ||||||
raise NotImplementedError | ||||||
|
||||||
def estimate_quantiles(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
|
||||||
def dequantize_blockwise(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
|
||||||
def quantize_blockwise(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
|
||||||
# 4-bit matmul interface | ||||||
def dequantize_4bit(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
|
||||||
def quantize_4bit(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
|
||||||
def gemv_4bit(self, *args, **kwargs): | ||||||
raise NotImplementedError | ||||||
|
||||||
# 8-bit optimizer interface | ||||||
def optimizer_update_32bit(self, *args, **kwargs): | ||||||
"""This is needed for tests""" | ||||||
raise NotImplementedError("Subclasses must implement 'optimizer_update_32bit'.") | ||||||
|
||||||
def optimizer_update_8bit_blockwise(self, *args, **kwargs): | ||||||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import ctypes | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
|
||
def pre_call(device): | ||
prev_device = torch.cuda.current_device() | ||
torch.cuda.set_device(device) | ||
return prev_device | ||
|
||
|
||
def post_call(prev_device): | ||
torch.cuda.set_device(prev_device) | ||
|
||
|
||
def get_ptr(A: Optional[torch.Tensor]) -> Optional[ctypes.c_void_p]: | ||
""" | ||
Get the ctypes pointer from a PyTorch Tensor. | ||
|
||
Parameters | ||
---------- | ||
A : torch.tensor | ||
The PyTorch tensor. | ||
|
||
Returns | ||
------- | ||
ctypes.c_void_p | ||
""" | ||
if A is None: | ||
return None | ||
else: | ||
return ctypes.c_void_p(A.data.data_ptr()) | ||
|
||
|
||
def is_on_gpu(tensors): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Types? This seems it should be called |
||
on_gpu = True | ||
gpu_ids = set() | ||
for t in tensors: | ||
if t is None: | ||
continue # NULL pointers are fine | ||
is_paged = getattr(t, "is_paged", False) | ||
on_gpu &= t.device.type == "cuda" or is_paged | ||
if not is_paged: | ||
gpu_ids.add(t.device.index) | ||
if not on_gpu: | ||
raise TypeError( | ||
f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}" | ||
) | ||
if len(gpu_ids) > 1: | ||
raise TypeError( | ||
f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}" | ||
) | ||
return on_gpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this renaming re-export should be here. Is there an excellent reason for lay users of the library to be able to do
from bitsandbytes import backend
?