From 0cc5c95efc213ffc8ed64ad34376fa4e4a3114ac Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 7 Oct 2024 13:05:20 -0400 Subject: [PATCH 01/65] Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation --- bitsandbytes/autograd/_functions.py | 83 ++++---- bitsandbytes/functional.py | 165 +++++++-------- bitsandbytes/nn/modules.py | 7 +- bitsandbytes/research/autograd/_functions.py | 33 +-- csrc/ops.cu | 103 +++++++++- csrc/ops.cuh | 2 +- csrc/pythonInterface.cpp | 19 ++ tests/test_autograd.py | 12 +- tests/test_functional.py | 199 +++++++++++-------- tests/test_modules.py | 2 +- 10 files changed, 382 insertions(+), 243 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d33dd1bc5..01845a131 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,11 +245,11 @@ class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None force_no_igemmlt: bool = False CB = None - CxB = None + CxB = None # TODO: Deprecate/remove SB = None SCB = None - CxBt = None + CxBt = None # TODO: Deprecate/remove SBt = None CBt = None @@ -263,7 +263,7 @@ class MatmulLtState: has_fp16_weights = True memory_efficient_backward = False use_pool = False - formatB = F.get_special_format_str() + formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove def reset_grads(self): self.CB = None @@ -283,9 +283,6 @@ def tile_indices(self): class MatMul8bitLt(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - @staticmethod def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt @@ -306,7 +303,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # 3. Matmul # 4. Mixed-precision decomposition matmul # 5. Save state - formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() @@ -328,14 +324,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - else: - if state.CxB is None and using_igemmlt: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: - if not state.has_fp16_weights and state.CxB is None and using_igemmlt: - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None # 2. Quantize B @@ -345,19 +334,17 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): if is_transposed: B = B.contiguous() - if (state.is_training and not has_grad) or state.CxB is None: + if (state.is_training and not has_grad) or state.CB is None: state.reset_grads() + + # quantize... ( - CB, + state.CB, state.CBt, state.SCB, state.SCBt, coo_tensorB, ) = F.double_quant(B.to(torch.float16)) - if using_igemmlt: - state.CxB, state.SB = F.transform(CB, to_order=formatB) - else: - state.CB = CB else: has_grad = False @@ -372,17 +359,18 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # else: # state.idx = outlier_idx - if state.CxB is not None: - outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - else: - outliers = state.CB[:, state.idx.long()].clone() + + # if state.CxB is not None: + # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + # else: + outliers = state.CB[:, state.idx.long()].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] - shapeB = state.SB[0] if state.SB else B.shape + shapeB = state.CB.shape if len(input_shape) == 3: output_shape = (input_shape[0], input_shape[1], shapeB[0]) @@ -391,13 +379,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # 3. Matmul if using_igemmlt: - C32A, SA = F.transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + out32, Sout32 = F.igemmlt(CA, state.CB) + if bias is None or bias.dtype == torch.float16: # we apply the fused bias here output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = output.to(A.dtype) else: # apply bias separately + # TODO: Fused bias for fp32/bf16? output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) output = output.to(A.dtype).add_(bias) @@ -417,7 +406,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # 5. Save state ctx.state = state - ctx.formatB = formatB ctx.grad_shape = input_shape ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype @@ -437,10 +425,10 @@ def backward(ctx, grad_output): if ctx.is_empty: bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states - formatB = ctx.formatB state = ctx.state grad_A = grad_B = grad_bias = None @@ -454,33 +442,39 @@ def backward(ctx, grad_output): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: - CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) - gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + # CxAt, SAt = F.transform(CAt, formatB, transpose=True) + # C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) + # gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + gradB32, SgradB32 = F.igemmlt( + Cgradt.t(), CAt.t() + ) # issue here in test_linear_serialization w/ has fp16 weights grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: if state.CBt is not None: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + # C32grad, Sgrad = F.transform(Cgrad, "col32") + # if state.CxBt is None: + # state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) + # gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t()) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) - elif state.CxB is not None: - CB = ( - undo_layout(state.CxB, state.tile_indices) - .to(ctx.dtype_A) - .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - ) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + # elif state.CxB is not None: + # CB = ( + # undo_layout(state.CxB, state.tile_indices) + # .to(ctx.dtype_A) + # .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + # ) + # grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception("State must contain either CBt or CB or CxB matrix for backward") + raise Exception("State must contain either CBt or CB matrix for backward") return grad_A, grad_B, None, grad_bias, None @@ -564,6 +558,7 @@ def matmul_4bit( bias=None, ): assert quant_state is not None + if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 34b3c0293..7f07778ef 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,9 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct -from functools import reduce # Required in Python 3 import itertools -import operator +from math import prod from typing import Any, Dict, Optional, Tuple import numpy as np @@ -16,12 +15,6 @@ from .cextension import lib - -# math.prod not compatible with python < 3.8 -def prod(iterable): - return reduce(operator.mul, iterable, 1) - - name2qmap = {} if lib and lib.compiled_with_cuda: @@ -421,15 +414,9 @@ def create_quantile_map(A, total_bits=8): return q +# TODO: Deprecate def get_special_format_str(): - if not torch.cuda.is_available(): - return "col_turing" - major, _minor = torch.cuda.get_device_capability() - if major <= 7: - return "col_turing" - if major == 8: - return "col_ampere" - return "col_turing" + return "row" def is_on_gpu(tensors): @@ -2302,84 +2289,68 @@ def batched_igemm( return out -def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - shapeA = SA[0] - shapeB = SB[0] - dimsA = len(shapeA) - dimsB = len(shapeB) - assert dimsB == 2, "Only two dimensional matrices are supported for argument B" - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - - rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) - - if dimsA == 2 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") - elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") - - assert dimsB != 3, "len(B.shape)==3 not supported" +def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): + # + # To use the IMMA tensor core kernels without special Turing/Ampere layouts, + # cublasLt has some rules, namely: A must be transposed, B must not be transposed. + # The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major. + # This will typically be used with row-major tensors to efficiently + # calculate the linear layer with `C = B @ A.T` without any transformations. + # We will swap A and B in the API invocation, so that we get `C = A @ B.T`. + # + # Quick explanation: + # With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`. + # To get row-major output, `C.T = (A @ B.T).T = B @ A.T`. + # + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + dimsA = A.ndim + dimsB = B.ndim + assert A.device.type == "cuda" assert B.device.type == "cuda" assert A.dtype == torch.int8 assert B.dtype == torch.int8 + assert dimsA == 2, "Only two dimensional matrices are supported for argument B" + assert dimsB in [2, 3], "Only two or three dimensional matrices are supported for argument A" + assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}" + + shapeC = (*shapeB[:-1], shapeA[0]) + Sout = (shapeC, "row") + + if out is None: + out = torch.empty(shapeC, device=A.device, dtype=dtype) + assert out.dtype == dtype - assert SA[1] == "col32" - assert SB[1] in ["col_turing", "col_ampere"] - assert Sout[1] == "col32" - assert ( - shapeA[-1] == shapeB[-1] - ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" - formatB = SB[1] + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ B={shapeA}" + prev_device = A.device torch.cuda.set_device(A.device) - ptr = CUBLAS_Context.get_instance().get_context(A.device) + ctx = CUBLAS_Context.get_instance().get_context(A.device) ptrA = get_ptr(A) ptrB = get_ptr(B) ptrC = get_ptr(out) - - k = shapeA[-1] - lda = ct.c_int32(m * 32) - if formatB == "col_turing": - # turing: tiles with rows filled up to multiple of 8 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) - else: - # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - - ldc = ct.c_int32(m * 32) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - - has_error = 0 ptrRowScale = get_ptr(None) + m, n, k, lda, ldb, ldc = map(ct.c_int32, (m, n, k, lda, ldb, ldc)) + is_on_gpu([A, B, out]) - if formatB == "col_turing": - if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - else: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif formatB == "col_ampere": - if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - else: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + + if dtype == torch.int32: + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + else: + has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") + raise NotImplementedError("igemmlt not implemented!") if has_error: print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") @@ -2392,6 +2363,26 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): assert A.dtype == torch.int32 + + compute_dtype = torch.float32 + + A_calc = A.view(-1, A.shape[-1]).to(compute_dtype) + row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) + col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) + + # TODO support out != None + + out = A_calc * (row_stats * col_stats) * 6.200124e-5 # .to(torch.float16) + + if bias is not None: + # assert bias.dtype == torch.float16 + out.add_(bias) + + return out.to(torch.float16) + + +def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): + assert A.dtype == torch.int32 if bias is not None: assert bias.dtype == torch.float16 out_shape = quant_state[0] @@ -2553,6 +2544,21 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + # TODO: Optimize/write CUDA kernel for this. Currently vectorwise_quant will recalculate row/col stats. + # TODO: Support threshold + + # if out_col is None: + # out_col = torch.zeros(A.shape, device=A.device, dtype=torch.int8) + # if out_row is None: + # out_row = torch.zeros(A.shape, device=A.device, dtype=torch.int8) + + out_col, Scol = vectorwise_quant(A, dim=0) + out_row, Srow = vectorwise_quant(A, dim=1) + + return out_row, out_col, Srow.flatten().float(), Scol.flatten().float(), None # coo_tensor + + +def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -2949,6 +2955,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): + # TODO: Implement for row-major shapeA = SA[0] formatA = SA[1] assert formatA in ["col_turing", "col_ampere"] diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6c78494aa..1e5a334ee 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1009,11 +1009,8 @@ def forward(self, x: torch.Tensor): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) if not self.state.has_fp16_weights: - if self.state.CB is not None and self.state.CxB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + if self.state.CB is not None: + self.weight.data = self.state.CB return out diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index b194b8777..5f8b2c437 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -204,7 +204,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # 3. Matmul # 4. Mixed-precision decomposition matmul # 5. Save state - formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() @@ -227,14 +226,11 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 state.subB = B[:, idx].t().contiguous() state.idx = idx else: - if state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + if state.SB is None: + state.SB = (state.CB.shape, "row") else: - # print('A shape', A.shape) - if not state.has_fp16_weights and state.CxB is None: - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + if not state.has_fp16_weights and state.SB is None: + state.SB = (state.CB.shape, "row") subA = None # 2. Quantize B @@ -245,16 +241,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 if is_transposed: B = B.contiguous() - if (state.is_training and not has_grad) or state.CxB is None: + if (state.is_training and not has_grad) or state.SB is None: state.reset_grads() ( - CB, + state.CB, state.CBt, state.SCB, state.SCBt, coo_tensorB, ) = F.double_quant(B.to(torch.float16)) - state.CxB, state.SB = F.transform(CB, to_order=formatB) + state.SB = (state.CB.shape, "row") else: has_grad = False @@ -269,7 +265,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # else: # state.idx = outlier_idx - outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + outliers = state.CB[:, state.idx.long()].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 @@ -283,8 +280,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - C32A, SA = F.transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + out32, Sout32 = F.igemmlt(CA, state.CB) # we apply the fused bias here if bias is None or bias.dtype == torch.float16: @@ -301,7 +297,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # 5. Save state ctx.state = state - ctx.formatB = formatB ctx.grad_shape = input_shape ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype @@ -324,7 +319,6 @@ def backward(ctx, grad_output): req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states - formatB = ctx.formatB state = ctx.state grad_A = grad_B = grad_bias = None @@ -345,12 +339,7 @@ def backward(ctx, grad_output): if req_gradA: if state.CBt is not None: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) - # print('back B shape', state.CxBt.shape) - # print('back grad shape', C32grad.shape) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t()) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: diff --git a/csrc/ops.cu b/csrc/ops.cu index 7ca854baf..1f259d67f 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -422,6 +422,101 @@ template void trans #endif } +template int igemmlt( + cublasLtHandle_t ltHandle, + int m, int n, int k, + const int8_t * A, + const int8_t * B, + void * C, + float * row_scale, + int lda, int ldb, int ldc +) { + + // Calculate C = A^T @ B, in col-major layout. + // + // Use the IMMA kernels requires: + // * A must be transposed and B must be non-transposed. + // * All leading dimensions must be multiples of 4. + // * Dimensions m and k must be multiples of 4. + // * All pointers must be 4-byte aligned; 16-byte alignment preferred. + // + + + int has_error = 0; + + // this is the default + cublasLtOrder_t col_major = CUBLASLT_ORDER_COL; + + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t aDesc, bDesc, cDesc; + cublasOperation_t opT = CUBLAS_OP_T; + + cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I; + cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F; + + cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); + + // Default layout order is col major + + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); + + if (DTYPE_OUT == 32) { + int alpha = 1, beta = 0; + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int32_t*)C, cDesc, + (int32_t*)C, cDesc, + NULL, NULL, 0, 0 + )); + } else { + if (!SCALE_ROWS) { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, 0 + )); + } else { + cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + float beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute( + matmulDesc, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointerMode, + sizeof(alphaVec) + )); + has_error |= checkCublasStatus(cublasLtMatmul( + ltHandle, matmulDesc, + row_scale, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, 0 + )); + } + } + + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc)); + has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + + if(has_error == 1) + printf("error detected"); + + return has_error; +} + template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { #ifdef NO_CUBLASLT @@ -729,8 +824,8 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { - int num_blocks = (m+3)/4; - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + int num_blocks = (m+7)/8; + kgemm_4bit_inference_naive<<< num_blocks, 256, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -772,6 +867,10 @@ template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index b0ecc4622..ab0185242 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -175,7 +175,7 @@ void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, i void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount); - +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index f0ee84c29..09b9b62a9 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -175,6 +175,15 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } + int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } @@ -316,6 +325,16 @@ extern "C" Context *get_context(){ return new Context(); } ContextCusparse *get_cusparse(){ return new ContextCusparse(); } + int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } //{ (cublasLtHandle_t)context->m_handle; return 0; } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 9da665a2d..89dce644b 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -198,10 +198,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool assert (idx == 0).sum().item() < n * 0.02 -@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +# @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +# @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +# @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) # [64,0] +@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", [48], ids=id_formatter("dim4")) @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) @pytest.mark.parametrize( "funcs", diff --git a/tests/test_functional.py b/tests/test_functional.py index 1cca04511..522af516c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -570,10 +570,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans torch.testing.assert_close(A, out2) -@pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) +# @pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2")) +# @pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3")) +# @pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [512], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): @@ -585,20 +589,17 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) - A2, SA = F.transform(A, "col32") - B2, SB = F.transform(B, "col_turing") - C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_close(C1, C3.float()) + C2, SC = F.igemmlt(A, B) + torch.testing.assert_close(C1, C2.float()) # transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) - C1 = torch.matmul(A.float(), B.float()) + # B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) + # C1 = torch.matmul(A.float(), B.float()) - B2t, SBt = F.transform(B, "col_turing", transpose=True) - C2, SC = F.igemmlt(A2, B2t, SA, SBt) - C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_close(C1, C3.float()) + # B2t, SBt = F.transform(B, "col", transpose=True) + # C2, SC = F.igemmlt(A2, B2t, SA, SBt) #B2t, A2, SBt, SA) + # C3, S = F.nvidia_transform(C2, "row", state=SC) + # torch.testing.assert_close(C1, C2.float()) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @@ -622,17 +623,15 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) - C32A, SA = F.transform(CA, "col32") - CxB, SB = F.transform(CB, to_order=formatB) - out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB) - output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt) + out1_32, Sout1_32 = F.igemmlt(CA, CB) + output = F.mm_dequant(out1_32, Sout1_32, statsA, statsB) # print('') # print(output.flatten()[:10]) # print(C1.flatten()[:10]) # print(C2.flatten()[:10]) - # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) # transpose # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) @@ -801,17 +800,18 @@ def test_bench_8bit_training(batch, seq, model, hidden): # print(t8) -@pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) +# @pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) -@pytest.mark.parametrize("formatB", ["col_turing", "col_ampere"], ids=id_formatter("formatB")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) -def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): - inner = torch.randint(1, 128, size=(1,)).item() +def test_dequant_mm(dim1, dim4, dims, has_bias): + inner = 128 # torch.randint(1, 128, size=(1,)).item() bias = None if has_bias: bias = torch.randn(dim4, device="cuda", dtype=torch.float16) - formatB = F.get_special_format_str() + for i in range(1): A = torch.randn(dim1, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda") @@ -822,12 +822,9 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) - A2, SA = F.nvidia_transform(A1, "col32") - B2, SB = F.nvidia_transform(B1, formatB) - C2, SC = F.igemmlt(A2, B2, SA, SB) + C2, SC = F.igemmlt(A1, B1) - C3, S = F.nvidia_transform(C2, "row", state=SC) - C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) if has_bias: C4 += bias @@ -840,8 +837,9 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" - C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) - # torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) + C5 = F.mm_dequant(C2, SC, maxA, maxB, bias=bias) + C5 /= std + torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) n = C5.numel() assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) @@ -890,8 +888,10 @@ def test_colrow_absmax(dim1, dim2, dims): assert nnz_block_ptr2 is None -@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) +# @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) def test_double_quant(dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() @@ -926,9 +926,12 @@ def test_double_quant(dim1, dim2): ( pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") for (dim1, dim4, inner) in zip( - get_test_dims(1, 4 * 1024, n=4), - get_test_dims(1, 4 * 1024, n=4), - get_test_dims(1, 4 * 1024, n=4), + (1, 8, 2048, 4096), + (2, 128, 2048, 4096), + (4, 256, 512, 4096), + # get_test_dims(1, 4 * 1024, n=4), + # get_test_dims(1, 4 * 1024, n=4), + # get_test_dims(1, 4 * 1024, n=4), ) ), ) @@ -949,17 +952,11 @@ def test_integrated_igemmlt(dim1, dim4, inner): torch.testing.assert_close(C1a, A1, rtol=0, atol=1) torch.testing.assert_close(C2a, B1, rtol=0, atol=1) - A2, SA = F.nvidia_transform(C1a, "col32") - B2, SB = F.nvidia_transform(C2a, "col_turing") - outC32, SC = F.igemmlt(A2, B2, SA, SB) - out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + out2, SC = F.igemmlt(A1, B1) - A2, SA = F.nvidia_transform(A1, "col32") - B2, SB = F.nvidia_transform(B1, "col_turing") - C2, SC = F.igemmlt(A2, B2, SA, SB) + C2, SC = F.igemmlt(A1, B1) - C3, S = F.nvidia_transform(C2, "row", state=SC) - out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) err1 = torch.abs(out1 - out2).mean().item() err2 = torch.abs(out1 - out3).mean().item() @@ -999,7 +996,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + outC32, SC = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: @@ -1012,7 +1009,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) - outC32, SC = F.igemmlt(A2, B2, SA, SB) + outC32, SC = F.igemmlt(A2, B2) out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector") @@ -1080,7 +1077,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + outC32, SC = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1089,7 +1086,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt(A2, B2, SA, SB) + outC32, SC = F.igemmlt(A2, B2) torch.cuda.synchronize() print("vector-wise", time.time() - t0) @@ -1132,10 +1129,11 @@ def test_overflow(): a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) - Ca, Sa = F.nvidia_transform(a, "col32") - Cb, Sb = F.nvidia_transform(b, formatB) + # Ca, Sa = F.nvidia_transform(a, "col32") + # Cb, Sb = F.nvidia_transform(b, formatB) - c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) + # c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) + c = F.igemmlt(a, b) c2 = torch.matmul(a.float(), b.float().t()) @@ -1238,25 +1236,21 @@ def test_spmm_bench(): @pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2")) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 - formatB = "col_turing" + # formatB = "col_turing" for i in range(k): A = torch.randn(dim1, dim2).cuda().half() w1 = torch.randn(dim1, dim2).cuda().half() out1 = torch.matmul(A, w1.t()) Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - CTw1, Sw1 = F.transform(Cw1, formatB) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - C32A, SA = F.transform(CA, "col32") - out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out1_32, Sout1_32 = F.igemmlt(CA, Cw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) - C32A, SA = F.transform(CA, "col32") - out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out1_32, Sout1_32 = F.igemmlt(CA, Cw1) out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) assert coo_tensor is not None @@ -1484,7 +1478,12 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): @pytest.mark.parametrize( ("batch", "seq", "model", "hidden"), - [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")], + [ + # pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"), + # pytest.param(2, 128, 6656, 4 * 6656, id="batch=2, seq=128, model=6656, hidden=26k"), + # pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"), + pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k") + ], ) @pytest.mark.benchmark def test_bench_matmul(batch, seq, model, hidden): @@ -1557,19 +1556,45 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # bnb.matmul(A, B) - # torch.cuda.synchronize() - # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul(A, B) + torch.cuda.synchronize() + print( + f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul(A, B, threshold=6.0) + torch.cuda.synchronize() + print( + f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + out32, Sout32 = F.igemmlt(CA, CB) + torch.cuda.synchronize() + print( + f"no overhead igemmlt [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + # C32A, SA = F.transform(CA, "col32") + + # CxB, SB = F.transform(CB, to_order=formatB) # torch.cuda.synchronize() # t0 = time.time() # for i in range(iters): - # bnb.matmul(A, B, threshold=6.0) + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # torch.cuda.synchronize() - # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) # C32A, SA = F.transform(CA, "col32") @@ -1610,21 +1635,25 @@ def test_bench_matmul(batch, seq, model, hidden): # torch.cuda.synchronize() # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # linear8bit(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # linear8bit(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + linear8bit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linear8bit(A) + torch.cuda.synchronize() + print( + f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) - # linearMixedBit(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # linearMixedBit(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + linearMixedBit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linearMixedBit(A) + torch.cuda.synchronize() + print( + f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) # linear8bit_train(A) # torch.cuda.synchronize() @@ -2144,7 +2173,7 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False) err1 = sum(errs1) / len(errs1) / math.sqrt(dim) err2 = sum(errs2) / len(errs2) / math.sqrt(dim) err3 = sum(errs3) / len(errs3) / math.sqrt(dim) diff --git a/tests/test_modules.py b/tests/test_modules.py index 2176f1d48..d5c968395 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -310,7 +310,7 @@ def test_linear8bitlt_inference(threshold): b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) if i == 1: - assert l1.state.CxB is not None + assert l1.state.CB is not None def test_linear8bitlt_accumulated_gradient(): From 0f2dc347448948280e6d774edf1a9399e44f6a7f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:24:45 -0400 Subject: [PATCH 02/65] Fix unintended change --- csrc/ops.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/ops.cu b/csrc/ops.cu index 1f259d67f..8c72b22b4 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -824,8 +824,8 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { - int num_blocks = (m+7)/8; - kgemm_4bit_inference_naive<<< num_blocks, 256, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + int num_blocks = (m+3)/4; + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } From 50fe50ebec81a155cc31ac28e0deac9a8e1006a6 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 9 Oct 2024 10:50:53 -0400 Subject: [PATCH 03/65] New naive mm_dequant kernel for row-major; cleanup --- bitsandbytes/autograd/_functions.py | 28 +---- bitsandbytes/functional.py | 111 ++++++++++++------ csrc/kernels.cu | 170 ++++++++++------------------ csrc/kernels.cuh | 4 +- csrc/ops.cu | 27 ++--- csrc/ops.cuh | 3 +- csrc/pythonInterface.cpp | 60 ++-------- tests/test_functional.py | 58 +++++----- tests/test_modules.py | 4 +- 9 files changed, 189 insertions(+), 276 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 01845a131..e32763f56 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -284,7 +284,9 @@ def tile_indices(self): class MatMul8bitLt(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): + state = state or MatmulLtState() + using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt # default of pytorch behavior if inputs are empty ctx.is_empty = False @@ -417,8 +419,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - clone_func = torch.clone if len(output_shape) == 3 else lambda x: x - return clone_func(output.view(output_shape)) + return output.reshape(output_shape) @staticmethod def backward(ctx, grad_output): @@ -442,37 +443,18 @@ def backward(ctx, grad_output): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: - # CxAt, SAt = F.transform(CAt, formatB, transpose=True) - # C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) - # gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) - gradB32, SgradB32 = F.igemmlt( - Cgradt.t(), CAt.t() - ) # issue here in test_linear_serialization w/ has fp16 weights + gradB32, SgradB32 = F.igemmlt(Cgradt.t(), CAt.t()) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: if state.CBt is not None: - # C32grad, Sgrad = F.transform(Cgrad, "col32") - # if state.CxBt is None: - # state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) - # gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t()) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - elif state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) - # elif state.CxB is not None: - # CB = ( - # undo_layout(state.CxB, state.tile_indices) - # .to(ctx.dtype_A) - # .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - # ) - # grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: raise Exception("State must contain either CBt or CB matrix for backward") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7f07778ef..d59fc8778 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2330,7 +2330,7 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): ldb = shapeB[-1] # Activations (batch, tokens, inputs) ldc = shapeC[-1] # Output (batch, tokens, outputs) - assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ B={shapeA}" + assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" prev_device = A.device torch.cuda.set_device(A.device) @@ -2361,18 +2361,25 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): return out, Sout -def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): +def mm_dequant_torch( + A: torch.Tensor, + quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format) + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats=None, # TODO: unused + new_col_stats=None, # TODO: unused + bias: Optional[torch.Tensor] = None, +): assert A.dtype == torch.int32 - compute_dtype = torch.float32 - - A_calc = A.view(-1, A.shape[-1]).to(compute_dtype) - row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) - col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) + A_calc = A.view(-1, A.shape[-1]) + row_stats = row_stats.reshape(-1).unsqueeze(-1) + col_stats = col_stats.reshape(-1).unsqueeze(0) # TODO support out != None - out = A_calc * (row_stats * col_stats) * 6.200124e-5 # .to(torch.float16) + out = A_calc * (row_stats * col_stats) * 6.200124e-5 if bias is not None: # assert bias.dtype == torch.float16 @@ -2381,42 +2388,40 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non return out.to(torch.float16) -def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): +def mm_dequant( + A: torch.Tensor, + quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format) + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats=None, # TODO: unused + new_col_stats=None, # TODO: unused + bias: Optional[torch.Tensor] = None, +): assert A.dtype == torch.int32 + if bias is not None: assert bias.dtype == torch.float16 - out_shape = quant_state[0] - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) if out is None: - out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: - new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) - if new_col_stats is None: - new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) - assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" - assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" + out = torch.empty_like(A, dtype=torch.float16) - prev_device = pre_call(A.device) ptrA = get_ptr(A) ptrOut = get_ptr(out) ptrRowStats = get_ptr(row_stats) ptrColStats = get_ptr(col_stats) - ptrNewRowStats = get_ptr(new_row_stats) - ptrNewColStats = get_ptr(new_col_stats) ptrBias = get_ptr(bias) - numRows = ct.c_int32(out_shape[0]) - numCols = ct.c_int32(out_shape[1]) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + is_on_gpu([A, row_stats, col_stats, out, bias]) - is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) + prev_device = pre_call(A.device) lib.cdequant_mm_int32_fp16( ptrA, ptrRowStats, ptrColStats, ptrOut, - ptrNewRowStats, - ptrNewColStats, ptrBias, numRows, numCols, @@ -2426,7 +2431,33 @@ def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats return out -def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): +def get_colrow_absmax( + A: torch.Tensor, + row_stats: torch.Tensor = None, + col_stats: torch.Tensor = None, + nnz_block_ptr: torch.Tensor = None, + threshold=0.0, +): + # Note: prior impl only works with fp16 + assert A.is_floating_point() + + if row_stats is None or col_stats is None: + absA = A.abs().view(-1, A.shape[-1]) # view as 2D + if row_stats is None: + # shape [rows]; unsqueeze(-1) gives [rows,1] + row_stats = absA.amax(dim=1, keepdim=False).float() + if col_stats is None: + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + # TODO: threshold support + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros_like(A, dtype=torch.int32) + + return row_stats, col_stats, nnz_block_ptr + + +def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): assert A.dtype == torch.float16 device = A.device @@ -2543,19 +2574,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) +@torch.compile def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - # TODO: Optimize/write CUDA kernel for this. Currently vectorwise_quant will recalculate row/col stats. + # TODO: Optimize/write CUDA kernel for this # TODO: Support threshold - # if out_col is None: - # out_col = torch.zeros(A.shape, device=A.device, dtype=torch.int8) - # if out_row is None: - # out_row = torch.zeros(A.shape, device=A.device, dtype=torch.int8) + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + + scaled_A = A.mul(C) + + # quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8) + # quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8) + quant_row = torch.round(scaled_A / row_stats.unsqueeze(-1)).to(torch.int8) + quant_col = torch.round(scaled_A / col_stats.unsqueeze(0)).to(torch.int8) - out_col, Scol = vectorwise_quant(A, dim=0) - out_row, Srow = vectorwise_quant(A, dim=1) + if out_row is not None: + quant_row = out_row.copy_(quant_row) + if out_col is not None: + quant_col = out_col.copy_(quant_col) - return out_row, out_col, Srow.flatten().float(), Scol.flatten().float(), None # coo_tensor + return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), None def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): diff --git a/csrc/kernels.cu b/csrc/kernels.cu index be7779de1..5bdcb1a41 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -219,7 +219,7 @@ __device__ half dhDequantizeNF4(unsigned char val) } -__device__ float dDequantizeNF4(unsigned char val) +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { // the values for this tree was generated by test_normal_map_tree @@ -722,7 +722,7 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { if(DATA_TYPE > 0) { @@ -734,7 +734,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; } - local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); + //local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); @@ -2291,128 +2292,68 @@ template __global__ void kgetColRowStats(half * __rest #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) -template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n) -{ +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + half *out, + half *__restrict__ const bias, + const int numRows, + const int numCols, + const int n +) { + const int n_out = numRows * numCols; - // Strategy: To dequantize we need to load col/row statistics. This can be very expensive - // since different row/col stats need to be loaded with each thread. - // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure - // and would lead to low global load utilization. - // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads - // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. - // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. - // This allows for efficient row/col loading from shared memory within the tile. - // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has - // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts - // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the - // shared memory loads. - - // data is in 32 column-tile major with tile width 32 columns and numRows rows - // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) - // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) - // C2. Compute normalization values and store col values in register - // S1. Store C1 into 16-bit output - // S2. Store col/row statistics of new buffer in shared memory - - // We allow for sub-tiles to span multiple col32 tiles. This is okay - // since the items per thread only rely on a single column statistic. - - - const int n_out = numRows*numCols; - - int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); - // we have tiles of size numRows*32, thus col only increases every numRows - // num_row_tiles is the tiles after which the column increases by 32 - // blockIdx.x is the index of the current tile - int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); - // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached - int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); - - // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS - // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD - // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. - // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have - // 1024*1024/(128*32) = 256 tiles - // 256 tiles are 256*128*32/4 = 256*1024 threads - - // 1. Figure out how index relates to the start of the sub-tile - // 2. Each thread < SUBTILE_ROWS calculates row index - // 3. Load striped and store in shared memory + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; - __shared__ float smem_rowStats[SUBTILE_ROWS]; + float local_colStats[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; - typedef cub::BlockLoad LoadInt32; - typedef cub::BlockExchange ExchangeInt32; + typedef cub::BlockLoad LoadInt32; __shared__ typename LoadInt32::TempStorage loadint32; - __shared__ typename ExchangeInt32::TempStorage exchangeint32; + int row_idx, col_idx; - // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - float colStat = col >= numCols ? 0.0f : colStats[col]; - float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); - // no block loads for rows for now -- keep it simple - for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) - { - // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? - int row = (base_row+j) % numRows; // wrap around - // each warp accesses the same element, for four consequitive elements - // todo: update description about striped shared memory, it is not needed - // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements - smem_rowStats[j] = rowStats[row]; - } - __syncthreads(); - + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { - // each block processes SUBTILE_ROWS*32 elements - const int items_per_load = THREADS*ITEMS_PER_THREAD; - const int rows_per_load = items_per_load/32; + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; - int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile - int row_offset = 0; - // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed - int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32); - for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) - { - int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); - int valid_items = valid_rows*32; - if(valid_items <= 0) // the sub-tile might have more elements than the tile itself - break; + local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]); + } - // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) - LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); - ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + // Each block loads THREADS * ITEMS_PER_THREAD values from A + int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out + ? THREADS * ITEMS_PER_THREAD + : n_out - block_offset; + __syncthreads(); + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { + local_output[j] = __float2half( + fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) + ); + } - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); - //absmax_col = fmax(fabsf(local_output[j]), absmax_col); - - // we store data in row major - // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] - // so that each thread holds ITEMS_PER_THREAD consecutive items for each row - // this way throughput into storage is increased by a factor of ~2x - // for now we use a simple store - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); - if(outIdx< n_out && col < numCols) - out[outIdx] = local_output[j]; + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int outIdx = block_offset + thread_offset + j; + if (outIdx < n_out) { + out[outIdx] = local_output[j]; } - - row_offset += rows_per_load; } } - template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) { // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD @@ -3525,17 +3466,20 @@ template __global__ void kgemm_4bit_inferenc __shared__ T quant_map[16]; T local_absmax = T(0.0f); - for(int i = threadIdx.x; i < 16; i++) - quant_map[i] = T(datatype[i]); + if (threadIdx.x < 16) + quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); + //for(int i = threadIdx.x; i < 16; i++) + //quant_map[i] = T(__ldg(&datatype[i])); __syncthreads(); // A: [1, K] // B: [N, K] for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) { - int inner_idx_halved = inner_idx/2; - int offset_B = ldb*row_B; - int absidx = ((2*offset_B)+inner_idx)/blocksize; + const int inner_idx_halved = inner_idx/2; + const int offset_B = ldb*row_B; + const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); + //int absidx = ((2*offset_B)+inner_idx)/blocksize; local_absmax = __ldg(&(absmax[absidx])); if(row_B < M) @@ -3810,7 +3754,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>( template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); +template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ec6daebe5..1e094dbd2 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -112,9 +112,9 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kdequant_mm_int32_fp16( +template __global__ void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, - half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); + half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); diff --git a/csrc/ops.cu b/csrc/ops.cu index 8c72b22b4..f3d349a41 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -584,19 +584,15 @@ int fill_up_to_nearest_multiple(int value, int multiple) return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols) { - int threads = 512; - int tileCols = fill_up_to_nearest_multiple(numCols, 32); - int n = numRows*tileCols; - int subtile_rows = 128; - int tilesize = 32*subtile_rows; - int num_blocks = numRows/subtile_rows; - num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; - num_blocks = num_blocks*(tileCols/32); - assert(threads <= tilesize); - - kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); + const int threads = 512; + const int num_per_thread = 4; + const int num_per_block = threads * num_per_thread; + const int n = numRows*numCols; + const int num_blocks = (n + num_per_block - 1) / num_per_block; + + kdequant_mm_int32_fp16<<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -861,17 +857,10 @@ template void extractOutliers(char * A, int *idx, char *out, int idx template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); - template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index ab0185242..9ecb93bf2 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -176,11 +176,10 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i long long int strideA, long long int strideB, long long int strideC, int batchCount); template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 09b9b62a9..0034db262 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -5,6 +5,7 @@ #if BUILD_CUDA #include +uint abc; #endif #if BUILD_MPS // #include @@ -175,32 +176,15 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } - int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); - } - int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); - } - int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); - } - int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } +int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} +int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} +int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } @@ -335,26 +319,6 @@ extern "C" return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - //{ (cublasLtHandle_t)context->m_handle; return 0; } - //{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ { \ @@ -370,8 +334,8 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) - void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols) - { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); } + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols) + { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 522af516c..5052909e7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -847,45 +847,41 @@ def test_dequant_mm(dim1, dim4, dims, has_bias): @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) -def test_colrow_absmax(dim1, dim2, dims): +@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp")) +def test_colrow_absmax(dim1, dim2, dims, threshold): for i in range(k): - threshold = 3.0 A = torch.randn(dim1, dim2, device="cuda").half() - A_truncated = A.clone() - A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0 - if dims == 2: - row_stats1, _ = torch.abs(A.float()).max(1) - col_stats1, _ = torch.abs(A.float()).max(0) + + assert dims == 2 + + row_stats1, _ = torch.abs(A.float()).max(1) + col_stats1, _ = torch.abs(A.float()).max(0) + + if threshold > 0.0: + A_truncated = A.clone() + A_truncated[torch.abs(A_truncated) >= threshold] = 0.0 row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) - else: - assert False - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) - A_blocked = einops.rearrange( - torch.abs(A), - "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size", - row_tiles=16, - block_size=64 * 4, - ) - nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten() - nnz_block_ptr1 = torch.zeros( - nnz_rows1_counts.shape[0] + 1, - dtype=nnz_rows1_counts.dtype, - device=nnz_rows1_counts.device, - ) - nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) - - torch.testing.assert_close(col_stats1_trunc, col_stats2) - torch.testing.assert_close(row_stats1_trunc, row_stats2) - torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) + nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten() + nnz_block_ptr1 = torch.zeros( + nnz_rows1_counts.shape[0] + 1, + dtype=nnz_rows1_counts.dtype, + device=nnz_rows1_counts.device, + ) + nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + torch.testing.assert_close(col_stats1_trunc, col_stats2) + torch.testing.assert_close(row_stats1_trunc, row_stats2) + torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) + else: + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + assert nnz_block_ptr2 is None torch.testing.assert_close(col_stats1, col_stats2) torch.testing.assert_close(row_stats1, row_stats2) - assert nnz_block_ptr2 is None # @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) @@ -1480,9 +1476,9 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): ("batch", "seq", "model", "hidden"), [ # pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"), - # pytest.param(2, 128, 6656, 4 * 6656, id="batch=2, seq=128, model=6656, hidden=26k"), + pytest.param(1, 1, 3584, 512, id="batch=1, seq=128, model=3584, hidden=19k"), # pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"), - pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k") + # pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k") ], ) @pytest.mark.benchmark diff --git a/tests/test_modules.py b/tests/test_modules.py index d5c968395..7369bb1cf 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -335,8 +335,8 @@ def test_linear8bitlt_accumulated_gradient(): loss1.backward() loss2.backward() if i == 2: - assert l1[0].state.CxB is not None - assert l1[1].state.CxB is not None + assert l1[0].state.CB is not None + assert l1[1].state.CB is not None if i > 0 and i % acc_steps == 0: opt1.step() From 57e6427c1259c8f18d4959ad13cee9a7c4b97f36 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 9 Oct 2024 11:02:48 -0400 Subject: [PATCH 04/65] fix --- csrc/pythonInterface.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 0034db262..b03b0650c 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -5,7 +5,6 @@ #if BUILD_CUDA #include -uint abc; #endif #if BUILD_MPS // #include From ca372f2b2784332bbce488ec640941ad93e1ff80 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 14 Oct 2024 14:43:52 -0400 Subject: [PATCH 05/65] int8 refactor: initial sparse decomp, cleanup --- bitsandbytes/autograd/_functions.py | 142 ++--- bitsandbytes/cextension.py | 3 + bitsandbytes/functional.py | 578 +++++++++---------- bitsandbytes/research/autograd/_functions.py | 8 +- csrc/kernels.cu | 30 +- csrc/ops.cu | 6 - tests/test_functional.py | 31 +- tests/test_linear8bitlt.py | 11 +- tests/test_modules.py | 15 +- 9 files changed, 390 insertions(+), 434 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e32763f56..bc7a51113 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -from functools import reduce # Required in Python 3 -import operator +from math import prod from typing import Callable, Optional, Tuple import warnings from warnings import warn @@ -9,12 +8,6 @@ import bitsandbytes.functional as F - -# math.prod not compatible with python < 3.8 -def prod(iterable): - return reduce(operator.mul, iterable, 1) - - # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py @@ -284,10 +277,16 @@ def tile_indices(self): class MatMul8bitLt(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): - state = state or MatmulLtState() + def forward( + ctx: torch.autograd.function.FunctionCtx, + A: torch.Tensor, + B: torch.Tensor, + out=None, + bias: Optional[torch.Tensor] = None, + state=MatmulLtState, + ): + # state = state or MatmulLtState() - using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -300,14 +299,7 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): else: return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) - # 1. Quantize A - # 2. Quantize B - # 3. Matmul - # 4. Mixed-precision decomposition matmul - # 5. Save state input_shape = A.shape - if state.outlier_pool is None: - state.outlier_pool = GlobalOutlierPooler.get_instance() # Cast A to fp16 if A.dtype != torch.float16: @@ -318,20 +310,10 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): A = A.reshape(-1, A.shape[-1]) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) - if state.threshold > 0.0 and coo_tensorA is not None: - if state.has_fp16_weights: - idx = torch.unique(coo_tensorA.colidx).long() - CA[:, idx] = 0 - CAt[:, idx] = 0 - subA = A[:, idx] - state.subB = B[:, idx].t().contiguous() - state.idx = idx - else: - subA = None + has_grad = False - # 2. Quantize B if state.has_fp16_weights: - has_grad = True if (getattr(B, "grad", None) is not None) else False + has_grad = getattr(B, "grad", None) is not None is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) if is_transposed: B = B.contiguous() @@ -339,71 +321,46 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): if (state.is_training and not has_grad) or state.CB is None: state.reset_grads() - # quantize... + # 2. Quantize B ( state.CB, state.CBt, state.SCB, state.SCBt, - coo_tensorB, + _, ) = F.double_quant(B.to(torch.float16)) - else: - has_grad = False - - if coo_tensorA is not None and not state.has_fp16_weights: - # extract outliers - - outlier_idx = torch.unique(coo_tensorA.colidx) - state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx - - # if state.CxB is not None: - # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - # else: - outliers = state.CB[:, state.idx.long()].clone() - - state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) - CA[:, state.idx.long()] = 0 - CAt[:, state.idx.long()] = 0 - subA = A[:, state.idx.long()] - - shapeB = state.CB.shape - - if len(input_shape) == 3: - output_shape = (input_shape[0], input_shape[1], shapeB[0]) - else: - output_shape = (input_shape[0], shapeB[0]) - # 3. Matmul - if using_igemmlt: - out32, Sout32 = F.igemmlt(CA, state.CB) + if state.threshold > 0.0 and coo_tensorA is not None: + state.idx = torch.unique(coo_tensorA._indices()[1]).long() + + # Zero out the outliers in the int8 inputs + CA[:, state.idx] = 0 + CAt[:, state.idx] = 0 - if bias is None or bias.dtype == torch.float16: - # we apply the fused bias here - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A.dtype) - else: # apply bias separately - # TODO: Fused bias for fp32/bf16? - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A.dtype).add_(bias) + # Extract the input outliers in original precision + subA = A[:, state.idx] + # Extract the corresponding weights + if state.has_fp16_weights: + state.subB = B[:, state.idx].t().contiguous() + else: + outliers = state.CB[:, state.idx].clone() + state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) else: - A_wo_outliers = A.clone() - if state.idx is not None: - A_wo_outliers[:, state.idx.long()] = 0 - output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype)) - output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0)) - if bias is not None: - output = output.add_(bias) + subA = state.subB = None + + # 3. Int8 Matmul + out32, Sout32 = F.igemmlt(CA, state.CB) + if bias is None or bias.dtype == torch.float16: + # we apply the fused bias here + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias).to(A.dtype) + else: # apply bias separately + # TODO: Fused bias for fp32/bf16? + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul - if coo_tensorA is not None and subA is not None: - output += torch.matmul(subA, state.subB) + if subA is not None and state.subB is not None: + output += torch.matmul(subA, state.subB.to(subA.dtype)) # 5. Save state ctx.state = state @@ -419,7 +376,8 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - return output.reshape(output_shape) + output_shape = (*input_shape[:-1], state.CB.shape[0]) + return output.reshape(output_shape).clone() @staticmethod def backward(ctx, grad_output): @@ -441,16 +399,24 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) if req_gradB: - gradB32, SgradB32 = F.igemmlt(Cgradt.t(), CAt.t()) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + # grad_output.T @ A + # grad_weight = grad_output.t().mm(A) + grad_B = torch.matmul(grad_output.t(), A) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + # if req_gradB: + # + # gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) + # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + # if state.threshold > 0.0 and subA is not None: + # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: + # grad_output @ B.T if state.CBt is not None: - gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t()) + gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 45573538e..b7522334c 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -67,6 +67,9 @@ def __init__(self, lib: ct.CDLL): def __getattr__(self, item): return getattr(self._lib, item) + def __getitem__(self, item): + return getattr(self._lib, item) + class CudaBNBNativeLibrary(BNBNativeLibrary): compiled_with_cuda = True diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d59fc8778..8d7226b2c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -5,7 +5,7 @@ import ctypes as ct import itertools from math import prod -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -845,8 +845,7 @@ def quantize_blockwise( if absmax is None: n = A.numel() - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 + blocks = -(n // -blocksize) absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) if out is None: @@ -854,40 +853,31 @@ def quantize_blockwise( if A.device.type != "cpu": assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - cblocksize = ct.c_int32(blocksize) - prev_device = pre_call(A.device) + code = code.to(A.device) is_on_gpu([code, A, out, absmax]) - if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - cblocksize, - ct.c_int(A.numel()), - ) - elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - cblocksize, - ct.c_int(A.numel()), - ) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16( + + fn_map = { + torch.float32: "cquantize_blockwise_fp32", + torch.bfloat16: "cquantize_blockwise_bf16", + torch.float16: "cquantize_blockwise_fp16", + } + + if A.dtype not in fn_map.keys(): + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + fn = fn_map[A.dtype] + + with torch.cuda.device_of(A): + lib[fn]( get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), - cblocksize, + ct.c_int32(blocksize), ct.c_int(A.numel()), ) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) + else: # cpu code = code.cpu() @@ -972,47 +962,34 @@ def dequantize_blockwise( out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) if A.device.type != "cpu": - device = pre_call(A.device) code = quant_state.code.to(A.device) if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError( f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", ) is_on_gpu([A, absmax, out]) - stream = get_tensor_stream(A) - if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - stream, # Used the _as_parameter_ attribute of torch.cuda.Stream, Similarly for the following - ) - elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - stream, - ) - elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16( + + fn_map = { + torch.float32: "cdequantize_blockwise_fp32", + torch.bfloat16: "cdequantize_blockwise_bf16", + torch.float16: "cdequantize_blockwise_fp16", + } + + if out.dtype not in fn_map.keys(): + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + + fn = fn_map[out.dtype] + + with torch.cuda.device_of(A): + lib[fn]( get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()), - stream, + get_tensor_stream(A), ) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) else: code = quant_state.code.cpu() lib.cdequantize_blockwise_cpu_fp32( @@ -1174,8 +1151,7 @@ def quantize_4bit( input_shape = A.shape if absmax is None: - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 + blocks = -(n // -blocksize) absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) if out is None: @@ -1184,68 +1160,72 @@ def quantize_4bit( assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - prev_device = pre_call(A.device) is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + with torch.cuda.device_of(A): + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + with torch.cuda.device_of(A): + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.float16: if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + with torch.cuda.device_of(A): + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + with torch.cuda.device_of(A): + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.bfloat16: if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + with torch.cuda.device_of(A): + lib.cquantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + with torch.cuda.device_of(A): + lib.cquantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) code = get_4bit_type(quant_type, device=A.device) @@ -1363,77 +1343,80 @@ def dequantize_4bit( n = out.numel() - device = pre_call(A.device) is_on_gpu([A, absmax, out]) stream = get_tensor_stream(A) if out.dtype == torch.float32: if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) + with torch.cuda.device_of(A): + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) else: - lib.cdequantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) + with torch.cuda.device_of(A): + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) elif out.dtype == torch.float16: if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) + with torch.cuda.device_of(A): + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) else: - lib.cdequantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) + with torch.cuda.device_of(A): + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) elif out.dtype == torch.bfloat16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - else: - lib.cdequantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) + with torch.cuda.device_of(A): + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) + else: + lib.cdequantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - is_transposed = True if A.shape[0] == 1 else False + is_transposed = A.shape[0] == 1 if is_transposed: return out.t() else: @@ -1995,10 +1978,9 @@ def gemv_4bit( transposed_B=False, state=None, ): - prev_device = pre_call(A.device) # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )") + raise ValueError("state cannot None. gemv_4bit() requires the state from quantize_4bit()") if A.numel() != A.shape[-1]: raise ValueError( @@ -2032,62 +2014,64 @@ def gemv_4bit( ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) stream = get_tensor_stream(A) - if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) + + with torch.cuda.device_of(A): + if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + stream, + ) + else: + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + else: raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") - else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") - - post_call(prev_device) + # post_call(prev_device) return out @@ -2332,62 +2316,35 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" - prev_device = A.device - torch.cuda.set_device(A.device) - - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = get_ptr(None) - m, n, k, lda, ldb, ldc = map(ct.c_int32, (m, n, k, lda, ldb, ldc)) - is_on_gpu([A, B, out]) - if dtype == torch.int32: - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - else: - has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + with torch.cuda.device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = get_ptr(None) + m, n, k, lda, ldb, ldc = map(ct.c_int32, (m, n, k, lda, ldb, ldc)) + + if dtype == torch.int32: + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + else: + has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` raise NotImplementedError("igemmlt not implemented!") if has_error: - print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") - raise Exception("cublasLt ran into an error!") - - torch.cuda.set_device(prev_device) + raise RuntimeError( + f"cublasLt ran into an error!\n" + f"\tA: {shapeA}, B: {shapeB}, C: {Sout[0]}\n" + f"\t(lda, ldb, ldc): {(lda, ldb, ldc)}\n" + f"\t(m, n, k): {(m, n, k)}" + ) return out, Sout -def mm_dequant_torch( - A: torch.Tensor, - quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format) - row_stats: torch.Tensor, - col_stats: torch.Tensor, - out: Optional[torch.Tensor] = None, - new_row_stats=None, # TODO: unused - new_col_stats=None, # TODO: unused - bias: Optional[torch.Tensor] = None, -): - assert A.dtype == torch.int32 - - A_calc = A.view(-1, A.shape[-1]) - row_stats = row_stats.reshape(-1).unsqueeze(-1) - col_stats = col_stats.reshape(-1).unsqueeze(0) - - # TODO support out != None - - out = A_calc * (row_stats * col_stats) * 6.200124e-5 - - if bias is not None: - # assert bias.dtype == torch.float16 - out.add_(bias) - - return out.to(torch.float16) - - def mm_dequant( A: torch.Tensor, quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format) @@ -2416,17 +2373,16 @@ def mm_dequant( is_on_gpu([A, row_stats, col_stats, out, bias]) - prev_device = pre_call(A.device) - lib.cdequant_mm_int32_fp16( - ptrA, - ptrRowStats, - ptrColStats, - ptrOut, - ptrBias, - numRows, - numCols, - ) - post_call(prev_device) + with torch.cuda.device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, + ptrRowStats, + ptrColStats, + ptrOut, + ptrBias, + numRows, + numCols, + ) return out @@ -2441,8 +2397,21 @@ def get_colrow_absmax( # Note: prior impl only works with fp16 assert A.is_floating_point() + outlier_mask = None + if row_stats is None or col_stats is None: - absA = A.abs().view(-1, A.shape[-1]) # view as 2D + # view as 2D + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # For parity with tests build nnz_block_ptr. + nnz_block_ptr = torch.zeros(absA.shape[0] + 1, dtype=torch.int64, device=A.device) + nnz_block_ptr[1:] = outlier_mask.sum(1).cumsum(0) + if row_stats is None: # shape [rows]; unsqueeze(-1) gives [rows,1] row_stats = absA.amax(dim=1, keepdim=False).float() @@ -2450,11 +2419,7 @@ def get_colrow_absmax( # shape [cols]; unsqueeze(0) gives [1,cols] col_stats = absA.amax(dim=0, keepdim=False).float() - # TODO: threshold support - if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros_like(A, dtype=torch.int32) - - return row_stats, col_stats, nnz_block_ptr + return row_stats, col_stats, outlier_mask def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): @@ -2496,7 +2461,9 @@ def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, class COOSparseTensor: - def __init__(self, rows, cols, nnz, rowidx, colidx, values): + def __init__( + self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor + ): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 assert values.dtype == torch.float16 @@ -2574,16 +2541,26 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -@torch.compile -def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): +# @torch.compile +def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): # TODO: Optimize/write CUDA kernel for this - # TODO: Support threshold if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + row_stats, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) - scaled_A = A.mul(C) + if threshold > 0.0: + # Extract outliers to COO tensor: + # 1. Zero out all of the non-outliers, convert to COO. + # 2. Zero out the outliers in the dense tensor. + # TODO we could improve perf of this + # is_outlier = A.abs() >= threshold + coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() + A = A.masked_fill(outlier_mask, 0.0) + else: + coo_tensor = None + # Quantize + scaled_A = A.mul(C) # quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8) # quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8) quant_row = torch.round(scaled_A / row_stats.unsqueeze(-1)).to(torch.int8) @@ -2594,7 +2571,7 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, if out_col is not None: quant_col = out_col.copy_(quant_col) - return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), None + return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), coo_tensor def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): @@ -2735,7 +2712,22 @@ def transform(A, to_order, from_order="row", out=None, transpose=False, state=No return out, new_state -def spmm_coo(cooA, B, out=None): +def spmm_coo(cooA: Union[COOSparseTensor, torch.Tensor], B: torch.Tensor, out: torch.Tensor = None): + if not isinstance(cooA, COOSparseTensor): + assert ( + cooA.is_sparse and cooA.layout == torch.sparse_coo + ), "Tensor must be `COOSparseTensor or a PyTorch COO tensor." + + # Convert to custom COOSparseTensor + cooA = COOSparseTensor( + rows=cooA.shape[0], + cols=cooA.shape[1], + nnz=cooA._nnz(), + rowidx=cooA.indices()[0].int(), + colidx=cooA.indices()[1].int(), + values=cooA.values(), + ) + if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) nnz = cooA.nnz diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 5f8b2c437..3e807d6e1 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -219,7 +219,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: - idx = torch.unique(coo_tensorA.colidx).long() + # idx = torch.unique(coo_tensorA.colidx).long() + idx = torch.unique(coo_tensorA._indices()[1]).long() CA[:, idx] = 0 CAt[:, idx] = 0 subA = A[:, idx] @@ -257,7 +258,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 if coo_tensorA is not None and not state.has_fp16_weights: # extract outliers - outlier_idx = torch.unique(coo_tensorA.colidx) + # outlier_idx = torch.unique(coo_tensorA.colidx) + outlier_idx = torch.unique(coo_tensorA._indices()[1]).long() state.idx = outlier_idx # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: @@ -339,7 +341,7 @@ def backward(ctx, grad_output): if req_gradA: if state.CBt is not None: - gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t()) + gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 5bdcb1a41..34de9d5ca 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -627,7 +627,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float for(int i = threadIdx.x; i < 256; i+=blockDim.x) smem_code[i] = code[i]; - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; local_abs_max = -FLT_MAX; @@ -645,19 +645,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); - if(threadIdx.x == 0) - smem_absmax_value[0] = local_abs_max; - + if (threadIdx.x == 0) { + smem_absmax_value[0] = 1.0f / local_abs_max; + absmax[i / BLOCK_SIZE] = local_abs_max; + } __syncthreads(); - if(threadIdx.x == 0) - absmax[i/BLOCK_SIZE] = local_abs_max; - else - local_abs_max = smem_absmax_value[0]; - - __syncwarp(); - - local_abs_max = 1.0f/local_abs_max; + local_abs_max = smem_absmax_value[0]; if(STOCHASTIC) { @@ -724,15 +718,15 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { - if(DATA_TYPE > 0) + if (DATA_TYPE > 0) { - valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; - valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); + valid_items_store = min(TILE_SIZE * 2, n - i * 2); } else { - valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; - valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_load = min(TILE_SIZE, n - i); + valid_items_store = valid_items_load; } local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); //local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); @@ -740,7 +734,7 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - switch(DATA_TYPE) + switch (DATA_TYPE) { case General8bit: // load code through read-only cache via __ldg diff --git a/csrc/ops.cu b/csrc/ops.cu index f3d349a41..089a30cc1 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -436,17 +436,11 @@ template int igemmlt( // // Use the IMMA kernels requires: // * A must be transposed and B must be non-transposed. - // * All leading dimensions must be multiples of 4. // * Dimensions m and k must be multiples of 4. // * All pointers must be 4-byte aligned; 16-byte alignment preferred. - // - int has_error = 0; - // this is the default - cublasLtOrder_t col_major = CUBLASLT_ORDER_COL; - cublasLtMatmulDesc_t matmulDesc; cublasLtMatrixLayout_t aDesc, bDesc, cDesc; cublasOperation_t opT = CUBLAS_OP_T; diff --git a/tests/test_functional.py b/tests/test_functional.py index 5052909e7..9b7004946 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -875,13 +875,12 @@ def test_colrow_absmax(dim1, dim2, dims, threshold): torch.testing.assert_close(col_stats1_trunc, col_stats2) torch.testing.assert_close(row_stats1_trunc, row_stats2) - torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) + # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2) else: row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) assert nnz_block_ptr2 is None - - torch.testing.assert_close(col_stats1, col_stats2) - torch.testing.assert_close(row_stats1, row_stats2) + torch.testing.assert_close(col_stats1, col_stats2) + torch.testing.assert_close(row_stats1, row_stats2) # @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) @@ -1122,32 +1121,32 @@ def test_overflow(): formatB = F.get_special_format_str() print(formatB) for i in range(2): - a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) - b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) + a = torch.arange(0, 16).cuda().to(torch.int8).view(-1, 4).contiguous() + b = torch.arange(0, 16).cuda().to(torch.int8).view(-1, 4).contiguous() # Ca, Sa = F.nvidia_transform(a, "col32") # Cb, Sb = F.nvidia_transform(b, formatB) # c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) - c = F.igemmlt(a, b) + c = F.igemmlt(a, b, dtype=torch.int8) c2 = torch.matmul(a.float(), b.float().t()) -@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) +# @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(dim1, dim2): threshold = 3.00 for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() idx = torch.abs(A) >= threshold - CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) if coo_tensor is not None: A1 = A * idx - A2 = torch.zeros_like(A) - A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values + A2 = coo_tensor.to_dense() torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) @@ -1228,8 +1227,10 @@ def test_spmm_bench(): print(tsp / t8) -@pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2")) +# @pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1")) +# @pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2")) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 # formatB = "col_turing" @@ -1252,6 +1253,8 @@ def test_integrated_sparse_decomp(dim1, dim2): assert coo_tensor is not None out4 = F.spmm_coo(coo_tensor, w1.t()) + # idx = torch.unique(coo_tensor._indices()[1]).long() + # out4 = torch.matmul(A, w1.t()) out5 = out3 + out4 err1 = torch.abs(out1 - out2).mean().item() diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 9b7923312..149d9a93c 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -79,14 +79,13 @@ def test_linear_no_igemmlt(): @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) -@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) +# @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) def test_linear_serialization( has_fp16_weights, serialize_before_forward, deserialize_before_cuda, - force_no_igemmlt, save_before_forward, load_before_cuda, ): @@ -100,8 +99,8 @@ def test_linear_serialization( has_fp16_weights=has_fp16_weights, threshold=6.0, ) - if force_no_igemmlt: - linear_custom.state.force_no_igemmlt = True + # if force_no_igemmlt: + # linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( linear.weight.data.clone(), @@ -147,8 +146,8 @@ def test_linear_serialization( has_fp16_weights=has_fp16_weights, threshold=6.0, ) - if force_no_igemmlt: - new_linear_custom.state.force_no_igemmlt = True + # if force_no_igemmlt: + # new_linear_custom.state.force_no_igemmlt = True if deserialize_before_cuda: with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): diff --git a/tests/test_modules.py b/tests/test_modules.py index 7369bb1cf..1f1b17584 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -528,15 +528,17 @@ def test_linear_kbit_fp32_bias(module): @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) def test_kbit_backprop(module): - b = 17 - dim1 = 37 - dim2 = 83 + b = 16 + dim1 = 32 + dim2 = 48 + # dim1 = 37 + # dim2 = 83 - ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)]) + ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 16)]) ref[1].weight.requires_grad = False torch.nn.init.kaiming_normal_(ref[0].weight) torch.nn.init.kaiming_normal_(ref[1].weight) - kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)]) + kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 16)]) kbit[0].weight.detach().copy_(ref[0].weight) kbit[1].weight.detach().copy_(ref[1].weight) kbit[0].bias.detach().copy_(ref[0].bias) @@ -570,7 +572,8 @@ def test_kbit_backprop(module): relerrs1.append(relerr1.mean().item()) relerrs2.append(relerr2.mean().item()) - if isinstance(module, bnb.nn.Linear8bitLt): + # if isinstance(module, bnb.nn.Linear8bitLt): + if module == bnb.nn.Linear8bitLt: assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1) torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) else: From 510a8808542064b16ab06ef2a7e973c91ed3c9dd Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 14 Oct 2024 17:20:08 -0400 Subject: [PATCH 06/65] Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup --- .github/scripts/build-cuda.sh | 30 ++++++++++++++--------------- CMakeLists.txt | 14 +------------- bitsandbytes/autograd/_functions.py | 30 +++++++++++++++-------------- bitsandbytes/cextension.py | 6 +----- bitsandbytes/cuda_specs.py | 2 +- bitsandbytes/diagnostics/cuda.py | 4 ++-- csrc/ops.cu | 15 --------------- tests/conftest.py | 4 ---- tests/test_cuda_setup_evaluator.py | 20 ------------------- tests/test_linear8bitlt.py | 8 +++++--- tests/test_modules.py | 14 +++++++------- 11 files changed, 48 insertions(+), 99 deletions(-) diff --git a/.github/scripts/build-cuda.sh b/.github/scripts/build-cuda.sh index 0f9b8d726..26a7075b0 100644 --- a/.github/scripts/build-cuda.sh +++ b/.github/scripts/build-cuda.sh @@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90" [[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????} [[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???} [[ "${build_os}" = windows-* ]] && python3 -m pip install ninja -for NO_CUBLASLT in ON OFF; do - if [ "${build_os:0:6}" == ubuntu ]; then - image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04 - echo "Using image $image" - docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \ - "apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ - && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \ - && cmake --build ." - else - pip install cmake==3.28.3 - cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S . - cmake --build . --config Release - fi -done + +if [ "${build_os:0:6}" == ubuntu ]; then + image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04 + echo "Using image $image" + docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \ + && cmake --build ." +else + pip install cmake==3.28.3 + cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S . + cmake --build . --config Release +fi + output_dir="output/${build_os}/${build_arch}" mkdir -p "${output_dir}" diff --git a/CMakeLists.txt b/CMakeLists.txt index d305e5a3e..ce3962ff7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,6 @@ # For MSVC: `cmake -B build . && cmake --build build --config Release` # You can also use the following options and variables # - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend -# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # is whatever CMake finds on your path. # - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. @@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") if(APPLE) message(FATAL_ERROR "CUDA is not supported on macOS" ) endif() - option(NO_CUBLASLT "Disable CUBLAS" OFF) set(BUILD_CUDA ON) set(BUILD_MPS OFF) - message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -166,9 +163,6 @@ if(BUILD_CUDA) list(APPEND SRC_FILES ${CUDA_FILES}) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") - if(NO_CUBLASLT) - string(APPEND BNB_OUTPUT_NAME "_nocublaslt") - endif() add_compile_definitions(BUILD_CUDA) elseif(BUILD_MPS) if(NOT APPLE) @@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include) if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) - if(NO_CUBLASLT) - target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) - else() - target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) - endif() - + target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse) set_target_properties(bitsandbytes PROPERTIES CUDA_SEPARABLE_COMPILATION ON diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index bc7a51113..03e3add4a 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -283,9 +283,9 @@ def forward( B: torch.Tensor, out=None, bias: Optional[torch.Tensor] = None, - state=MatmulLtState, + state: MatmulLtState = None, ): - # state = state or MatmulLtState() + state = state or MatmulLtState() # default of pytorch behavior if inputs are empty ctx.is_empty = False @@ -318,7 +318,7 @@ def forward( if is_transposed: B = B.contiguous() - if (state.is_training and not has_grad) or state.CB is None: + if (state.is_training and not has_grad) or state.SCB is None: state.reset_grads() # 2. Quantize B @@ -347,7 +347,7 @@ def forward( outliers = state.CB[:, state.idx].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) else: - subA = state.subB = None + subA = None # 3. Int8 Matmul out32, Sout32 = F.igemmlt(CA, state.CB) @@ -377,7 +377,11 @@ def forward( ctx.save_for_backward(None, None) output_shape = (*input_shape[:-1], state.CB.shape[0]) - return output.reshape(output_shape).clone() + + if len(input_shape) == 3: + return output.view(output_shape).clone() + else: + return output @staticmethod def backward(ctx, grad_output): @@ -400,18 +404,16 @@ def backward(ctx, grad_output): grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) - if req_gradB: - # grad_output.T @ A - # grad_weight = grad_output.t().mm(A) - grad_B = torch.matmul(grad_output.t(), A) - if state.threshold > 0.0 and subA is not None: - grad_B[:, idx] += torch.matmul(grad_output.t(), subA) # if req_gradB: - # - # gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) - # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + + # grad_B = torch.matmul(grad_output.t(), A) # if state.threshold > 0.0 and subA is not None: # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + if req_gradB: + gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + if state.threshold > 0.0 and subA is not None: + grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: # grad_output @ B.T diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index b7522334c..5bed7fba4 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -37,11 +37,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: The library is not guaranteed to exist at the returned path. """ - library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" - if not cuda_specs.has_cublaslt: - # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt - library_name += "_nocublaslt" - library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" + library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index ed19795a0..e72d57590 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -11,7 +11,7 @@ class CUDASpecs: cuda_version_tuple: Tuple[int, int] @property - def has_cublaslt(self) -> bool: + def has_imma(self) -> bool: return self.highest_compute_capability >= (7, 5) diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index 8974c6400..45dc98dea 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") - # 7.5 is the minimum CC for cublaslt - if not cuda_specs.has_cublaslt: + # 7.5 is the minimum CC for int8 tensor cores + if not cuda_specs.has_imma: print_dedented( """ WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! diff --git a/csrc/ops.cu b/csrc/ops.cu index 089a30cc1..e2eddc7ab 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -314,8 +314,6 @@ int roundoff(int v, int d) { } -#ifdef NO_CUBLASLT -#else template cublasLtOrder_t get_order() { switch(ORDER) @@ -347,7 +345,6 @@ template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); -#endif template int get_leading_dim(int dim1, int dim2) @@ -379,8 +376,6 @@ template int get_leading_dim(int dim1, int dim2) template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { -#ifdef NO_CUBLASLT -#else cublasLtOrder_t orderA = get_order(); cublasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); @@ -419,7 +414,6 @@ template void trans if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); -#endif } template int igemmlt( @@ -513,9 +507,6 @@ template int igemmlt( template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { -#ifdef NO_CUBLASLT - return ERR_NOT_IMPLEMENTED; -#else int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; @@ -570,7 +561,6 @@ template int igemmlt(cublasLtHandle printf("error detected"); return has_error; -#endif // NO_CUBLASLT } int fill_up_to_nearest_multiple(int value, int multiple) @@ -681,10 +671,6 @@ template void transformRowToFormat(char * A, char *o void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { - -#ifdef NO_CUBLASLT -#else - cusparseSpMatDescr_t descA; cusparseDnMatDescr_t descB, descC; @@ -731,7 +717,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CUDA_CHECK_RETURN( cudaFree(dBuffer) ); -#endif } template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) diff --git a/tests/conftest.py b/tests/conftest.py index 59146963d..c029c3cb5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,10 +7,6 @@ def pytest_runtest_call(item): try: item.runtest() - except NotImplementedError as nie: - if "NO_CUBLASLT" in str(nie): - pytest.skip("CUBLASLT not available") - raise except AssertionError as ae: if str(ae) == "Torch not compiled with CUDA enabled": pytest.skip("Torch not compiled with CUDA enabled") diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index b13f8b6c6..79406472e 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -13,15 +13,6 @@ def cuda120_spec() -> CUDASpecs: ) -@pytest.fixture -def cuda111_noblas_spec() -> CUDASpecs: - return CUDASpecs( - cuda_version_string="111", - highest_compute_capability=(7, 2), - cuda_version_tuple=(11, 1), - ) - - def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" @@ -31,14 +22,3 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? - - -def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog): - monkeypatch.setenv("BNB_CUDA_VERSION", "125") - assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt" - assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? - - -def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): - monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 149d9a93c..48c3a9ea8 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -69,11 +69,13 @@ def test_linear_no_igemmlt(): fx_ours = linear_custom(x_ours).float() (fx_ours * grad_proj).mean().backward() + + assert linear_custom.state.CB is not None + assert not linear_custom.state.has_fp16_weights assert torch.allclose(fx_ref, fx_ours, atol=0.02) assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) - assert not linear_custom.state.has_fp16_weights - assert linear_custom.state.CB is not None - assert linear_custom.state.CxB is None + + # assert linear_custom.state.CxB is None @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) diff --git a/tests/test_modules.py b/tests/test_modules.py index 1f1b17584..c84ffa42a 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -529,16 +529,16 @@ def test_linear_kbit_fp32_bias(module): @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) def test_kbit_backprop(module): b = 16 - dim1 = 32 - dim2 = 48 + dim1 = 36 + dim2 = 56 # dim1 = 37 # dim2 = 83 - ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 16)]) - ref[1].weight.requires_grad = False + ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)]) + # ref[1].weight.requires_grad = False torch.nn.init.kaiming_normal_(ref[0].weight) torch.nn.init.kaiming_normal_(ref[1].weight) - kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 16)]) + kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)]) kbit[0].weight.detach().copy_(ref[0].weight) kbit[1].weight.detach().copy_(ref[1].weight) kbit[0].bias.detach().copy_(ref[0].bias) @@ -572,8 +572,8 @@ def test_kbit_backprop(module): relerrs1.append(relerr1.mean().item()) relerrs2.append(relerr2.mean().item()) - # if isinstance(module, bnb.nn.Linear8bitLt): - if module == bnb.nn.Linear8bitLt: + if isinstance(module, bnb.nn.Linear8bitLt): + # if module == bnb.nn.Linear8bitLt: assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1) torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) else: From 0ab14fece671df5788c54985d24a9314bcc9bc76 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:05:55 -0400 Subject: [PATCH 07/65] int8: inference optimizations, some cleanup --- .gitignore | 2 + bitsandbytes/autograd/_functions.py | 20 +-- bitsandbytes/functional.py | 173 ++++--------------- bitsandbytes/nn/modules.py | 6 +- bitsandbytes/research/autograd/_functions.py | 30 ++-- csrc/common.cuh | 48 +++++ csrc/kernels.cu | 109 +++++++++++- csrc/kernels.cuh | 3 + csrc/ops.cu | 75 ++------ csrc/ops.cuh | 6 +- csrc/pythonInterface.cpp | 7 +- tests/test_modules.py | 2 +- 12 files changed, 246 insertions(+), 235 deletions(-) create mode 100644 csrc/common.cuh diff --git a/.gitignore b/.gitignore index 22f5a6cd6..aca1983d3 100644 --- a/.gitignore +++ b/.gitignore @@ -22,9 +22,11 @@ CMakeFiles/ bitsandbytes.dir/ Debug/ Release/ +cmake-build-*/ # IDE local files .vs/ +.idea/ # Distribution / packaging .Python diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 03e3add4a..133f9e066 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -335,17 +335,17 @@ def forward( # Zero out the outliers in the int8 inputs CA[:, state.idx] = 0 - CAt[:, state.idx] = 0 + # CAt[:, state.idx] = 0 # Extract the input outliers in original precision subA = A[:, state.idx] # Extract the corresponding weights if state.has_fp16_weights: - state.subB = B[:, state.idx].t().contiguous() + state.subB = B[:, state.idx].t() # .contiguous() else: - outliers = state.CB[:, state.idx].clone() - state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) + outliers = state.CB[:, state.idx] # .clone() + state.subB = (7.874016e-3 * outliers * state.SCB.view(-1, 1)).t().to(A.dtype) else: subA = None @@ -372,14 +372,14 @@ def forward( ctx.tensors = (CAt, subA, A) ctx.tensor_states = (SCAt, state.idx) else: - ctx.tensors = [None, None, A] + ctx.tensors = [None, None, None] # A] ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) output_shape = (*input_shape[:-1], state.CB.shape[0]) if len(input_shape) == 3: - return output.view(output_shape).clone() + return output.reshape(output_shape) # .clone() else: return output @@ -417,10 +417,10 @@ def backward(ctx, grad_output): if req_gradA: # grad_output @ B.T - if state.CBt is not None: - gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - elif state.CB is not None: + # if state.CBt is not None: + # gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) + # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + if state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8d7226b2c..6c8ffe3d1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2400,7 +2400,6 @@ def get_colrow_absmax( outlier_mask = None if row_stats is None or col_stats is None: - # view as 2D absA = A.abs().view(-1, A.shape[-1]) if threshold > 0.0: @@ -2408,13 +2407,10 @@ def get_colrow_absmax( outlier_mask = absA >= threshold absA.masked_fill_(outlier_mask, 0.0) - # For parity with tests build nnz_block_ptr. - nnz_block_ptr = torch.zeros(absA.shape[0] + 1, dtype=torch.int64, device=A.device) - nnz_block_ptr[1:] = outlier_mask.sum(1).cumsum(0) - if row_stats is None: # shape [rows]; unsqueeze(-1) gives [rows,1] - row_stats = absA.amax(dim=1, keepdim=False).float() + row_stats = get_row_absmax(A, threshold) + if col_stats is None: # shape [cols]; unsqueeze(0) gives [1,cols] col_stats = absA.amax(dim=0, keepdim=False).float() @@ -2422,42 +2418,20 @@ def get_colrow_absmax( return row_stats, col_stats, outlier_mask -def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): +def get_row_absmax(A, threshold=0.0): assert A.dtype == torch.float16 - device = A.device + rows = prod(A.shape[:-1]) cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - col_tiles = (cols + 255) // 256 - tiled_rows = ((rows + 15) // 16) * 16 - if row_stats is None: - row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) - if col_stats is None: - col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) - - if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device) - - ptrA = get_ptr(A) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNnzrows = get_ptr(nnz_block_ptr) - rows = ct.c_int32(rows) - cols = ct.c_int32(cols) + row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device) - prev_device = pre_call(A.device) - is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) - lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) - post_call(prev_device) + is_on_gpu([A, row_stats]) - if threshold > 0.0: - nnz_block_ptr.cumsum_(0) + with torch.cuda.device_of(A): + lib.cget_row_stats(get_ptr(A), get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) - return row_stats, col_stats, nnz_block_ptr + return row_stats class COOSparseTensor: @@ -2541,127 +2515,48 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -# @torch.compile +def extract_outliers_new(A: torch.Tensor, threshold: float): + outlier_mask = A.abs() >= threshold + return A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() + + def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - # TODO: Optimize/write CUDA kernel for this + assert A.dtype == torch.half - if row_stats is None or col_stats is None: - row_stats, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32) + + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) if threshold > 0.0: # Extract outliers to COO tensor: # 1. Zero out all of the non-outliers, convert to COO. # 2. Zero out the outliers in the dense tensor. # TODO we could improve perf of this - # is_outlier = A.abs() >= threshold - coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() - A = A.masked_fill(outlier_mask, 0.0) + # outlier_mask = A.abs() >= threshold + # coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() + # A = A.masked_fill(outlier_mask, 0.0) + coo_tensor = extract_outliers_new(A, threshold) else: coo_tensor = None - # Quantize - scaled_A = A.mul(C) - # quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8) - # quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8) - quant_row = torch.round(scaled_A / row_stats.unsqueeze(-1)).to(torch.int8) - quant_col = torch.round(scaled_A / col_stats.unsqueeze(0)).to(torch.int8) - - if out_row is not None: - quant_row = out_row.copy_(quant_row) - if out_col is not None: - quant_col = out_col.copy_(quant_col) - - return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), coo_tensor + is_on_gpu([A, row_stats]) - -def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - device = A.device - assert A.dtype == torch.half - assert device.type == "cuda" - prev_device = pre_call(A.device) - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) - - if out_col is None: - out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: - out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) - - coo_tensor = None - ptrA = get_ptr(A) - ptrColStats = get_ptr(col_stats) - ptrRowStats = get_ptr(row_stats) - ptrOutCol = get_ptr(out_col) - ptrOutRow = get_ptr(out_row) - - is_on_gpu([A, col_stats, row_stats, out_col, out_row]) - if threshold > 0.0: - nnz = nnz_row_ptr[-1].item() - if nnz > 0: - coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) - ptrRowIdx = get_ptr(coo_tensor.rowidx) - ptrColIdx = get_ptr(coo_tensor.colidx) - ptrVal = get_ptr(coo_tensor.values) - ptrRowPtr = get_ptr(nnz_row_ptr) - - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - ptrRowIdx, - ptrColIdx, - ptrVal, - ptrRowPtr, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - val, idx = torch.sort(coo_tensor.rowidx) - coo_tensor.rowidx = val - coo_tensor.colidx = coo_tensor.colidx[idx] - coo_tensor.values = coo_tensor.values[idx] - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(0.0), - ct.c_int32(rows), - ct.c_int32(cols), - ) - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, + with torch.cuda.device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols), ) - post_call(prev_device) - return out_row, out_col, row_stats, col_stats, coo_tensor + # TODO: col_stats + + return out_row, None, row_stats, None, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1e5a334ee..fee15b000 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1008,9 +1008,9 @@ def forward(self, x: torch.Tensor): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights: - if self.state.CB is not None: - self.weight.data = self.state.CB + if not self.state.has_fp16_weights and self.state.CB is not None: + self.weight.data = self.state.CB + return out diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 3e807d6e1..dd4de5df6 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -186,7 +186,9 @@ class SwitchBackBnb(torch.autograd.Function): @staticmethod # TODO: the B008 on the line below is a likely bug; the current implementation will # have each SwitchBackBnb instance share a single MatmulLtState instance!!! - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008 + def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): + state = state or MatmulLtState() + # default to pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -222,7 +224,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # idx = torch.unique(coo_tensorA.colidx).long() idx = torch.unique(coo_tensorA._indices()[1]).long() CA[:, idx] = 0 - CAt[:, idx] = 0 + # CAt[:, idx] = 0 subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx @@ -249,7 +251,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 state.CBt, state.SCB, state.SCBt, - coo_tensorB, + _, ) = F.double_quant(B.to(torch.float16)) state.SB = (state.CB.shape, "row") else: @@ -257,21 +259,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 if coo_tensorA is not None and not state.has_fp16_weights: # extract outliers + state.idx = torch.unique(coo_tensorA._indices()[1]).long() - # outlier_idx = torch.unique(coo_tensorA.colidx) - outlier_idx = torch.unique(coo_tensorA._indices()[1]).long() - state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = state.CB[:, state.idx.long()].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 - CAt[:, state.idx.long()] = 0 + # CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] shapeB = state.SB[0] @@ -318,6 +312,7 @@ def backward(ctx, grad_output): if ctx.is_empty: bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states @@ -340,11 +335,10 @@ def backward(ctx, grad_output): grad_B = torch.matmul(grad_output.t(), A) if req_gradA: - if state.CBt is not None: - gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - - elif state.CB is not None: + # if state.CBt is not None: + # gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) + # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + if state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: diff --git a/csrc/common.cuh b/csrc/common.cuh new file mode 100644 index 000000000..8c85accfd --- /dev/null +++ b/csrc/common.cuh @@ -0,0 +1,48 @@ +#pragma once + +// TODO: Let's make some of these constexpr and put in a namespace. + +#define BNB_CC_MAXWELL 500 +#define BNB_CC_MAXWELL2 520 +#define BNB_CC_MAXWELL2_X1 530 +#define BNB_CC_PASCAL 600 +#define BNB_CC_PASCAL_X2 620 +#define BNB_CC_VOLTA 700 +#define BNB_CC_VOLTA_XAVIER 720 +#define BNB_CC_TURING 750 +#define BNB_CC_AMPERE 800 +#define BNB_CC_AMPERE2 860 +#define BNB_CC_AMPERE2_ORIN 870 +#define BNB_CC_ADA 890 +#define BNB_CC_HOPPER 900 +#define BNB_CC_BLACKWELL 1000 + +#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1) +#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) +#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) +#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) +#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) + +#define BNB_WARP_SIZE 32 + +// The maximum number of resident threads per SM varies by arch. +// For A100/H100 and all prior to Turing, it is 2048, which allows +// for 2 full blocks of 1024 threads per SM. +// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability +#if __CUDA_ARCH__ == 750 +#define BNB_MAX_THREADS_PER_SM 1024 +#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 +#define BNB_MAX_THREADS_PER_SM 1536 +#else +#define BNB_MAX_THREADS_PER_SM 2048 +#endif + +// Maximum resident warps per SM is always directly related to the number of threads. +#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) + +// Maximum resident blocks per SM may vary. +#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 +#define BNB_MAX_BLOCKS_PER_SM 16 +#else +#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) +#endif diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 34de9d5ca..d0ea0d270 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3,7 +3,8 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include +#include "kernels.cuh" +#include "common.cuh" #include #include #include @@ -2129,6 +2130,106 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char } } +// Inputs: +// A [rows, cols] +// Outputs: +// rowStats [rows] +// out [rows, cols] +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { + using BlockReduceT = cub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + __shared__ float smem_row_absmax; + + const int row_id = blockIdx.x; + const T* __restrict__ row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + float row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const float absval = fabsf(__ldg(&(row_data[i]))); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = smem_row_absmax = row_absmax; + } + __syncthreads(); + + // Quantize row-wise. + const float scale = __fdividef(127.0f, smem_row_absmax); + for (int i = threadIdx.x; i < cols; i += THREADS) { + if constexpr (SPARSE_DECOMP) { + // For sparse decomposition, we do not want to quantize the outliers. + // Instead they're zeroed out. + float val = row_data[i]; + out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; + } else { + out[row_id * cols + i] = __float2int_rn(float(row_data[i]) * scale); + } + } +} + +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { + using BlockReduceT = cub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + + const int row_id = blockIdx.x; + const T* __restrict__ row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + float row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const float absval = fabsf(row_data[i]); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = row_absmax; + } +} + template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) { // 0. reset stats to -FLT_MAX @@ -2283,6 +2384,12 @@ template(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); + +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 1e094dbd2..f17bfe4de 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -117,6 +117,9 @@ template __global__ void kdequant_mm_int32_fp half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); diff --git a/csrc/ops.cu b/csrc/ops.cu index e2eddc7ab..df5ec01da 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -505,64 +505,6 @@ template int igemmlt( return has_error; } -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) -{ - int has_error = 0; - cublasLtMatmulDesc_t matmulDesc = NULL; - cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - cublasOperation_t opT = CUBLAS_OP_T; - cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; - cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; - cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; - - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb)); - - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - if(FORMATB == COL_TURING) - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); - else - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); - - if(DTYPE_OUT == 32) - { - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I)); - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - int alpha = 1, beta = 0; - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - else - { - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F)); - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - if(!SCALE_ROWS) - { - float alpha = 1.0f, beta = 0.0f; - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - else - { - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - } - - - if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); - if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); - if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); - if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); - if(has_error == 1) - printf("error detected"); - - return has_error; -} - int fill_up_to_nearest_multiple(int value, int multiple) { return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); @@ -580,6 +522,15 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols) { + if (threshold == 0.0) { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } else { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + #define STATS_THREADS 64 #define STATS_ITEMS 4 #define STATS_ROWS 16 @@ -602,6 +553,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r } +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols) { + if (threshold == 0.0) + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + else + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) { int threads = 64; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 9ecb93bf2..558d93008 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -29,7 +29,6 @@ exit(1); \ } } -#define THREADS_PER_BLOCKS (512) #define CHECK_CUSPARSE(value) { \ cusparseStatus_t _m_cudaStat = value; \ @@ -40,9 +39,6 @@ } } -#define THREADS_PER_BLOCKS (512) - - inline void checkCudaStatus(cudaError_t status) { if (status != cudaSuccess) { printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); @@ -181,8 +177,10 @@ template void trans void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols); void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index b03b0650c..0400d9b48 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -337,7 +337,12 @@ extern "C" { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } - + void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols) { + getRowStats(A, rowStats, threshold, rows, cols); + } + void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols) { + int8VectorQuant(A, out, rowStats, threshold, rows, cols); + } void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols) { doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); } diff --git a/tests/test_modules.py b/tests/test_modules.py index c84ffa42a..51fb21178 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -530,7 +530,7 @@ def test_linear_kbit_fp32_bias(module): def test_kbit_backprop(module): b = 16 dim1 = 36 - dim2 = 56 + dim2 = 84 # dim1 = 37 # dim2 = 83 From fdf474562051671027c87e6ca9d96b4f1d216d0b Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 18 Oct 2024 18:24:30 -0400 Subject: [PATCH 08/65] int8: more tests passing, cleanup --- bitsandbytes/autograd/_functions.py | 46 ++++++++++++++++++----------- bitsandbytes/functional.py | 36 ++++++++++++++-------- csrc/kernels.cu | 4 +-- tests/test_autograd.py | 17 ++++++----- tests/test_functional.py | 30 +++++++++++++++---- tests/test_linear8bitlt.py | 7 +++-- 6 files changed, 93 insertions(+), 47 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 133f9e066..41555a450 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -308,7 +308,14 @@ def forward( # 1. Quantize A if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + + if ctx.needs_input_grad[1]: + # Slower path + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + else: + # Fast path + CA, SCA, coo_tensorA = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold) + CAt = SCAt = None has_grad = False @@ -322,20 +329,24 @@ def forward( state.reset_grads() # 2. Quantize B - ( - state.CB, - state.CBt, - state.SCB, - state.SCBt, - _, - ) = F.double_quant(B.to(torch.float16)) + state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) + + # ( + # state.CB, + # state.CBt, + # state.SCB, + # state.SCBt, + # _, + # ) = F.double_quant(B.to(torch.float16)) if state.threshold > 0.0 and coo_tensorA is not None: state.idx = torch.unique(coo_tensorA._indices()[1]).long() # Zero out the outliers in the int8 inputs CA[:, state.idx] = 0 - # CAt[:, state.idx] = 0 + + if CAt is not None: + CAt[:, state.idx] = 0 # Extract the input outliers in original precision subA = A[:, state.idx] @@ -372,7 +383,7 @@ def forward( ctx.tensors = (CAt, subA, A) ctx.tensor_states = (SCAt, state.idx) else: - ctx.tensors = [None, None, None] # A] + ctx.tensors = [None, None, None] ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) @@ -403,17 +414,16 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) - # if req_gradB: - - # grad_B = torch.matmul(grad_output.t(), A) - # if state.threshold > 0.0 and subA is not None: - # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradB: - gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + grad_B = torch.matmul(grad_output.t(), A) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + # Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) + # if req_gradB: + # gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) + # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + # if state.threshold > 0.0 and subA is not None: + # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: # grad_output @ B.T diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6c8ffe3d1..f4ff3eafa 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2409,6 +2409,7 @@ def get_colrow_absmax( if row_stats is None: # shape [rows]; unsqueeze(-1) gives [rows,1] + # We have a CUDA kernel for row max, but not yet for cols. row_stats = get_row_absmax(A, threshold) if col_stats is None: @@ -2521,29 +2522,42 @@ def extract_outliers_new(A: torch.Tensor, threshold: float): def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + # TODO: Optimize/write CUDA kernel for this? + # Note: for inference, use the new int8_vectorwise_quant. + + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, coo_tensor = int8_vectorwise_quant(A, threshold=threshold) + + # PyTorch impl for colwise + _, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8) + + if out_row is not None: + quant_row = out_row.copy_(quant_row) + if out_col is not None: + quant_col = out_col.copy_(quant_col) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor + + +def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): assert A.dtype == torch.half + is_on_gpu([A]) rows = prod(A.shape[:-1]) cols = A.shape[-1] row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) if threshold > 0.0: - # Extract outliers to COO tensor: - # 1. Zero out all of the non-outliers, convert to COO. - # 2. Zero out the outliers in the dense tensor. # TODO we could improve perf of this - # outlier_mask = A.abs() >= threshold - # coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() - # A = A.masked_fill(outlier_mask, 0.0) coo_tensor = extract_outliers_new(A, threshold) else: coo_tensor = None - is_on_gpu([A, row_stats]) - with torch.cuda.device_of(A): lib.cint8_vector_quant( get_ptr(A), @@ -2554,9 +2568,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, ct.c_int32(cols), ) - # TODO: col_stats - - return out_row, None, row_stats, None, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor + return out_row, row_stats, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d0ea0d270..45ee0a3ed 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3612,7 +3612,7 @@ template __global__ void kgemm_4bit_inferenc #pragma unroll for(int k = 0; k < num_values_8bit/4; k++) { - #if __CUDA_ARCH__ >= 800 + #if BNB_BF16_AVAILABLE local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; #else @@ -3649,7 +3649,7 @@ template __global__ void kgemm_4bit_inferenc #pragma unroll for(int k = 0; k < num_values_4bit/4; k++) { - #if __CUDA_ARCH__ >= 800 + #if BNB_BF16_AVAILABLE local_C += (float)(local_A[k]*local_B[k]); #else // bf16 multipliation not supported diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 89dce644b..3717a9572 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -253,13 +253,16 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec if not has_fp16_weights: if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous() - ( - state.CB, - CBt, - state.SCB, - SCBt, - coo_tensorB, - ) = bnb.functional.double_quant(B2.to(torch.float16)) + + state.CB, state.SCB, _ = bnb.functional.int8_vectorwise_quant(B2.to(torch.float16)) + + # ( + # state.CB, + # CBt, + # state.SCB, + # SCBt, + # coo_tensorB, + # ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: diff --git a/tests/test_functional.py b/tests/test_functional.py index 9b7004946..34dbf56fd 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1132,17 +1132,37 @@ def test_overflow(): c2 = torch.matmul(a.float(), b.float().t()) +@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) +def test_coo_double_quant(dim1, dim2): + threshold = 2.00 + for i in range(k): + A = torch.randn(dim1, dim2, device="cuda").half() + + idx = torch.abs(A) >= threshold + CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold) + + if coo_tensor is not None: + A1 = A * idx + A2 = coo_tensor.to_dense() + torch.testing.assert_close(A1, A2) + + A1 = A * (idx == 0) + A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + + # @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) # @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) -def test_coo_double_quant(dim1, dim2): +def test_coo_int8_vectorwise_quant(dim1, dim2): threshold = 3.00 for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() idx = torch.abs(A) >= threshold - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold) if coo_tensor is not None: A1 = A * idx @@ -1239,13 +1259,13 @@ def test_integrated_sparse_decomp(dim1, dim2): w1 = torch.randn(dim1, dim2).cuda().half() out1 = torch.matmul(A, w1.t()) - Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1) + CA, statsA, coo_tensor = F.int8_vectorwise_quant(A) out1_32, Sout1_32 = F.igemmlt(CA, Cw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold) out1_32, Sout1_32 = F.igemmlt(CA, Cw1) out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 48c3a9ea8..3f80beacf 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -72,10 +72,11 @@ def test_linear_no_igemmlt(): assert linear_custom.state.CB is not None assert not linear_custom.state.has_fp16_weights - assert torch.allclose(fx_ref, fx_ours, atol=0.02) - assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) - # assert linear_custom.state.CxB is None + idx = torch.isclose(fx_ref, fx_ours, atol=0.02, rtol=1e-5) + assert (idx == 0).sum().item() < fx_ref.numel() * 2.5e-4 + torch.testing.assert_close(fx_ref, fx_ours, atol=0.03, rtol=1e-5) + torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) From d231db75e34026b665b9b829ae04d7363c8b407d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 21 Oct 2024 16:23:36 -0400 Subject: [PATCH 09/65] int8 - more cleanup, most tests passing --- bitsandbytes/autograd/_functions.py | 22 ++++++++++---------- bitsandbytes/functional.py | 32 ++++++++++++++--------------- bitsandbytes/nn/modules.py | 7 ++----- tests/test_autograd.py | 4 +++- tests/test_linear8bitlt.py | 4 +++- tests/test_modules.py | 4 ++-- 6 files changed, 36 insertions(+), 37 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 41555a450..e4a740301 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -305,7 +305,7 @@ def forward( if A.dtype != torch.float16: warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") - # 1. Quantize A + # 1. Quantize A. Note that as a side-effect, outliers are suppressed. if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) @@ -342,9 +342,7 @@ def forward( if state.threshold > 0.0 and coo_tensorA is not None: state.idx = torch.unique(coo_tensorA._indices()[1]).long() - # Zero out the outliers in the int8 inputs - CA[:, state.idx] = 0 - + # Zero out the outliers in the transposed 8bit inputs. if CAt is not None: CAt[:, state.idx] = 0 @@ -414,16 +412,18 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # if req_gradB: + # grad_B = torch.matmul(grad_output.t(), A) + # if state.threshold > 0.0 and subA is not None: + # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + # Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) if req_gradB: - grad_B = torch.matmul(grad_output.t(), A) + Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) + + gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) - # Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) - # if req_gradB: - # gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) - # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) - # if state.threshold > 0.0 and subA is not None: - # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: # grad_output @ B.T diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f4ff3eafa..daeb37810 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -5,7 +5,7 @@ import ctypes as ct import itertools from math import prod -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Union import numpy as np import torch @@ -419,22 +419,23 @@ def get_special_format_str(): return "row" -def is_on_gpu(tensors): +def is_on_gpu(tensors: Iterable[torch.Tensor]): on_gpu = True gpu_ids = set() + for t in tensors: - if t is None: - continue # NULL pointers are fine - is_paged = getattr(t, "is_paged", False) - on_gpu &= t.device.type == "cuda" or is_paged - if not is_paged: + # NULL pointers and paged tensors are OK. + if t is not None and not getattr(t, "is_paged", False): + on_gpu &= t.is_cuda gpu_ids.add(t.device.index) + if not on_gpu: - raise TypeError( + raise RuntimeError( f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}", ) + if len(gpu_ids) > 1: - raise TypeError( + raise RuntimeError( f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}", ) return on_gpu @@ -2290,15 +2291,11 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): shapeA = A.shape shapeB = B.shape - dimsA = A.ndim - dimsB = B.ndim - assert A.device.type == "cuda" - assert B.device.type == "cuda" assert A.dtype == torch.int8 assert B.dtype == torch.int8 - assert dimsA == 2, "Only two dimensional matrices are supported for argument B" - assert dimsB in [2, 3], "Only two or three dimensional matrices are supported for argument A" + assert A.ndim == 2, "Only two dimensional matrices are supported for argument B" + assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A" assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}" shapeC = (*shapeB[:-1], shapeA[0]) @@ -2308,6 +2305,7 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): out = torch.empty(shapeC, device=A.device, dtype=dtype) assert out.dtype == dtype + k, m = shapeA n = prod(shapeB[:-1]) lda = shapeA[-1] # Weights (outputs, inputs) @@ -2427,7 +2425,7 @@ def get_row_absmax(A, threshold=0.0): row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device) - is_on_gpu([A, row_stats]) + is_on_gpu([A]) with torch.cuda.device_of(A): lib.cget_row_stats(get_ptr(A), get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) @@ -2568,7 +2566,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): ct.c_int32(cols), ) - return out_row, row_stats, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor + return out_row, row_stats, coo_tensor def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index fee15b000..66b671510 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -588,12 +588,9 @@ def cuda(self, device): if self.has_fp16_weights: return super().cuda(device) else: - # we store the 8-bit rows-major weight - # we convert this weight to the turning/ampere weight during the first inference pass + # We quantize the weight and store in 8bit row-major B = self.data.contiguous().half().cuda(device) - CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) - del CBt - del SCBt + CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B) self.data = CB self.CB = CB self.SCB = SCB diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 3717a9572..a5ed3f823 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -320,11 +320,13 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec else: assert torch.abs(gradB1).sum() == 0.0 assert torch.abs(gradB2).sum() == 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx == 0).sum().item() <= n * 0.10 - assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) if req_grad[2]: diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 3f80beacf..51e273897 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -93,7 +93,9 @@ def test_linear_serialization( load_before_cuda, ): linear = torch.nn.Linear(32, 96) - x = torch.randn(3, 32, dtype=torch.half) + # TODO: Fallback for bad shapes + x = torch.randn(4, 32, dtype=torch.half) + # x = torch.randn(3, 32, dtype=torch.half) linear_custom = Linear8bitLt( linear.in_features, diff --git a/tests/test_modules.py b/tests/test_modules.py index 51fb21178..9e16b5e2d 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -351,8 +351,8 @@ def test_linear8bitlt_accumulated_gradient(): l1[0].bias.data.copy_(l2[0].bias.data) l1[1].bias.data.copy_(l2[1].bias.data) else: - torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3) - torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04) + torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.00, atol=0.02) @pytest.mark.parametrize("threshold", [0.0, 2.0]) From dfc466868f5abe067a0e3a33e4f123344418f6ed Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:40:55 -0400 Subject: [PATCH 10/65] int8: specify CUDA stream for int8 ops --- bitsandbytes/functional.py | 38 ++++++++++++++++++++++---------------- bitsandbytes/nn/modules.py | 4 +--- csrc/kernels.cu | 2 +- csrc/ops.cu | 31 ++++++++++++++++--------------- csrc/ops.cuh | 8 ++++---- csrc/pythonInterface.cpp | 36 ++++++++++++++++++------------------ 6 files changed, 62 insertions(+), 57 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index daeb37810..681e06f08 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -442,8 +442,7 @@ def is_on_gpu(tensors: Iterable[torch.Tensor]): def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: - stream = torch.cuda.current_stream(tensor.device) - return stream + return torch.cuda.current_stream(tensor.device) def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: @@ -461,8 +460,8 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: """ if A is None: return None - else: - return ct.c_void_p(A.data.data_ptr()) + + return ct.c_void_p(A.data_ptr()) def pre_call(device): @@ -2323,11 +2322,12 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): ptrC = get_ptr(out) ptrRowScale = get_ptr(None) m, n, k, lda, ldb, ldc = map(ct.c_int32, (m, n, k, lda, ldb, ldc)) + stream = get_tensor_stream(A) if dtype == torch.int32: - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) else: - has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` raise NotImplementedError("igemmlt not implemented!") @@ -2373,13 +2373,7 @@ def mm_dequant( with torch.cuda.device_of(A): lib.cdequant_mm_int32_fp16( - ptrA, - ptrRowStats, - ptrColStats, - ptrOut, - ptrBias, - numRows, - numCols, + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, get_tensor_stream(A) ) return out @@ -2428,7 +2422,14 @@ def get_row_absmax(A, threshold=0.0): is_on_gpu([A]) with torch.cuda.device_of(A): - lib.cget_row_stats(get_ptr(A), get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + lib.cget_row_stats( + get_ptr(A), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + get_tensor_stream(A), + ) return row_stats @@ -2547,12 +2548,16 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): rows = prod(A.shape[:-1]) cols = A.shape[-1] - row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32) + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) if threshold > 0.0: # TODO we could improve perf of this - coo_tensor = extract_outliers_new(A, threshold) + + # A.masked_fill(A.abs() < threshold, 0.0).to_sparse_coo() + # coo_tensor = extract_outliers_new(A, threshold) + coo_tensor = torch.masked_fill(A, A.abs() < threshold, 0.0).to_sparse_coo() + else: coo_tensor = None @@ -2564,6 +2569,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols), + get_tensor_stream(A), ) return out_row, row_stats, coo_tensor diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 66b671510..1ab5cacaf 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -481,10 +481,8 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) - - out = out.to(inp_dtype) + out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 45ee0a3ed..0f7150b92 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3558,6 +3558,7 @@ template __global__ void kgemm_4bit_inferenc const int warp_idx = threadIdx.x / 32; const int warp_lane = threadIdx.x % 32; const int row_B = (THREADS/32)*blockIdx.x + warp_idx; + const int offset_B = ldb*row_B; const int num_values_8bit = num_values_4bit/2; float local_C = 0.0f; @@ -3578,7 +3579,6 @@ template __global__ void kgemm_4bit_inferenc for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) { const int inner_idx_halved = inner_idx/2; - const int offset_B = ldb*row_B; const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); //int absidx = ((2*offset_B)+inner_idx)/blocksize; local_absmax = __ldg(&(absmax[absidx])); diff --git a/csrc/ops.cu b/csrc/ops.cu index df5ec01da..88e829bef 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -423,7 +423,8 @@ template int igemmlt( const int8_t * B, void * C, float * row_scale, - int lda, int ldb, int ldc + int lda, int ldb, int ldc, + cudaStream_t stream ) { // Calculate C = A^T @ B, in col-major layout. @@ -461,7 +462,7 @@ template int igemmlt( B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, - NULL, NULL, 0, 0 + NULL, NULL, 0, stream )); } else { if (!SCALE_ROWS) { @@ -472,7 +473,7 @@ template int igemmlt( B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, - NULL, NULL, 0, 0 + NULL, NULL, 0, stream )); } else { cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; @@ -489,7 +490,7 @@ template int igemmlt( B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, - NULL, NULL, 0, 0 + NULL, NULL, 0, stream )); } } @@ -510,7 +511,7 @@ int fill_up_to_nearest_multiple(int value, int multiple) return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols) +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, cudaStream_t stream) { const int threads = 512; const int num_per_thread = 4; @@ -518,15 +519,15 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, const int n = numRows*numCols; const int num_blocks = (n + num_per_block - 1) / num_per_block; - kdequant_mm_int32_fp16<<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); + kdequant_mm_int32_fp16<<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols) { +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { if (threshold == 0.0) { - kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); } else { - kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); } CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -553,11 +554,11 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r } -void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols) { +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { if (threshold == 0.0) - kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); else - kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -795,9 +796,9 @@ template void extractOutliers(char * A, int *idx, char *out, int idx template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); +template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); +template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 558d93008..5f60051df 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -171,16 +171,16 @@ void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, i void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); -void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols); +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); -void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 0400d9b48..441d3adef 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -175,14 +175,14 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } -int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } -int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +int igemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } -int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +int igemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) @@ -308,14 +308,14 @@ extern "C" Context *get_context(){ return new Context(); } ContextCusparse *get_cusparse(){ return new ContextCusparse(); } - int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + int cigemmlt_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } - int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + int cigemmlt_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } - int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { - return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { + return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ @@ -333,15 +333,15 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) - void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols) - { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols); } + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream) + { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } - void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols) { - getRowStats(A, rowStats, threshold, rows, cols); + void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { + getRowStats(A, rowStats, threshold, rows, cols, stream); } - void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols) { - int8VectorQuant(A, out, rowStats, threshold, rows, cols); + void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { + int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); } void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols) { doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); } From 01bf54eaa41b1b5c3d433321fd7a697411a71e3f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 23 Oct 2024 22:09:45 -0400 Subject: [PATCH 11/65] perf: reduce overhead from getting cudaStream ptr --- bitsandbytes/functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 681e06f08..1907cd0f0 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -441,8 +441,9 @@ def is_on_gpu(tensors: Iterable[torch.Tensor]): return on_gpu -def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: - return torch.cuda.current_stream(tensor.device) +def get_tensor_stream(tensor: Tensor) -> ct.c_void_p: + # We use the raw stream for performance reasons. + return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: From 32979b49bbeda130e14dff6cc737eb22f1d743d2 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 24 Oct 2024 17:03:39 -0400 Subject: [PATCH 12/65] Mark some functions for deprecation. --- bitsandbytes/autograd/_functions.py | 16 +++++----- bitsandbytes/cextension.py | 2 +- bitsandbytes/functional.py | 49 ++++++++++++++++++++++++----- setup.py | 4 +-- 4 files changed, 53 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e4a740301..c654a0254 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -5,6 +5,7 @@ from warnings import warn import torch +from typing_extensions import deprecated import bitsandbytes.functional as F @@ -97,6 +98,10 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - return outputs.reshape(rows, cols).contiguous() +@deprecated( + "MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.", + category=FutureWarning, +) class MatMul8bit(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, quant_type="vector", precision=None): @@ -208,6 +213,7 @@ def backward(ctx, grad_output): matmul_cublas = MatMul8bit.apply +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" if torch.cuda.get_device_capability(device=device) < (7, 5): @@ -219,6 +225,7 @@ def supports_igemmlt(device: torch.device) -> bool: return True +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def _get_tile_size(format): assert format in ( "col_turing", @@ -227,6 +234,7 @@ def _get_tile_size(format): return (8, 32) if format == "col_turing" else (32, 32) +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def get_tile_inds(format, device): transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device) with torch.no_grad(): @@ -331,14 +339,6 @@ def forward( # 2. Quantize B state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) - # ( - # state.CB, - # state.CBt, - # state.SCB, - # state.SCBt, - # _, - # ) = F.double_quant(B.to(torch.float16)) - if state.threshold > 0.0 and coo_tensorA is not None: state.idx = torch.unique(coo_tensorA._indices()[1]).long() diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 5bed7fba4..8b0a0c91d 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -113,6 +113,6 @@ def get_native_library() -> BNBNativeLibrary: Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes -and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues +and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues """, ) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 1907cd0f0..33a5d72eb 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -10,6 +10,7 @@ import numpy as np import torch from torch import Tensor +from typing_extensions import deprecated from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict @@ -244,10 +245,12 @@ def fill(A, value, device=None, prefetch=True): elementwise_func("fill", A, None, value) +@deprecated("Function will be removed in a future release.", category=FutureWarning) def arange(A, device=None): elementwise_func("arange", A, None, 0) +@deprecated("Function will be removed in a future release.", category=FutureWarning) def _mul(A, B, device=None): elementwise_func("_mul", A, B, 0) @@ -414,7 +417,7 @@ def create_quantile_map(A, total_bits=8): return q -# TODO: Deprecate +@deprecated("This function is deprecated and will be removed in a future version.", category=FutureWarning) def get_special_format_str(): return "row" @@ -475,6 +478,10 @@ def post_call(prev_device): torch.cuda.set_device(prev_device) +@deprecated( + "The layout transformation operations will be removed in a future release. Please use row-major layout only.", + category=FutureWarning, +) def get_transform_func(dtype, orderA, orderOut, transpose=False): name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' if not hasattr(lib, name): @@ -486,6 +493,10 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False): return getattr(lib, name) +@deprecated( + "The layout transformation operations will be removed in a future release. Please use row-major layout only.", + category=FutureWarning, +) def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False): # init_func = torch.empty init_func = torch.zeros @@ -525,6 +536,10 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order="row", trans raise NotImplementedError(f"To_order not supported: {to_order}") +@deprecated( + "The layout transformation operations will be removed in a future release. Please use row-major layout only.", + category=FutureWarning, +) def nvidia_transform( A, to_order, @@ -1424,6 +1439,7 @@ def dequantize_4bit( return out +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def quantize( A: Tensor, code: Optional[torch.Tensor] = None, @@ -1443,6 +1459,7 @@ def quantize( return out, (absmax, code) +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def dequantize( A: Tensor, state: Optional[Tuple[Tensor, Tensor]] = None, @@ -1463,6 +1480,7 @@ def dequantize( return out * state[0] +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: """ Quantizes input tensor to 8-bit. @@ -1493,6 +1511,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No return out +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: """ Dequantizes the 8-bit tensor to 32-bit. @@ -1627,6 +1646,11 @@ def optimizer_update_32bit( post_call(prev_device) +@deprecated( + "This function is deprecated and will be removed in a future release. " + "Please use optimizer_update_8bit_blockwise instead. ", + category=FutureWarning, +) def optimizer_update_8bit( optimizer_name: str, g: Tensor, @@ -1827,6 +1851,7 @@ def optimizer_update_8bit_blockwise( post_call(prev_device) +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): """Applies percentile clipping @@ -2516,11 +2541,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def extract_outliers_new(A: torch.Tensor, threshold: float): - outlier_mask = A.abs() >= threshold - return A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() - - def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): # TODO: Optimize/write CUDA kernel for this? # Note: for inference, use the new int8_vectorwise_quant. @@ -2576,6 +2596,10 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): return out_row, row_stats, coo_tensor +@deprecated( + "The layout transformation operations will be removed in a future release. Please use row-major layout only.", + category=FutureWarning, +) def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) if state is None: @@ -2772,6 +2796,11 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): C = 127.0 +@deprecated( + "This function is deprecated and will be removed in a future release. " + "Consider using `int8_vectorwise_quant` instead.", + category=FutureWarning, +) def vectorwise_quant(x, dim=1, quant_type="vector"): if quant_type == "linear": max1 = torch.abs(x).max().float() @@ -2816,6 +2845,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): return None +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def vectorwise_dequant(xq, max1, quant_type="vector"): if quant_type == "vector": x = (xq / C * max1).to(torch.float32) @@ -2824,6 +2854,10 @@ def vectorwise_dequant(xq, max1, quant_type="vector"): return None +@deprecated( + "This function is deprecated and will be removed in a future release. Consider using `mm_dequant` instead.", + category=FutureWarning, +) def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): if quant_type == "linear": norm = S1 * S2 / (C * C) @@ -2883,6 +2917,7 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): return None +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): offset = B.float().t().sum(0) * (SA[0] + SA[1]) x = xq.float() @@ -2898,7 +2933,6 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): - # TODO: Implement for row-major shapeA = SA[0] formatA = SA[1] assert formatA in ["col_turing", "col_ampere"] @@ -2923,6 +2957,7 @@ def extract_outliers(A, SA, idx): return out +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) diff --git a/setup.py b/setup.py index 3a1bcb574..1d5b6f839 100644 --- a/setup.py +++ b/setup.py @@ -31,10 +31,10 @@ def has_ext_modules(self): description="k-bit optimizers and matrix multiplication routines.", license="MIT", keywords="gpu optimizers optimization 8-bit quantization compression", - url="https://github.com/TimDettmers/bitsandbytes", + url="https://github.com/bitsandbytes-foundation/bitsandbytes", packages=find_packages(), package_data={"": libs}, - install_requires=["torch", "numpy"], + install_requires=["torch", "numpy", "typing_extensions>=4.8.0"], extras_require={ "benchmark": ["pandas", "matplotlib"], "test": ["scipy", "lion_pytorch"], From 521da0c8c52723c17c4513d761b8e2f5f04fd18c Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:21:48 -0400 Subject: [PATCH 13/65] int8 sparse decomp: small perf improvement --- bitsandbytes/autograd/_functions.py | 8 ++++---- bitsandbytes/functional.py | 14 +++++--------- bitsandbytes/research/autograd/_functions.py | 15 +++++++-------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c654a0254..5d9983545 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -319,10 +319,10 @@ def forward( if ctx.needs_input_grad[1]: # Slower path - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold) else: # Fast path - CA, SCA, coo_tensorA = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold) + CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold) CAt = SCAt = None has_grad = False @@ -339,8 +339,8 @@ def forward( # 2. Quantize B state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) - if state.threshold > 0.0 and coo_tensorA is not None: - state.idx = torch.unique(coo_tensorA._indices()[1]).long() + if state.threshold > 0.0 and outlier_cols is not None: + state.idx = outlier_cols # Zero out the outliers in the transposed 8bit inputs. if CAt is not None: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 33a5d72eb..15402d7d4 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2546,7 +2546,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, # Note: for inference, use the new int8_vectorwise_quant. # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, coo_tensor = int8_vectorwise_quant(A, threshold=threshold) + quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold) # PyTorch impl for colwise _, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) @@ -2559,7 +2559,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, if out_col is not None: quant_col = out_col.copy_(quant_col) - return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): @@ -2574,13 +2574,9 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): if threshold > 0.0: # TODO we could improve perf of this - - # A.masked_fill(A.abs() < threshold, 0.0).to_sparse_coo() - # coo_tensor = extract_outliers_new(A, threshold) - coo_tensor = torch.masked_fill(A, A.abs() < threshold, 0.0).to_sparse_coo() - + outlier_cols = torch.argwhere((A.abs() >= threshold).any(dim=0)).view(-1) else: - coo_tensor = None + outlier_cols = None with torch.cuda.device_of(A): lib.cint8_vector_quant( @@ -2593,7 +2589,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): get_tensor_stream(A), ) - return out_row, row_stats, coo_tensor + return out_row, row_stats, outlier_cols @deprecated( diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index dd4de5df6..f8fe14c48 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -217,12 +217,11 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold) - if state.threshold > 0.0 and coo_tensorA is not None: + if state.threshold > 0.0 and outlier_cols is not None: if state.has_fp16_weights: - # idx = torch.unique(coo_tensorA.colidx).long() - idx = torch.unique(coo_tensorA._indices()[1]).long() + idx = outlier_cols CA[:, idx] = 0 # CAt[:, idx] = 0 subA = A[:, idx] @@ -257,9 +256,9 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): else: has_grad = False - if coo_tensorA is not None and not state.has_fp16_weights: + if outlier_cols is not None and not state.has_fp16_weights: # extract outliers - state.idx = torch.unique(coo_tensorA._indices()[1]).long() + state.idx = outlier_cols # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = state.CB[:, state.idx.long()].clone() @@ -287,7 +286,7 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): output = output.to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul - if coo_tensorA is not None and subA is not None: + if outlier_cols is not None and subA is not None: output += torch.matmul(subA, state.subB) # 5. Save state @@ -327,7 +326,7 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.double_quant(grad_output.to(torch.float16)) if req_gradB: # print('back A shape', A.shape) From 217cf8e65d3d423fba2733b282c39ab89f620f49 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:49:09 -0400 Subject: [PATCH 14/65] update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1d5b6f839..096b434fb 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ def read(fname): - return open(os.path.join(os.path.dirname(__file__), fname)).read() + return open(os.path.join(os.path.dirname(__file__), fname), encoding="utf8").read() # Tested with wheel v0.29.0 From b9cb5c93f88fcba1c62ce7af3ae09eb6dcb81197 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:39:03 -0400 Subject: [PATCH 15/65] Update bitsandbytes/autograd/_functions.py Co-authored-by: Aarni Koskela --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 5d9983545..d0963a1e9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -291,7 +291,7 @@ def forward( B: torch.Tensor, out=None, bias: Optional[torch.Tensor] = None, - state: MatmulLtState = None, + state: Optional[MatmulLtState] = None, ): state = state or MatmulLtState() From 6fa79050d32fa76c94739adca78221edc97f3abe Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:45:07 -0400 Subject: [PATCH 16/65] Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d09463a64..d07162219 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -422,7 +422,7 @@ def get_special_format_str(): return "row" -def is_on_gpu(tensors: Iterable[torch.Tensor]): +def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): on_gpu = True gpu_ids = set() From e929df032f380709b131d5ce8b2c00f5780ca5b2 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:11:16 -0400 Subject: [PATCH 17/65] Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela --- bitsandbytes/functional.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d07162219..9aadfcada 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2347,7 +2347,12 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): ptrB = get_ptr(B) ptrC = get_ptr(out) ptrRowScale = get_ptr(None) - m, n, k, lda, ldb, ldc = map(ct.c_int32, (m, n, k, lda, ldb, ldc)) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) stream = get_tensor_stream(A) if dtype == torch.int32: From c7b31df34502d8c28a7e5d81f39ed85284a8930f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:40:47 -0400 Subject: [PATCH 18/65] Update bitsandbytes/research/autograd/_functions.py Co-authored-by: Aarni Koskela --- bitsandbytes/research/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index f8fe14c48..5b31cca09 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -186,7 +186,7 @@ class SwitchBackBnb(torch.autograd.Function): @staticmethod # TODO: the B008 on the line below is a likely bug; the current implementation will # have each SwitchBackBnb instance share a single MatmulLtState instance!!! - def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): + def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = None): state = state or MatmulLtState() # default to pytorch behavior if inputs are empty From 57300e74a8cf70b7ccf7b44829de62cd0b26231f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:25:43 -0500 Subject: [PATCH 19/65] int8 - perf improvement for sparse decomposition inference; deprecate get_tensor_stream() in favor of new private fn --- bitsandbytes/functional.py | 68 ++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9aadfcada..536786f41 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -444,7 +444,12 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): return on_gpu -def get_tensor_stream(tensor: Tensor) -> ct.c_void_p: +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) +def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: + return torch.cuda.current_stream(tensor.device) + + +def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: # We use the raw stream for performance reasons. return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) @@ -468,12 +473,14 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: return ct.c_void_p(A.data_ptr()) +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def pre_call(device): prev_device = torch.cuda.current_device() torch.cuda.set_device(device) return prev_device +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def post_call(prev_device): torch.cuda.set_device(prev_device) @@ -1004,7 +1011,7 @@ def dequantize_blockwise( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()), - get_tensor_stream(A), + _get_tensor_stream(A), ) else: code = quant_state.code.cpu() @@ -1360,7 +1367,7 @@ def dequantize_4bit( n = out.numel() is_on_gpu([A, absmax, out]) - stream = get_tensor_stream(A) + stream = _get_tensor_stream(A) if out.dtype == torch.float32: if quant_state.quant_type == "fp4": with torch.cuda.device_of(A): @@ -1537,7 +1544,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = if out is None: out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) - stream = get_tensor_stream(A) + stream = _get_tensor_stream(A) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream) post_call(prev_device) return out @@ -2039,7 +2046,7 @@ def gemv_4bit( lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - stream = get_tensor_stream(A) + stream = _get_tensor_stream(A) with torch.cuda.device_of(A): if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: @@ -2097,8 +2104,6 @@ def gemv_4bit( else: raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") - # post_call(prev_device) - return out @@ -2299,7 +2304,7 @@ def batched_igemm( return out -def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): +def igemmlt(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): # # To use the IMMA tensor core kernels without special Turing/Ampere layouts, # cublasLt has some rules, namely: A must be transposed, B must not be transposed. @@ -2322,15 +2327,13 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): assert A.ndim == 2, "Only two dimensional matrices are supported for argument B" assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A" assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}" + assert out is None or out.dtype == dtype shapeC = (*shapeB[:-1], shapeA[0]) - Sout = (shapeC, "row") if out is None: out = torch.empty(shapeC, device=A.device, dtype=dtype) - assert out.dtype == dtype - k, m = shapeA n = prod(shapeB[:-1]) lda = shapeA[-1] # Weights (outputs, inputs) @@ -2353,7 +2356,7 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - stream = get_tensor_stream(A) + stream = _get_tensor_stream(A) if dtype == torch.int32: has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) @@ -2366,22 +2369,19 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32): if has_error: raise RuntimeError( f"cublasLt ran into an error!\n" - f"\tA: {shapeA}, B: {shapeB}, C: {Sout[0]}\n" + f"\tA: {shapeA}, B: {shapeB}, C: {shapeC}\n" f"\t(lda, ldb, ldc): {(lda, ldb, ldc)}\n" f"\t(m, n, k): {(m, n, k)}" ) - return out, Sout + return out -def mm_dequant( +def int8_mm_dequant( A: torch.Tensor, - quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format) row_stats: torch.Tensor, col_stats: torch.Tensor, out: Optional[torch.Tensor] = None, - new_row_stats=None, # TODO: unused - new_col_stats=None, # TODO: unused bias: Optional[torch.Tensor] = None, ): assert A.dtype == torch.int32 @@ -2404,12 +2404,26 @@ def mm_dequant( with torch.cuda.device_of(A): lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, get_tensor_stream(A) + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) ) return out +@deprecated("mm_dequant is deprecated. Please use int8_mm_dequant() instead.", category=FutureWarning) +def mm_dequant( + A: torch.Tensor, + quant_state: Optional[Tuple[torch.Size, str]], # Not used + row_stats: torch.Tensor, + col_stats: torch.Tensor, + out: Optional[torch.Tensor] = None, + new_row_stats=None, # Not used + new_col_stats=None, # Not used + bias: Optional[torch.Tensor] = None, +): + return int8_mm_dequant(A, row_stats, col_stats, out, bias) + + def get_colrow_absmax( A: torch.Tensor, row_stats: torch.Tensor = None, @@ -2459,7 +2473,7 @@ def get_row_absmax(A, threshold=0.0): ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols), - get_tensor_stream(A), + _get_tensor_stream(A), ) return row_stats @@ -2577,11 +2591,16 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + outlier_cols = None + if threshold > 0.0: # TODO we could improve perf of this - outlier_cols = torch.argwhere((A.abs() >= threshold).any(dim=0)).view(-1) - else: - outlier_cols = None + outliers = A.abs() >= threshold + + # argwhere needs host/device sync, so we skip when + # there aren't actually any outliers. + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) with torch.cuda.device_of(A): lib.cint8_vector_quant( @@ -2591,7 +2610,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols), - get_tensor_stream(A), + _get_tensor_stream(A), ) return out_row, row_stats, outlier_cols @@ -2933,6 +2952,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): return x.to(dtype) +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] From 0460d2e0d20adba321794b98d343105412a444b4 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:28:23 -0500 Subject: [PATCH 20/65] int8 cleanup --- bitsandbytes/autograd/_functions.py | 16 +++++---- bitsandbytes/nn/modules.py | 17 ++------- bitsandbytes/research/autograd/_functions.py | 9 +++-- tests/test_functional.py | 37 ++++++++++---------- 4 files changed, 35 insertions(+), 44 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d0963a1e9..d40b3f706 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -289,7 +289,7 @@ def forward( ctx: torch.autograd.function.FunctionCtx, A: torch.Tensor, B: torch.Tensor, - out=None, + out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, state: Optional[MatmulLtState] = None, ): @@ -339,7 +339,9 @@ def forward( # 2. Quantize B state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) - if state.threshold > 0.0 and outlier_cols is not None: + # Handle sparse decomposition. In some instances, we may have not found any + # outlier columns at all. In that case, we'll skip this part completely. + if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel(): state.idx = outlier_cols # Zero out the outliers in the transposed 8bit inputs. @@ -359,13 +361,13 @@ def forward( subA = None # 3. Int8 Matmul - out32, Sout32 = F.igemmlt(CA, state.CB) + out32 = F.igemmlt(CA, state.CB) if bias is None or bias.dtype == torch.float16: # we apply the fused bias here - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias).to(A.dtype) + output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype) else: # apply bias separately # TODO: Fused bias for fp32/bf16? - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias) + output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul if subA is not None and state.subB is not None: @@ -420,8 +422,8 @@ def backward(ctx, grad_output): if req_gradB: Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) - gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + gradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) + grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1ab5cacaf..17af8bdf5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -16,7 +16,6 @@ from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, - LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, ) @@ -923,29 +922,19 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): param_from_weight = getattr(self.weight, scb_name) # case 2: self.init_8bit_state was called, SCB is in self.state param_from_state = getattr(self.state, scb_name) - # case 3: SCB is in self.state, weight layout reordered after first forward() - layout_reordered = self.state.CxB is not None key_name = prefix + f"{scb_name}" + + # We now only save in row-major. This format information is stored for backwards compatibility. format_name = prefix + "weight_format" if not self.state.has_fp16_weights: if param_from_weight is not None: destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() destination[format_name] = torch.tensor(0, dtype=torch.uint8) - elif param_from_state is not None and not layout_reordered: - destination[key_name] = param_from_state if keep_vars else param_from_state.detach() - destination[format_name] = torch.tensor(0, dtype=torch.uint8) elif param_from_state is not None: destination[key_name] = param_from_state if keep_vars else param_from_state.detach() - weights_format = self.state.formatB - # At this point `weights_format` is an str - if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: - raise ValueError(f"Unrecognized weights format {weights_format}") - - weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format] - - destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8) + destination[format_name] = torch.tensor(0, dtype=torch.uint8) def _load_from_state_dict( self, diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index f8fe14c48..390f54fed 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -275,15 +275,14 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - out32, Sout32 = F.igemmlt(CA, state.CB) + out32 = F.igemmlt(CA, state.CB) # we apply the fused bias here if bias is None or bias.dtype == torch.float16: - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A.dtype) + output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype) else: # apply bias separately - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A.dtype).add_(bias) + output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype) + output.add_(bias) # 4. Mixed-precision decomposition matmul if outlier_cols is not None and subA is not None: diff --git a/tests/test_functional.py b/tests/test_functional.py index 34dbf56fd..47e8c2b75 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -589,7 +589,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) - C2, SC = F.igemmlt(A, B) + C2 = F.igemmlt(A, B) torch.testing.assert_close(C1, C2.float()) # transpose @@ -623,8 +623,8 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) - out1_32, Sout1_32 = F.igemmlt(CA, CB) - output = F.mm_dequant(out1_32, Sout1_32, statsA, statsB) + out1_32 = F.igemmlt(CA, CB) + output = F.int8_mm_dequant(out1_32, statsA, statsB) # print('') # print(output.flatten()[:10]) @@ -822,7 +822,7 @@ def test_dequant_mm(dim1, dim4, dims, has_bias): A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) - C2, SC = F.igemmlt(A1, B1) + C2 = F.igemmlt(A1, B1) C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) if has_bias: @@ -837,7 +837,7 @@ def test_dequant_mm(dim1, dim4, dims, has_bias): # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" - C5 = F.mm_dequant(C2, SC, maxA, maxB, bias=bias) + C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias) C5 /= std torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) n = C5.numel() @@ -947,9 +947,9 @@ def test_integrated_igemmlt(dim1, dim4, inner): torch.testing.assert_close(C1a, A1, rtol=0, atol=1) torch.testing.assert_close(C2a, B1, rtol=0, atol=1) - out2, SC = F.igemmlt(A1, B1) + out2 = F.igemmlt(A1, B1) - C2, SC = F.igemmlt(A1, B1) + C2 = F.igemmlt(A1, B1) out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) @@ -991,8 +991,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32, SC = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) - C3, S = F.nvidia_transform(outC32, "row", state=SC) + outC32 = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) + # C3, S = F.nvidia_transform(outC32, "row", state=SC) + C3 = outC32 maxval = torch.abs(C3).max() if maxval == 127: scale = 1.5 @@ -1004,8 +1005,8 @@ def test_igemmlt_row_scale(dim1, dim4, inner): C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) - outC32, SC = F.igemmlt(A2, B2) - out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + outC32 = F.igemmlt(A2, B2) + out2 = F.int8_mm_dequant(outC32, stats1a, stats2a) CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector") CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear") @@ -1072,7 +1073,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) + outC32 = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1081,7 +1082,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt(A2, B2) + outC32 = F.igemmlt(A2, B2) torch.cuda.synchronize() print("vector-wise", time.time() - t0) @@ -1262,13 +1263,13 @@ def test_integrated_sparse_decomp(dim1, dim2): Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1) CA, statsA, coo_tensor = F.int8_vectorwise_quant(A) - out1_32, Sout1_32 = F.igemmlt(CA, Cw1) - out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + out1_32 = F.igemmlt(CA, Cw1) + out2 = F.int8_mm_dequant(out1_32, statsA, statsw1) CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold) - out1_32, Sout1_32 = F.igemmlt(CA, Cw1) - out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + out1_32 = F.igemmlt(CA, Cw1) + out3 = F.int8_mm_dequant(out1_32, statsA, statsw1) assert coo_tensor is not None @@ -1599,7 +1600,7 @@ def test_bench_matmul(batch, seq, model, hidden): t0 = time.time() for i in range(iters): # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - out32, Sout32 = F.igemmlt(CA, CB) + out32 = F.igemmlt(CA, CB) torch.cuda.synchronize() print( f"no overhead igemmlt [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" From 762daf49e3d9ec4df64d08266e0bba81ce818b8f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:29:55 -0500 Subject: [PATCH 21/65] Ignore ruff rule ISC001 (incompatible with formatter) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 271edc84e..a51fc0c0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ ignore = [ "E731", # Do not use lambda "F841", # Local assigned but not used (TODO: enable, these are likely bugs) "RUF012", # Mutable class attribute annotations + "ISC001", # single-line-implicit-string-concatenation incompatible with formatter ] [tool.ruff.lint.extend-per-file-ignores] From 875414ea63bc572358e07699058e9dc18343773a Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:33:32 -0500 Subject: [PATCH 22/65] add comment --- bitsandbytes/autograd/_functions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d40b3f706..7b1cccb23 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -353,9 +353,13 @@ def forward( # Extract the corresponding weights if state.has_fp16_weights: - state.subB = B[:, state.idx].t() # .contiguous() + state.subB = B[:, state.idx].t() else: - outliers = state.CB[:, state.idx] # .clone() + outliers = state.CB[:, state.idx] + + # To dequantize our weights associated with the input outliers, + # we want to divide by 127. It's however more performant to multiply + # by the reciprocal. state.subB = (7.874016e-3 * outliers * state.SCB.view(-1, 1)).t().to(A.dtype) else: subA = None From 0aefeb0b9da3c71aced0c15061d488acbe8023d0 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:39:21 -0500 Subject: [PATCH 23/65] int8 more cleanup --- bitsandbytes/autograd/_functions.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7b1cccb23..d24d8aac4 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -264,7 +264,7 @@ class MatmulLtState: has_fp16_weights = True memory_efficient_backward = False use_pool = False - formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove + formatB = "row" # TODO: Deprecate/remove def reset_grads(self): self.CB = None @@ -394,9 +394,9 @@ def forward( output_shape = (*input_shape[:-1], state.CB.shape[0]) if len(input_shape) == 3: - return output.reshape(output_shape) # .clone() - else: - return output + return output.reshape(output_shape) + + return output @staticmethod def backward(ctx, grad_output): @@ -418,11 +418,6 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - # if req_gradB: - # grad_B = torch.matmul(grad_output.t(), A) - # if state.threshold > 0.0 and subA is not None: - # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) - # Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) if req_gradB: Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) @@ -432,15 +427,11 @@ def backward(ctx, grad_output): grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - # grad_output @ B.T - # if state.CBt is not None: - # gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) - # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) if state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception("State must contain either CBt or CB matrix for backward") + raise Exception("State must contain CB matrix for backward") return grad_A, grad_B, None, grad_bias, None From bfb42d1b72f5aa32ea2f1387020542774474dca7 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:39:38 -0500 Subject: [PATCH 24/65] Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela --- bitsandbytes/functional.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 536786f41..2a0f38510 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1439,11 +1439,9 @@ def dequantize_4bit( else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - is_transposed = A.shape[0] == 1 - if is_transposed: + if A.shape[0] == 1: # is transposed, transpose back return out.t() - else: - return out + return out @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) From bf002dbece3dac702dc91e4ed009351b27064d8d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:14:14 -0500 Subject: [PATCH 25/65] int8: rename / deprecate old fn signatures --- bitsandbytes/autograd/_functions.py | 4 +- bitsandbytes/functional.py | 23 +++++- bitsandbytes/research/autograd/_functions.py | 2 +- tests/test_functional.py | 76 ++++++-------------- 4 files changed, 48 insertions(+), 57 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d24d8aac4..79619cf74 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -365,7 +365,7 @@ def forward( subA = None # 3. Int8 Matmul - out32 = F.igemmlt(CA, state.CB) + out32 = F.int8_linear_matmul(CA, state.CB) if bias is None or bias.dtype == torch.float16: # we apply the fused bias here output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype) @@ -421,7 +421,7 @@ def backward(ctx, grad_output): if req_gradB: Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) - gradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) + gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t()) grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2a0f38510..5840e38be 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2302,7 +2302,28 @@ def batched_igemm( return out -def igemmlt(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): +@deprecated( + "igemmlt is deprecated and will be removed in a future release. " "Please use int8_linear_matmul instead.", + category=FutureWarning, +) +def igemmlt( + A: torch.Tensor, + B: torch.Tensor, + SA: Tuple[torch.Size, str], + SB: Tuple[torch.Size, str], + out: Optional[torch.Tensor] = None, + Sout: Optional[Tuple[torch.Size, str]] = None, + dtype=torch.int32, +): + if SA is not None and SA[1] != "row": + raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`") + if SB is not None and SB[1] != "row": + raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`") + result = int8_linear_matmul(A, B, out=out, dtype=dtype) + return result, (result.shape, "row") + + +def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): # # To use the IMMA tensor core kernels without special Turing/Ampere layouts, # cublasLt has some rules, namely: A must be transposed, B must not be transposed. diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 0a4b2a0fd..44fe79d82 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -275,7 +275,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - out32 = F.igemmlt(CA, state.CB) + out32 = F.int8_linear_matmul(CA, state.CB) # we apply the fused bias here if bias is None or bias.dtype == torch.float16: diff --git a/tests/test_functional.py b/tests/test_functional.py index 47e8c2b75..ecc2261fa 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -570,17 +570,13 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans torch.testing.assert_close(A, out2) -# @pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1")) -# @pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2")) -# @pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3")) -# @pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [512], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) -def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): +def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) @@ -589,26 +585,16 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) - C2 = F.igemmlt(A, B) + C2 = F.int8_linear_matmul(A, B) torch.testing.assert_close(C1, C2.float()) - # transpose - # B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) - # C1 = torch.matmul(A.float(), B.float()) - - # B2t, SBt = F.transform(B, "col", transpose=True) - # C2, SC = F.igemmlt(A2, B2t, SA, SBt) #B2t, A2, SBt, SA) - # C3, S = F.nvidia_transform(C2, "row", state=SC) - # torch.testing.assert_close(C1, C2.float()) - @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) -def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): - formatB = F.get_special_format_str() +def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims): for i in range(k): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() @@ -617,31 +603,15 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): B = torch.randn((dim4, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) - C2 = bnb.matmul(A, B.t()) A = A.view(-1, A.shape[-1]) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) - out1_32 = F.igemmlt(CA, CB) - output = F.int8_mm_dequant(out1_32, statsA, statsB) - - # print('') - # print(output.flatten()[:10]) - # print(C1.flatten()[:10]) - # print(C2.flatten()[:10]) + CA, _, statsA, _, _ = F.double_quant(A) + CB, _, statsB, _, _ = F.int8_vectorwise_quant(B) + output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB) torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) - # transpose - # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) - # C1 = torch.matmul(A.float(), B.float()) - - # B2t, SBt = F.transform2(B, 'col_turing', transpose=True) - # C2, SC = F.igemmlt(A2, B2t, SA, SBt) - # C3, S = F.transform(C2, 'row', state=SC) - # torch.testing.assert_close(C1, C3.float()) - @pytest.mark.parametrize( ("batch", "seq", "model", "hidden"), @@ -822,7 +792,7 @@ def test_dequant_mm(dim1, dim4, dims, has_bias): A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) - C2 = F.igemmlt(A1, B1) + C2 = F.int8_linear_matmul(A1, B1) C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) if has_bias: @@ -930,15 +900,15 @@ def test_double_quant(dim1, dim2): ) ), ) -def test_integrated_igemmlt(dim1, dim4, inner): +def test_integrated_int8_linear_matmul(dim1, dim4, inner): for i in range(k): A = torch.randn(dim1, inner, device="cuda").half() B = torch.randn(dim4, inner, device="cuda").half() out1 = torch.matmul(A.half(), B.t().half()) - C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) - C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + C1a, stats1a, _ = F.int8_vectorwise_quant(A) + C2a, stats2a, _ = F.int8_vectorwise_quant(B) A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) @@ -947,9 +917,9 @@ def test_integrated_igemmlt(dim1, dim4, inner): torch.testing.assert_close(C1a, A1, rtol=0, atol=1) torch.testing.assert_close(C2a, B1, rtol=0, atol=1) - out2 = F.igemmlt(A1, B1) + out2 = F.int8_linear_matmul(A1, B1) - C2 = F.igemmlt(A1, B1) + C2 = F.int8_linear_matmul(A1, B1) out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) @@ -991,7 +961,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32 = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) + outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale) # C3, S = F.nvidia_transform(outC32, "row", state=SC) C3 = outC32 maxval = torch.abs(C3).max() @@ -1005,7 +975,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) - outC32 = F.igemmlt(A2, B2) + outC32 = F.int8_linear_matmul(A2, B2) out2 = F.int8_mm_dequant(outC32, stats1a, stats2a) CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector") @@ -1073,7 +1043,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32 = F.igemmlt(A2, B2, dtype=torch.int8, row_scale=row_scale) + outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1082,7 +1052,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32 = F.igemmlt(A2, B2) + outC32 = F.int8_linear_matmul(A2, B2) torch.cuda.synchronize() print("vector-wise", time.time() - t0) @@ -1129,7 +1099,7 @@ def test_overflow(): # Cb, Sb = F.nvidia_transform(b, formatB) # c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) - c = F.igemmlt(a, b, dtype=torch.int8) + c = F.int8_linear_matmul(a, b, dtype=torch.int8) c2 = torch.matmul(a.float(), b.float().t()) @@ -1263,12 +1233,12 @@ def test_integrated_sparse_decomp(dim1, dim2): Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1) CA, statsA, coo_tensor = F.int8_vectorwise_quant(A) - out1_32 = F.igemmlt(CA, Cw1) + out1_32 = F.int8_linear_matmul(CA, Cw1) out2 = F.int8_mm_dequant(out1_32, statsA, statsw1) CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold) - out1_32 = F.igemmlt(CA, Cw1) + out1_32 = F.int8_linear_matmul(CA, Cw1) out3 = F.int8_mm_dequant(out1_32, statsA, statsw1) assert coo_tensor is not None @@ -1594,16 +1564,16 @@ def test_bench_matmul(batch, seq, model, hidden): f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0) + CB, SCB, _ = F.int8_vectorwise_quant(B) torch.cuda.synchronize() t0 = time.time() for i in range(iters): # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - out32 = F.igemmlt(CA, CB) + out32 = F.int8_linear_matmul(CA, CB) torch.cuda.synchronize() print( - f"no overhead igemmlt [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + f"no overhead int8 [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) # C32A, SA = F.transform(CA, "col32") From 7f6fb60dc17a4e26d35b029be55da5a000a7be3d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:15:45 -0500 Subject: [PATCH 26/65] Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 5840e38be..8cdac1225 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2011,7 +2011,7 @@ def gemv_4bit( ): # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError("state cannot None. gemv_4bit() requires the state from quantize_4bit()") + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") if A.numel() != A.shape[-1]: raise ValueError( From 135b336c06fca95e25b14f5755d04b9bc36c8878 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:18:55 -0500 Subject: [PATCH 27/65] type annotation --- bitsandbytes/functional.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8cdac1225..388df51ce 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2579,7 +2579,14 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): +def double_quant( + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, +): # TODO: Optimize/write CUDA kernel for this? # Note: for inference, use the new int8_vectorwise_quant. From 5388877ef2811748d3f8e35877edd961c48e9597 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:25:19 -0500 Subject: [PATCH 28/65] format update --- bitsandbytes/nn/modules.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 17af8bdf5..f5c7f6a7c 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -481,8 +481,7 @@ def forward(self, x: torch.Tensor): bias = None if self.bias is None else self.bias.to(self.compute_dtype) - out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) - return out + return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) class LinearFP4(Linear4bit): From be2e98fb96a70f4ba32160b910b0f06d304aac5b Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:26:46 -0500 Subject: [PATCH 29/65] Update bitsandbytes/research/autograd/_functions.py Co-authored-by: Aarni Koskela --- bitsandbytes/research/autograd/_functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 44fe79d82..ef2048d77 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -184,8 +184,6 @@ def backward(ctx, grad_output): class SwitchBackBnb(torch.autograd.Function): @staticmethod - # TODO: the B008 on the line below is a likely bug; the current implementation will - # have each SwitchBackBnb instance share a single MatmulLtState instance!!! def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = None): state = state or MatmulLtState() From 4c849bb909c7a8c9782386289bb2ceaf173406f3 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:32:02 -0500 Subject: [PATCH 30/65] cleanup --- bitsandbytes/research/autograd/_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 44fe79d82..8ac90aa3f 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -223,7 +223,6 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non if state.has_fp16_weights: idx = outlier_cols CA[:, idx] = 0 - # CAt[:, idx] = 0 subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx @@ -264,7 +263,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non outliers = state.CB[:, state.idx.long()].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 - # CAt[:, state.idx.long()] = 0 + subA = A[:, state.idx.long()] shapeB = state.SB[0] From b9544741aaf0129e1fa0023aa008e40e5bdddbc8 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:55:47 -0500 Subject: [PATCH 31/65] Add comment to explain division optimization --- csrc/kernels.cu | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index cf6bb1d8e..74b9f7c47 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -729,8 +729,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs valid_items_load = min(TILE_SIZE, n - i); valid_items_store = valid_items_load; } + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); - //local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); @@ -3579,9 +3582,13 @@ template __global__ void kgemm_4bit_inferenc for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) { const int inner_idx_halved = inner_idx/2; + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); - //int absidx = ((2*offset_B)+inner_idx)/blocksize; - local_absmax = __ldg(&(absmax[absidx])); + + local_absmax = __ldg(&(absmax[absidx])); if(row_B < M) { From 35dbb2eb0bad404efae3492b6de4dc6d13f18280 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:58:28 -0500 Subject: [PATCH 32/65] more cleanup --- bitsandbytes/functional.py | 6 +++--- bitsandbytes/research/autograd/_functions.py | 2 -- tests/test_autograd.py | 8 -------- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 388df51ce..60c7a1931 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2388,9 +2388,9 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten if has_error: raise RuntimeError( f"cublasLt ran into an error!\n" - f"\tA: {shapeA}, B: {shapeB}, C: {shapeC}\n" - f"\t(lda, ldb, ldc): {(lda, ldb, ldc)}\n" - f"\t(m, n, k): {(m, n, k)}" + f"\t{shapeA=}, {shapeB=}, {shapeC=}\n" + f"\t{(lda, ldb, ldc)=}\n" + f"\t{(m, n, k)=}" ) return out diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 73450b2bb..34eeb02de 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -256,8 +256,6 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non if outlier_cols is not None and not state.has_fp16_weights: # extract outliers state.idx = outlier_cols - - # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = state.CB[:, state.idx.long()].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 diff --git a/tests/test_autograd.py b/tests/test_autograd.py index a5ed3f823..3422550f8 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -255,14 +255,6 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec B2 = B2.t().contiguous() state.CB, state.SCB, _ = bnb.functional.int8_vectorwise_quant(B2.to(torch.float16)) - - # ( - # state.CB, - # CBt, - # state.SCB, - # SCBt, - # coo_tensorB, - # ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: From b36003fcad3eb92b52d4906902d54d1a2131b520 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:23:26 -0500 Subject: [PATCH 33/65] Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 60c7a1931..79721212c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2303,7 +2303,7 @@ def batched_igemm( @deprecated( - "igemmlt is deprecated and will be removed in a future release. " "Please use int8_linear_matmul instead.", + "igemmlt is deprecated and will be removed in a future release. Please use int8_linear_matmul instead.", category=FutureWarning, ) def igemmlt( From 03a19633b5a30730d64dfb28a2fd8df60446df52 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:27:53 -0500 Subject: [PATCH 34/65] Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela --- bitsandbytes/functional.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 79721212c..378b3941f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2694,7 +2694,11 @@ def transform(A, to_order, from_order="row", out=None, transpose=False, state=No return out, new_state -def spmm_coo(cooA: Union[COOSparseTensor, torch.Tensor], B: torch.Tensor, out: torch.Tensor = None): +def spmm_coo( + cooA: Union[COOSparseTensor, torch.Tensor], + B: torch.Tensor, + out: Optional[torch.Tensor] = None, +): if not isinstance(cooA, COOSparseTensor): assert ( cooA.is_sparse and cooA.layout == torch.sparse_coo From 980279f2109fc7bbcc4ff22c86ba2333329df8a1 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:36:58 -0500 Subject: [PATCH 35/65] Update bitsandbytes/functional.py Co-authored-by: Aarni Koskela --- bitsandbytes/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 378b3941f..fc2e6651e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2445,9 +2445,9 @@ def mm_dequant( def get_colrow_absmax( A: torch.Tensor, - row_stats: torch.Tensor = None, - col_stats: torch.Tensor = None, - nnz_block_ptr: torch.Tensor = None, + row_stats: Optional[torch.Tensor] = None, + col_stats: Optional[torch.Tensor] = None, + nnz_block_ptr: Optional[torch.Tensor] = None, threshold=0.0, ): # Note: prior impl only works with fp16 From a72c463c7f4cc73b21dce227c1b5d8954dd02a4d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:04:41 -0500 Subject: [PATCH 36/65] cleanup --- bitsandbytes/functional.py | 264 +++++++++++++------------------------ 1 file changed, 92 insertions(+), 172 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 60c7a1931..05495fe5b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -828,13 +828,13 @@ def __eq__(self, other): def quantize_blockwise( - A: Tensor, + A: torch.Tensor, code: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=4096, nested=False, -) -> Tuple[Tensor, QuantState]: +) -> Tuple[torch.Tensor, QuantState]: """ Quantize tensor A in blocks of size 4096 values. @@ -878,21 +878,11 @@ def quantize_blockwise( assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] code = code.to(A.device) - is_on_gpu([code, A, out, absmax]) - - fn_map = { - torch.float32: "cquantize_blockwise_fp32", - torch.bfloat16: "cquantize_blockwise_bf16", - torch.float16: "cquantize_blockwise_fp16", - } - if A.dtype not in fn_map.keys(): - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - fn = fn_map[A.dtype] + is_on_gpu([A, out, absmax]) with torch.cuda.device_of(A): - lib[fn]( + args = ( get_ptr(code), get_ptr(A), get_ptr(absmax), @@ -901,6 +891,15 @@ def quantize_blockwise( ct.c_int(A.numel()), ) + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + else: # cpu code = code.cpu() @@ -932,14 +931,14 @@ def quantize_blockwise( def dequantize_blockwise( - A: Tensor, + A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, code: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 4096, nested=False, -) -> Tensor: +) -> torch.Tensor: """ Dequantizes blockwise quantized values. @@ -986,25 +985,15 @@ def dequantize_blockwise( if A.device.type != "cpu": code = quant_state.code.to(A.device) - if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]: raise ValueError( - f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]", ) - is_on_gpu([A, absmax, out]) - fn_map = { - torch.float32: "cdequantize_blockwise_fp32", - torch.bfloat16: "cdequantize_blockwise_bf16", - torch.float16: "cdequantize_blockwise_fp16", - } - - if out.dtype not in fn_map.keys(): - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - - fn = fn_map[out.dtype] + is_on_gpu([A, absmax, out]) with torch.cuda.device_of(A): - lib[fn]( + args = ( get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), @@ -1013,6 +1002,15 @@ def dequantize_blockwise( ct.c_int(A.numel()), _get_tensor_stream(A), ) + + if out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif out.dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") else: code = quant_state.code.cpu() lib.cdequantize_blockwise_cpu_fp32( @@ -1110,7 +1108,7 @@ def get_4bit_type(typename, device=None, blocksize=64): def quantize_fp4( - A: Tensor, + A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, @@ -1121,7 +1119,7 @@ def quantize_fp4( def quantize_nf4( - A: Tensor, + A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, @@ -1132,14 +1130,14 @@ def quantize_nf4( def quantize_4bit( - A: Tensor, + A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, -) -> Tuple[Tensor, QuantState]: +) -> Tuple[torch.Tensor, QuantState]: """ Quantize tensor A in blocks of 4-bit values. @@ -1184,71 +1182,34 @@ def quantize_4bit( assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] is_on_gpu([A, out, absmax]) - if A.dtype == torch.float32: - if quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cquantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - with torch.cuda.device_of(A): - lib.cquantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - elif A.dtype == torch.float16: - if quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cquantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - with torch.cuda.device_of(A): - lib.cquantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - elif A.dtype == torch.bfloat16: - if quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cquantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + + with torch.cuda.device_of(A): + args = ( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) else: - with torch.cuda.device_of(A): - lib.cquantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") code = get_4bit_type(quant_type, device=A.device) @@ -1281,33 +1242,33 @@ def quantize_4bit( def dequantize_fp4( - A: Tensor, + A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, -) -> Tensor: +) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") def dequantize_nf4( - A: Tensor, + A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, -) -> Tensor: +) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") def dequantize_4bit( - A: Tensor, + A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type="fp4", -) -> Tensor: +) -> torch.Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1368,76 +1329,35 @@ def dequantize_4bit( is_on_gpu([A, absmax, out]) stream = _get_tensor_stream(A) - if out.dtype == torch.float32: - if quant_state.quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cdequantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - else: - with torch.cuda.device_of(A): - lib.cdequantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - elif out.dtype == torch.float16: - if quant_state.quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cdequantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - else: - with torch.cuda.device_of(A): - lib.cdequantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - elif out.dtype == torch.bfloat16: - with torch.cuda.device_of(A): + + with torch.cuda.device_of(A): + args = ( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) + + if out.dtype == torch.bfloat16: if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) + lib.cdequantize_blockwise_bf16_fp4(*args) else: - lib.cdequantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") if A.shape[0] == 1: # is transposed, transpose back return out.t() From b1c4adc4cf8941183129d66d20dcc42210ab2e17 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:44:22 -0500 Subject: [PATCH 37/65] Type annotations, cleanup --- bitsandbytes/autograd/_functions.py | 37 +++++++++++--------- bitsandbytes/cextension.py | 18 ---------- bitsandbytes/functional.py | 12 ++++--- bitsandbytes/nn/modules.py | 10 +++--- bitsandbytes/research/autograd/_functions.py | 3 -- 5 files changed, 31 insertions(+), 49 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 79619cf74..2927b0574 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -244,25 +244,26 @@ def get_tile_inds(format, device): @dataclass class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None + force_no_igemmlt: bool = False - CB = None - CxB = None # TODO: Deprecate/remove - SB = None - SCB = None - CxBt = None # TODO: Deprecate/remove - SBt = None - CBt = None + CB: Optional[torch.Tensor] = None + CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove + SB: Optional[torch.Tensor] = None + SCB: Optional[torch.Tensor] = None + + CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove + SBt: Optional[torch.Tensor] = None + CBt: Optional[torch.Tensor] = None - subB = None + subB: Optional[torch.Tensor] = None - outlier_pool = None + outlier_pool: Optional[GlobalOutlierPooler] = None has_accumulated_gradients = False threshold = 0.0 - idx = None + idx: Optional[torch.Tensor] = None is_training = True has_fp16_weights = True - memory_efficient_backward = False use_pool = False formatB = "row" # TODO: Deprecate/remove @@ -313,10 +314,10 @@ def forward( if A.dtype != torch.float16: warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") - # 1. Quantize A. Note that as a side-effect, outliers are suppressed. if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) + # 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt. if ctx.needs_input_grad[1]: # Slower path CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold) @@ -366,6 +367,8 @@ def forward( # 3. Int8 Matmul out32 = F.int8_linear_matmul(CA, state.CB) + + # Dequantize matmul result if bias is None or bias.dtype == torch.float16: # we apply the fused bias here output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype) @@ -375,7 +378,7 @@ def forward( # 4. Mixed-precision decomposition matmul if subA is not None and state.subB is not None: - output += torch.matmul(subA, state.subB.to(subA.dtype)) + output += torch.matmul(subA, state.subB) # 5. Save state ctx.state = state @@ -399,7 +402,7 @@ def forward( return output @staticmethod - def backward(ctx, grad_output): + def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor): if ctx.is_empty: bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None @@ -407,7 +410,7 @@ def backward(ctx, grad_output): req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states - state = ctx.state + state: MatmulLtState = ctx.state grad_A = grad_B = grad_bias = None if req_gradBias: @@ -499,7 +502,7 @@ def matmul( out: Optional[torch.Tensor] = None, state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None, + bias: Optional[torch.Tensor] = None, ): state = state or MatmulLtState() if threshold > 0.0: @@ -512,7 +515,7 @@ def matmul_4bit( B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, - bias=None, + bias: Optional[torch.Tensor] = None, ): assert quant_state is not None diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 8b0a0c91d..0019ad9a4 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,21 +1,3 @@ -""" -extract factors the build is dependent on: -[X] compute capability - [ ] TODO: Q - What if we have multiple GPUs of different makes? -- CUDA version -- Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) - - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) - -evaluation: - - if paths faulty, return meaningful error - - else: - - determine CUDA version - - determine capabilities - - based on that set the default path -""" - import ctypes as ct import logging import os diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 12b529e78..e74f0bf53 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2279,7 +2279,9 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten ldb = shapeB[-1] # Activations (batch, tokens, inputs) ldc = shapeC[-1] # Output (batch, tokens, outputs) - assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" + assert ( + lda == ldb + ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" is_on_gpu([A, B, out]) @@ -2288,7 +2290,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten ptrA = get_ptr(A) ptrB = get_ptr(B) ptrC = get_ptr(out) - ptrRowScale = get_ptr(None) + ptrRowScale = None m = ct.c_int32(m) n = ct.c_int32(n) k = ct.c_int32(k) @@ -2303,7 +2305,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError("igemmlt not implemented!") + raise NotImplementedError("int8_linear_matmul not implemented!") if has_error: raise RuntimeError( @@ -2369,7 +2371,7 @@ def get_colrow_absmax( col_stats: Optional[torch.Tensor] = None, nnz_block_ptr: Optional[torch.Tensor] = None, threshold=0.0, -): +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # Note: prior impl only works with fp16 assert A.is_floating_point() @@ -2395,7 +2397,7 @@ def get_colrow_absmax( return row_stats, col_stats, outlier_mask -def get_row_absmax(A, threshold=0.0): +def get_row_absmax(A: torch.Tensor, threshold=0.0): assert A.dtype == torch.float16 rows = prod(A.shape[:-1]) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index f5c7f6a7c..e63cd8db9 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -566,11 +566,11 @@ def __init__( class Int8Params(torch.nn.Parameter): def __new__( cls, - data=None, + data: Optional[torch.Tensor] = None, requires_grad=True, has_fp16_weights=False, - CB=None, - SCB=None, + CB: Optional[torch.Tensor] = None, + SCB: Optional[torch.Tensor] = None, ): if data is None: data = torch.empty(0) @@ -881,7 +881,6 @@ def __init__( output_features: int, bias=True, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None, device=None, @@ -898,13 +897,12 @@ def __init__( Whether the linear class uses the bias term as well. """ super().__init__(input_features, output_features, bias, device) - assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights - self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 34eeb02de..abe56d27a 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -328,9 +328,6 @@ def backward(ctx, grad_output): grad_B = torch.matmul(grad_output.t(), A) if req_gradA: - # if state.CBt is not None: - # gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t()) - # grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) if state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) From ed922b8d5c5d57d708741335e2c97b5ff5d001a5 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:58:07 -0500 Subject: [PATCH 38/65] remove unused kernels; improved type annotations --- csrc/kernels.cu | 266 +-------------------------------------- csrc/kernels.cuh | 3 - csrc/ops.cu | 47 +------ csrc/ops.cuh | 3 - csrc/pythonInterface.cpp | 4 - 5 files changed, 4 insertions(+), 319 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 74b9f7c47..b92bbc6ea 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2233,160 +2233,6 @@ __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshol } } -template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) -{ - // 0. reset stats to -FLT_MAX - // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) - // 2. compute col max (per thread); store in smem due to register pressure - // 3. compute row max (per block); store in smem to accumulate full global mem transation - // 4. store data via atomicMax - - // each block loads TILE_COLs columns and TILE_ROW rows - // after reading a tile the row counter increase by TILE_ROWS - // the col counter reset after reading TILE_COL elements - const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; - // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached - const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; - const int base_idx = (base_row*cols) + base_col; - const int items_per_load = ITEMS_PER_THREAD*THREADS; - - typedef cub::BlockLoad LoadT; - typedef cub::BlockReduce BlockRowReduce; - typedef cub::BlockReduce BlockRowSum; - typedef cub::BlockExchange BlockExchange; - - __shared__ union { - typename BlockExchange::TempStorage exchange; - typename BlockRowReduce::TempStorage rowreduce; - typename BlockRowSum::TempStorage rowsum; - typename LoadT::TempStorage loadt; - } temp_storage; - - __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; - __shared__ int smem_row_nnz_values[TILE_ROWS]; - - half local_data[ITEMS_PER_THREAD]; - float local_data_fp32[ITEMS_PER_THREAD]; - float local_col_absmax_values[ITEMS_PER_THREAD]; - int local_row_nnz_count = 0; - float row_absmax = -FLT_MAX; - - // 0. reset stats to -FLT_MAX - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; - smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; - // smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; - } - - #pragma unroll TILE_ROWS - for (int j = 0; j < TILE_ROWS; j++) { - smem_row_nnz_values[j] = 0; - } - - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_col_absmax_values[j] = -FLT_MAX; - - __syncthreads(); - - int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; - int i = base_idx; - // we load row after row from the base_position - // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) - for(int row = 0; row < TILE_ROWS; row++) - { - if(base_row+row >= rows){ break; } - local_row_nnz_count = 0; - i = base_idx + ((row)*cols); - // each thread gets data from the same column - __syncthreads(); - LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f)); - - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_data[j] = fabsf(local_data[j]); - - - if(SPARSE_DECOMP) - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - if((float)local_data[j] >= nnz_threshold) - { - local_row_nnz_count += 1; - local_data[j] = 0.0f; - } - } - - // 2. compute col max (per thread); store in smem due to register pressure - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - // take the col max for this row - // we use shared memory because register pressure is too high if we do this locally - //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); - local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); - - // 3. compute row max (per block); store in smem to accumulate full global mem transation - - // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_data_fp32[j] = local_data[j]; - - __syncthreads(); - - row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); - if(SPARSE_DECOMP) - { - __syncthreads(); - local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count); - } - // we store the data temporarily in shared memory so we - // can execute a full atomic block transaction into global memory later - // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores - if(threadIdx.x == 0) - { - smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; - // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block - smem_row_nnz_values[row] = local_row_nnz_count; - } - - __syncthreads(); - - } - - // 4. store data via atomicMax - // to store col data efficiently we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 - // into a striped arrangement: [0, 8, 16, 24, ..] for t0 - __syncthreads(); - BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values); - - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - if(base_col+threadIdx.x+(j*THREADS) < cols) - { - float val = colStats[base_col+(threadIdx.x+(j*THREADS))]; - if(val < local_col_absmax_values[j]) - atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]); - } - - for(int j = 0; j < ITEMS_PER_THREAD; j++) - if(base_row+threadIdx.x+(j*THREADS) < rows) - { - float val = rowStats[base_row+(threadIdx.x+(j*THREADS))]; - if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)]) - atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]); - } - - if(SPARSE_DECOMP) - if(threadIdx.x < TILE_ROWS) - nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x]; - -} - -template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); -template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); @@ -2430,8 +2276,8 @@ __global__ void kdequant_mm_int32_fp16( row_idx = (block_offset + thread_offset + j) / numCols; col_idx = (block_offset + thread_offset + j) % numCols; - local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; - local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]); + local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]); local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]); } @@ -2439,7 +2285,6 @@ __global__ void kdequant_mm_int32_fp16( int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; - __syncthreads(); LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); #pragma unroll ITEMS_PER_THREAD @@ -2458,110 +2303,6 @@ __global__ void kdequant_mm_int32_fp16( } } -template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) -{ - // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD - // Each thread reads the same column but multiple rows - // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) - - // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) - // 1. Load data row by row (should be at least with TILE_SIZE = 512) - // 2. quantize data with row/col stats - // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) - - // each block loads TILE_COLs columns and TILE_ROW rows - // after reading a tile the row counter increase by TILE_ROWS - // the col counter reset after reading TILE_COL elements - const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; - // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached - const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; - const int base_idx = (base_row*cols) + base_col; - const int items_per_load = ITEMS_PER_THREAD*THREADS; - - typedef cub::BlockLoad LoadHalf; - __shared__ typename LoadHalf::TempStorage loadhalf; - typedef cub::BlockStore StoreInt8; - __shared__ typename StoreInt8::TempStorage storeint8; - - __shared__ float smem_row_stats[TILE_ROWS]; - __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS]; - - half local_data[ITEMS_PER_THREAD]; - float local_col_stats[ITEMS_PER_THREAD]; - char local_quantized_data[ITEMS_PER_THREAD]; - - // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols) - local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]); - - for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x) - { - if(base_row + i < rows) - smem_row_stats[i] = rowStats[base_row+i]; - - if(SPARSE_DECOMP) - smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i]; - } - __syncthreads(); - - // we load row after row from the base_position - // 1. Load data row by row (should be at least with TILE_SIZE = 512) - for(int row = 0; row < TILE_ROWS; row++) - { - if(base_row + row >= rows){ break; } - int i = base_idx + (row*cols); - int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; - - - LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); - float row_stat = __fdividef(127.0f, smem_row_stats[row]); - - // 2. quantize data with row/col stats - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - // we already pre-normalized the col/row stat: - // what this does is float/absmax*127 = int8 - if(SPARSE_DECOMP) - { - if(fabsf((float)local_data[j]) >= threshold) - { - local_quantized_data[j] = 0; - - int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX); - - rowidx[old_idx] = base_row+row; - colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j; - val[old_idx] = local_data[j]; - } - else - { - local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); - } - } - else - local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); - } - - StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); - - // 2. quantize data with row/col stats - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - // we already pre-normalized the col/row stat: - // what this does is float/absmax*127 = int8 - local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j])); - } - - __syncthreads(); - StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); - - } -} - template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) { @@ -3864,9 +3605,6 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>( template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); -template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); -template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); - template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index f17bfe4de..18017c4d2 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -116,12 +116,9 @@ template __global__ void kdequant_mm_int32_fp int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); -template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); template __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); -template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); - template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/ops.cu b/csrc/ops.cu index 88e829bef..afe1eb275 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -465,6 +465,8 @@ template int igemmlt( NULL, NULL, 0, stream )); } else { + // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. + if (!SCALE_ROWS) { float alpha = 1.0f, beta = 0.0f; has_error |= checkCublasStatus(cublasLtMatmul( @@ -532,28 +534,6 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -#define STATS_THREADS 64 -#define STATS_ITEMS 4 -#define STATS_ROWS 16 -void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) -{ - int tile_cols = STATS_THREADS*STATS_ITEMS; - int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); - int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); - int row_tiles = (tiledRows/STATS_ROWS); - int col_tiles = (tiledCols/tile_cols); - row_tiles = row_tiles > 0 ? row_tiles : 1; - col_tiles = col_tiles > 0 ? col_tiles : 1; - int num_blocks = row_tiles * col_tiles; - - if(nnz_threshold == 0.0) - kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); - else if(nnz_threshold != 0.0) - kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - -} - void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { if (threshold == 0.0) kgetRowStats<<>>(A, rowStats, threshold, rows, cols); @@ -562,29 +542,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) -{ - int threads = 64; - int items_per_thread = 4; - int tile_cols = threads*items_per_thread; - int tile_rows = 16; - int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); - int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int row_tiles = (tiledRows/tile_rows); - int col_tiles = (tiledCols/tile_cols); - row_tiles = row_tiles > 0 ? row_tiles : 1; - col_tiles = col_tiles > 0 ? col_tiles : 1; - int num_blocks = row_tiles * col_tiles; - - - if(threshold > 0.0f) - kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); - else - kDoubleRowColQuant<64, 4, 16, 64*4, 0><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); - - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - template void transformRowToFormat(char * A, char *out, int rows, int cols) { int threads = 256; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 5f60051df..1170237e1 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -176,10 +176,7 @@ template int igemmlt(cublasLtHandle_t ltHandle, template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream); -void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); -void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, - int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 441d3adef..0ced0394c 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -335,16 +335,12 @@ extern "C" void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream) { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } - void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) - { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { getRowStats(A, rowStats, threshold, rows, cols, stream); } void cint8_vector_quant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); } - void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols) - { doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); } void ctransform_row2col32(char * A, char *out, int rows, int cols) { transform_row2col32(A, out, rows, cols); } From a93b91ff19e0a5e0347310508f1149e5fd89ced1 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:58:49 -0500 Subject: [PATCH 39/65] small perf optimization for single-GPU systems --- bitsandbytes/functional.py | 121 ++++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 57 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e74f0bf53..d802dc64a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -191,6 +191,16 @@ def get_instance(cls): FIRST_CUDA_DEVICE = torch.device("cuda", index=0) +if torch.cuda.device_count() > 1: + + def _cuda_device_of(a: torch.Tensor): + return torch.cuda.device_of(a) +else: + import contextlib + + def _cuda_device_of(a: torch.Tensor): + return contextlib.nullcontext() + def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): num_bytes = dtype2bytes[dtype] * prod(shape) @@ -881,7 +891,7 @@ def quantize_blockwise( is_on_gpu([A, out, absmax]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): args = ( get_ptr(code), get_ptr(A), @@ -992,7 +1002,7 @@ def dequantize_blockwise( is_on_gpu([A, absmax, out]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): args = ( get_ptr(quant_state.code), get_ptr(A), @@ -1183,7 +1193,7 @@ def quantize_4bit( is_on_gpu([A, out, absmax]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): args = ( get_ptr(None), get_ptr(A), @@ -1330,7 +1340,7 @@ def dequantize_4bit( is_on_gpu([A, absmax, out]) stream = _get_tensor_stream(A) - with torch.cuda.device_of(A): + with _cuda_device_of(A): args = ( get_ptr(None), get_ptr(A), @@ -1547,28 +1557,28 @@ def optimizer_update_32bit( ) is_on_gpu([g, p, state1, state2, unorm_vec]) - prev_device = pre_call(g.device) - optim_func( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - post_call(prev_device) + + with _cuda_device_of(g): + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) @deprecated( @@ -1731,8 +1741,7 @@ def optimizer_update_8bit_blockwise( skip_zeros=False, ) -> None: optim_func = None - prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) + if g.dtype == torch.float32 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: @@ -1747,33 +1756,31 @@ def optimizer_update_8bit_blockwise( raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) - post_call(prev_device) is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - prev_device = pre_call(g.device) - optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - post_call(prev_device) + with _cuda_device_of(g): + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) @@ -1966,7 +1973,7 @@ def gemv_4bit( ldc = ct.c_int32(ldc) stream = _get_tensor_stream(A) - with torch.cuda.device_of(A): + with _cuda_device_of(A): if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16( @@ -2285,7 +2292,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten is_on_gpu([A, B, out]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): ctx = CUBLAS_Context.get_instance().get_context(A.device) ptrA = get_ptr(A) ptrB = get_ptr(B) @@ -2343,7 +2350,7 @@ def int8_mm_dequant( is_on_gpu([A, row_stats, col_stats, out, bias]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) ) @@ -2407,7 +2414,7 @@ def get_row_absmax(A: torch.Tensor, threshold=0.0): is_on_gpu([A]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): lib.cget_row_stats( get_ptr(A), get_ptr(row_stats), @@ -2550,7 +2557,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): if outliers.any(): outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - with torch.cuda.device_of(A): + with _cuda_device_of(A): lib.cint8_vector_quant( get_ptr(A), get_ptr(out_row), From 4bced868ae3937379a620cd605e326e703e4cc4e Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:00:19 -0500 Subject: [PATCH 40/65] small perf optimization for single-GPU systems --- bitsandbytes/functional.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d802dc64a..b7ed0d1c3 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -191,6 +191,10 @@ def get_instance(cls): FIRST_CUDA_DEVICE = torch.device("cuda", index=0) +# When multiple GPUs are present, we use a context manager to +# switch to the correct device of a tensor before invoking our CUDA +# kernels in the C++ library. However, when there's only one device +# there is no need to incur the overhead of cudaGetDevice/cudaSetDevice. if torch.cuda.device_count() > 1: def _cuda_device_of(a: torch.Tensor): From f61d8bc2cb0eb5d819aac5542db8aba8af36d9d6 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:25:19 -0500 Subject: [PATCH 41/65] update docstrings --- bitsandbytes/functional.py | 115 ++++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b7ed0d1c3..7ddd01613 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -437,6 +437,20 @@ def get_special_format_str(): def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): + """Verifies that the input tensors are all on the same device. + + An input tensor may also be marked as `paged`, in which case the device placement is ignored. + + Args: + tensors (Iterable[Optional[torch.Tensor]]): A list of tensors to verify. + + Raises: + `RuntimeError`: Raised when the verification fails. + + Returns: + `Literal[True]` + """ + on_gpu = True gpu_ids = set() @@ -1199,7 +1213,7 @@ def quantize_4bit( with _cuda_device_of(A): args = ( - get_ptr(None), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -1346,7 +1360,7 @@ def dequantize_4bit( with _cuda_device_of(A): args = ( - get_ptr(None), + None, get_ptr(A), get_ptr(absmax), get_ptr(out), @@ -2255,6 +2269,25 @@ def igemmlt( def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): + """Performs an 8-bit integer matrix multiplication. + + A linear transformation is applied such that `out = A @ B.T`. When possible, integer tensor core hardware is + utilized to accelerate the operation. + + Args: + A (`torch.Tensor`): The first matrix operand with the data type `torch.int8`. + B (`torch.Tensor`): The second matrix operand with the data type `torch.int8`. + out (`torch.Tensor, *optional*): A pre-allocated tensor used to store the result. + dtype (`torch.dtype`, *optional*): The expected data type of the output. Defaults to `torch.int32`. + + Raises: + `NotImplementedError`: The operation is not supported in the current environment. + `RuntimeError`: Raised when the cannot be completed for any other reason. + + Returns: + `torch.Tensor`: The result of the operation. + """ + # # To use the IMMA tensor core kernels without special Turing/Ampere layouts, # cublasLt has some rules, namely: A must be transposed, B must not be transposed. @@ -2336,6 +2369,19 @@ def int8_mm_dequant( out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ): + """Performs dequantization on the result of a quantized int8 matrix multiplication. + + Args: + A (`torch.Tensor` with dtype `torch.int32`): The result of a quantized int8 matrix multiplication. + row_stats (`torch.Tensor`): The row-wise quantization statistics for the lhs operand of the matrix multiplication. + col_stats (`torch.Tensor`): The column-wise quantization statistics for the rhs operand of the matrix multiplication. + out (`torch.Tensor], *optional*): A pre-allocated tensor to store the output of the operation. + bias (`torch.Tensor`, *optional*): An optional bias vector to add to the result. + + Returns: + `torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`. + """ + assert A.dtype == torch.int32 if bias is not None: @@ -2409,6 +2455,20 @@ def get_colrow_absmax( def get_row_absmax(A: torch.Tensor, threshold=0.0): + """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. + + For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). + + Args: + A (`torch.Tensor` with dtype `torch.float16`): The input matrix. + threshold (`float`, *optional*): + An optional threshold for sparse decomposition of outlier features. + No outliers are held back when 0.0. Defaults to 0.0. + + Returns: + `torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored. + """ + assert A.dtype == torch.float16 rows = prod(A.shape[:-1]) @@ -2520,6 +2580,37 @@ def double_quant( out_row: Optional[torch.Tensor] = None, threshold=0.0, ): + """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. + + The statistics are determined both row-wise and column-wise (transposed). + + For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). + + + This function is useful for training, but for inference it is advised to use [`int8_vectorwise_quant`] instead. + This implementation performs additional column-wise transposed calculations which are not optimized. + + + Args: + A (`torch.Tensor` with dtype `torch.float16`): The input matrix. + col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales. + row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales. + out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data. + out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data. + threshold (`float`, *optional*): + An optional threshold for sparse decomposition of outlier features. + + No outliers are held back when 0.0. Defaults to 0.0. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics. + - `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data. + - `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data. + - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales. + - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales. + - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. + """ + # TODO: Optimize/write CUDA kernel for this? # Note: for inference, use the new int8_vectorwise_quant. @@ -2541,6 +2632,24 @@ def double_quant( def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): + """Quantizes a tensor with dtype `torch.float16` to `torch.int8` in accordance to the `LLM.int8()` algorithm. + + For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). + + Args: + A (`torch.Tensor` with dtype `torch.float16`): The input tensor. + threshold (`float`, *optional*): + An optional threshold for sparse decomposition of outlier features. + + No outliers are held back when 0.0. Defaults to 0.0. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics. + - `torch.Tensor` with dtype `torch.int8`: The quantized data. + - `torch.Tensor` with dtype `torch.float32`: The quantization scales. + - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. + """ + assert A.dtype == torch.half is_on_gpu([A]) @@ -2838,7 +2947,7 @@ def vectorwise_dequant(xq, max1, quant_type="vector"): @deprecated( - "This function is deprecated and will be removed in a future release. Consider using `mm_dequant` instead.", + "This function is deprecated and will be removed in a future release. Consider using `int8_mm_dequant` instead.", category=FutureWarning, ) def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): From eed9c3cf65f661dd853da21e5be03aeb9878009c Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 18 Nov 2024 14:08:13 -0500 Subject: [PATCH 42/65] Improve docs and tests --- bitsandbytes/autograd/_functions.py | 4 +- bitsandbytes/functional.py | 72 ++++++++++++++++++-- bitsandbytes/research/autograd/_functions.py | 6 +- tests/test_functional.py | 33 ++++----- tests/test_modules.py | 2 - 5 files changed, 88 insertions(+), 29 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 2927b0574..ad7624c38 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -320,7 +320,7 @@ def forward( # 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt. if ctx.needs_input_grad[1]: # Slower path - CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold) else: # Fast path CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold) @@ -422,7 +422,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() if req_gradB: - Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) + Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16)) gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t()) grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7ddd01613..f9ccdc2e1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -442,7 +442,7 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): An input tensor may also be marked as `paged`, in which case the device placement is ignored. Args: - tensors (Iterable[Optional[torch.Tensor]]): A list of tensors to verify. + tensors (`Iterable[Optional[torch.Tensor]]`): A list of tensors to verify. Raises: `RuntimeError`: Raised when the verification fails. @@ -2572,6 +2572,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) +@deprecated("This function is deprecated. Please use `int8_double_quant` instead.", category=FutureWarning) def double_quant( A: torch.Tensor, col_stats: Optional[torch.Tensor] = None, @@ -2579,6 +2580,72 @@ def double_quant( out_col: Optional[torch.Tensor] = None, out_row: Optional[torch.Tensor] = None, threshold=0.0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[COOSparseTensor]]: + """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. + + The statistics are determined both row-wise and column-wise (transposed). + + For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). + + + This function exists for backwards compatibility only. It is advised to use [`int8_double_quant`] instead. + The difference is that this function will return a [`COOSparseTensor`] for outliers instead of a column index. + + + Args: + A (`torch.Tensor` with dtype `torch.float16`): The input matrix. + col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales. + row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales. + out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data. + out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data. + threshold (`float`, *optional*): + An optional threshold for sparse decomposition of outlier features. + + No outliers are held back when 0.0. Defaults to 0.0. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics. + - `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data. + - `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data. + - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales. + - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales. + - `COOSparseTensor`, *optional*: A structure representing the outlier values from the input tensor. + """ + + coo_tensor = None + quant_row, quant_col, row_stats, col_stats, _ = int8_double_quant( + A, + col_stats, + row_stats, + out_col, + out_row, + threshold=threshold, + ) + + if threshold > 0.0: + # Build COO tensor for any outliers. + outlier_mask = A.abs() >= threshold + outlier_locations = outlier_mask.nonzero() + outliers = A[outlier_mask] + coo_tensor = COOSparseTensor( + A.shape[0], + A.shape[1], + outliers.numel(), + outlier_locations[:, 0].int(), + outlier_locations[:, 1].int(), + outliers, + ) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor + + +def int8_double_quant( + A: torch.Tensor, + col_stats: Optional[torch.Tensor] = None, + row_stats: Optional[torch.Tensor] = None, + out_col: Optional[torch.Tensor] = None, + out_row: Optional[torch.Tensor] = None, + threshold=0.0, ): """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. @@ -2612,7 +2679,6 @@ def double_quant( """ # TODO: Optimize/write CUDA kernel for this? - # Note: for inference, use the new int8_vectorwise_quant. # Use CUDA kernel for rowwise and COO tensor quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold) @@ -2665,8 +2731,6 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): # TODO we could improve perf of this outliers = A.abs() >= threshold - # argwhere needs host/device sync, so we skip when - # there aren't actually any outliers. if outliers.any(): outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index abe56d27a..d9718382b 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -215,7 +215,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and outlier_cols is not None: if state.has_fp16_weights: @@ -248,7 +248,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non state.SCB, state.SCBt, _, - ) = F.double_quant(B.to(torch.float16)) + ) = F.int8_double_quant(B.to(torch.float16)) state.SB = (state.CB.shape, "row") else: has_grad = False @@ -320,7 +320,7 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.double_quant(grad_output.to(torch.float16)) + Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.int8_double_quant(grad_output.to(torch.float16)) if req_gradB: # print('back A shape', A.shape) diff --git a/tests/test_functional.py b/tests/test_functional.py index ecc2261fa..948ba5e20 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -606,8 +606,8 @@ def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims): A = A.view(-1, A.shape[-1]) - CA, _, statsA, _, _ = F.double_quant(A) - CB, _, statsB, _, _ = F.int8_vectorwise_quant(B) + CA, _, statsA, _, _ = F.int8_double_quant(A) + CB, statsB, _ = F.int8_vectorwise_quant(B) output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB) torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) @@ -863,7 +863,7 @@ def test_double_quant(dim1, dim2): out_col1, Scol = F.vectorwise_quant(A, dim=0) out_row1, Srow = F.vectorwise_quant(A, dim=1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + CA, CAt, statsA, statsAt, coo_tensor = F.int8_double_quant(A) # max difference is 1 due to rounding differences torch.testing.assert_close(CA, out_row1, atol=1, rtol=0) @@ -953,7 +953,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): out1 = torch.matmul(A.half(), B.t().half()) - C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A) CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") A2, SA = F.nvidia_transform(C1a, "col32") B2, SB = F.nvidia_transform(CB, formatB) @@ -1032,7 +1032,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() print("16", time.time() - t0) - C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A) CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") A2, SA = F.nvidia_transform(C1a, "col32") B2, SB = F.nvidia_transform(CB, formatB) @@ -1047,7 +1047,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() print("row-wise", time.time() - t0) - C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) torch.cuda.synchronize() t0 = time.time() @@ -1115,7 +1115,8 @@ def test_coo_double_quant(dim1, dim2): if coo_tensor is not None: A1 = A * idx - A2 = coo_tensor.to_dense() + A2 = torch.zeros_like(A) + A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) @@ -1133,14 +1134,9 @@ def test_coo_int8_vectorwise_quant(dim1, dim2): A = torch.randn(dim1, dim2, device="cuda").half() idx = torch.abs(A) >= threshold - CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold) + CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) - if coo_tensor is not None: - A1 = A * idx - A2 = coo_tensor.to_dense() - torch.testing.assert_close(A1, A2) - - A1 = A * (idx == 0) + if outlier_cols is not None: A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) @@ -1230,13 +1226,14 @@ def test_integrated_sparse_decomp(dim1, dim2): w1 = torch.randn(dim1, dim2).cuda().half() out1 = torch.matmul(A, w1.t()) - Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1) - CA, statsA, coo_tensor = F.int8_vectorwise_quant(A) + Cw1, statsw1, _ = F.int8_vectorwise_quant(w1) + CA, statsA, _ = F.int8_vectorwise_quant(A) out1_32 = F.int8_linear_matmul(CA, Cw1) out2 = F.int8_mm_dequant(out1_32, statsA, statsw1) - CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold) + # CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) + CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold) out1_32 = F.int8_linear_matmul(CA, Cw1) out3 = F.int8_mm_dequant(out1_32, statsA, statsw1) @@ -1377,7 +1374,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): torch.nn.init.xavier_uniform_(B) Bt = B.t().contiguous() - CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) + CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B) rowidx = torch.randint(0, A.shape[-1], size=(15,)) diff --git a/tests/test_modules.py b/tests/test_modules.py index 9e16b5e2d..9e6a708b9 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -356,7 +356,6 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) -@pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): l1 = ( bnb.nn.Linear8bitLt( @@ -364,7 +363,6 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): 64, threshold=threshold, has_fp16_weights=False, - memory_efficient_backward=memory_efficient_backward, ) .cuda() .half() From 6e0a4b3304012309b27b17162884b73a2c4f869d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 18 Nov 2024 14:40:51 -0500 Subject: [PATCH 43/65] Update docstring --- bitsandbytes/functional.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f9ccdc2e1..7d7547130 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2429,7 +2429,32 @@ def get_colrow_absmax( nnz_block_ptr: Optional[torch.Tensor] = None, threshold=0.0, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Note: prior impl only works with fp16 + """ "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. + + The row-wise and column-wise absmax values are determined. + + For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). + + + This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead. + The column-wise quantization scales are not typically needed in inference scenarios. + + + Args: + A (`torch.Tensor` with dtype `torch.float16`): Input tensor. + row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped. + col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped. + nnz_block_ptr (`torch.Tensor`, *optional*): Not used. + threshold (`float`, `optional`): + An optional threshold for sparse decomposition of outlier features. + No outliers are held back when 0.0. Defaults to 0.0. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics. + - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics. + - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics. + - `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor. + """ assert A.is_floating_point() outlier_mask = None From 161c1949fdaeaabc272d409f0478a6548a77918d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:01:32 -0500 Subject: [PATCH 44/65] Update test --- tests/test_modules.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 9e6a708b9..239c7d3a6 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -17,20 +17,18 @@ def __init__(self, initial_data): class MLP8bit(torch.nn.Module): - def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): + def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): super().__init__() self.fc1 = bnb.nn.Linear8bitLt( dim1, dim2, has_fp16_weights=has_fp16_weights, - memory_efficient_backward=memory_efficient_backward, threshold=threshold, ) self.fc2 = bnb.nn.Linear8bitLt( dim2, dim1, has_fp16_weights=has_fp16_weights, - memory_efficient_backward=memory_efficient_backward, threshold=threshold, ) @@ -326,7 +324,7 @@ def test_linear8bitlt_accumulated_gradient(): acc_steps = 10 - for i in range(10): + for i in range(15): b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) o2 = l2(b1) @@ -356,7 +354,7 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) -def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): +def test_linear8bitlt_no_fp16_weights(threshold): l1 = ( bnb.nn.Linear8bitLt( 32, @@ -420,7 +418,6 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): 64, threshold=threshold, has_fp16_weights=False, - memory_efficient_backward=memory_efficient_backward, ) .half() .to("cuda") @@ -444,7 +441,6 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): 64, threshold=threshold, has_fp16_weights=False, - memory_efficient_backward=memory_efficient_backward, ) w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, mlp = mlp.cuda().half() # and this line triggers quantization @@ -463,21 +459,20 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" - if memory_efficient_backward: - b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) - o1 = mlp(b1) - assert o1.dtype == torch.float16 - assert o1.requires_grad - grad_proj = torch.randn_like(o1) + b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) + o1 = mlp(b1) + assert o1.dtype == torch.float16 + assert o1.requires_grad + grad_proj = torch.randn_like(o1) - mlp.zero_grad() - (o1 * grad_proj).sum().backward() - grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() - scale = grad_ref.abs().mean() + mlp.zero_grad() + (o1 * grad_proj).sum().backward() + grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() + scale = grad_ref.abs().mean() - torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) - idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) - assert (idx == 0).sum().item() <= b1.numel() * 0.005 + torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) + idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) + assert (idx == 0).sum().item() <= b1.numel() * 0.005 @pytest.mark.parametrize( From e3051fa8a2c0fd088bdb6bebbf752fb2197381de Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 20 Nov 2024 13:48:21 -0500 Subject: [PATCH 45/65] add benchmarking script --- benchmarking/int8/int8_benchmark.py | 70 +++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 benchmarking/int8/int8_benchmark.py diff --git a/benchmarking/int8/int8_benchmark.py b/benchmarking/int8/int8_benchmark.py new file mode 100644 index 000000000..541cf6def --- /dev/null +++ b/benchmarking/int8/int8_benchmark.py @@ -0,0 +1,70 @@ +""" +Basic benchmark for text generation. + +Usage: python benchmarking/int8/int8_benchmark.py +""" + +import time + +import torch +from torch.profiler import ProfilerActivity, profile +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +MAX_NEW_TOKENS = 128 +model_name = "meta-llama/Llama-3.1-8B" + +text = "Below is a question. I need an answer.\n\nExplain machine learning: " +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_ids = tokenizer([text] * 8, return_tensors="pt").input_ids.to(0) + +max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" + +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_threshold=6.0, + ), + attn_implementation="sdpa", + torch_dtype=torch.float16, +) + +print(model) + +# warmup +print("Warmup...") +for i in range(3): + generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS) + +print("Profiler starting...") +with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + with_modules=True, + with_stack=True, +) as prof: + model.generate(input_ids, max_new_tokens=1) + +print( + prof.key_averages().table( + sort_by="cpu_time_total", + max_name_column_width=50, + top_level_events_only=True, + row_limit=50, + ) +) + +torch.cuda.synchronize() + + +print("Generating...") +num = 0 +time_1 = time.time() +for i in range(5): + generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS) + num += len(generated_ids[0]) + +print("=" * 40) +print(f"Example:\n{tokenizer.decode(generated_ids[0])}") +print("=" * 40) +print(f"Speed: {num/(time.time() - time_1)}token/s") From 56abdc2ec492a9251f96827d6ce48f70b0206c16 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 20 Nov 2024 17:35:58 -0500 Subject: [PATCH 46/65] test cleanup: add deprecated marker, move benchmarks out --- benchmarking/int8/int8_benchmark.py | 2 - benchmarking/int8/row_scale_benchmark.py | 70 +++ benchmarking/int8/training_benchmark.py | 171 +++++ benchmarking/matmul_benchmark.py | 213 +++++++ pytest.ini | 1 + tests/test_autograd.py | 7 +- tests/test_functional.py | 754 +++++------------------ tests/test_linear8bitlt.py | 5 - tests/test_modules.py | 5 - 9 files changed, 612 insertions(+), 616 deletions(-) create mode 100644 benchmarking/int8/row_scale_benchmark.py create mode 100644 benchmarking/int8/training_benchmark.py create mode 100644 benchmarking/matmul_benchmark.py diff --git a/benchmarking/int8/int8_benchmark.py b/benchmarking/int8/int8_benchmark.py index 541cf6def..b91e5f76f 100644 --- a/benchmarking/int8/int8_benchmark.py +++ b/benchmarking/int8/int8_benchmark.py @@ -17,8 +17,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_name) input_ids = tokenizer([text] * 8, return_tensors="pt").input_ids.to(0) -max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" - model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", diff --git a/benchmarking/int8/row_scale_benchmark.py b/benchmarking/int8/row_scale_benchmark.py new file mode 100644 index 000000000..98d2496de --- /dev/null +++ b/benchmarking/int8/row_scale_benchmark.py @@ -0,0 +1,70 @@ +""" +Extracted from tests/test_functional.py + +Note: This feature is currently unused! It is kept here for archival purposes. + +Usage: pytest benchmarking/int8/row_scale_benchmark.py +""" + +import time + +import pytest +import torch + +from bitsandbytes import functional as F + +k = 20 +torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) + + +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + [ + pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"), + pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"), + ], +) +@pytest.mark.skip("Row scale has some bugs for ampere") +@pytest.mark.benchmark +def test_row_scale_bench(dim1, dim4, inner): + formatB = F.get_special_format_str() + err1, err2, err3 = [], [], [] + relerr1, relerr2 = [], [] + scale = 1 + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() + torch.nn.init.xavier_uniform_(B) + # warmpup + for i in range(k): + C1 = torch.matmul(A, B.t()) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + C1 = torch.matmul(A, B.t()) + torch.cuda.synchronize() + print("16", time.time() - t0) + + C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A) + CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") + A2, SA = F.nvidia_transform(C1a, "col32") + B2, SB = F.nvidia_transform(CB, formatB) + A1, maxA = F.vectorwise_quant(A, dim=1) + + c = 10.0 * inner * scale + row_scale = maxA / c + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale) + torch.cuda.synchronize() + print("row-wise", time.time() - t0) + + C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B) + B2, SB = F.nvidia_transform(C2a, formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + outC32 = F.int8_linear_matmul(A2, B2) + torch.cuda.synchronize() + print("vector-wise", time.time() - t0) diff --git a/benchmarking/int8/training_benchmark.py b/benchmarking/int8/training_benchmark.py new file mode 100644 index 000000000..32060afde --- /dev/null +++ b/benchmarking/int8/training_benchmark.py @@ -0,0 +1,171 @@ +""" +Extracted from tests/test_functional.py + +Usage: pytest benchmarking/int8/training_benchmark.py +""" + +import time + +import pytest +import torch + +from bitsandbytes import functional as F + +k = 20 + + +@pytest.mark.parametrize( + ("batch", "seq", "model", "hidden"), + [ + pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"), + pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"), + pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"), + ], +) +@pytest.mark.benchmark +def test_bench_8bit_training(batch, seq, model, hidden): + formatB = F.get_special_format_str() + A = torch.randn(batch, seq, model, device="cuda").half() + grad = torch.randn(batch, seq, model, device="cuda").half() + w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half() + w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half() + print("") + + # torch.cuda.synchronize() + ## warmup + # for i in range(100): + # torch.matmul(A, w1.t()) + # torch.cuda.synchronize() + + dtype = torch.int8 + A = A.view(-1, A.shape[-1]).contiguous() + grad = grad.view(-1, grad.shape[-1]).contiguous() + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + out1 = torch.matmul(A, w1.t()) # fc1 + # out2 = torch.matmul(out1, w2.t())# fc2 + + # d1 = torch.matmul(grad, w2) # delta1 + # d2 = torch.matmul(d1, w1) # delta2 + + # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 + # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 + + torch.cuda.synchronize() + t16 = time.time() - t0 + print(t16) + + # torch.cuda.empty_cache() + + # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + # CTw1, Sw1 = F.transform2(Cw1, formatB) + # CTw2, Sw2 = F.transform2(Cw2, formatB) + # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + + # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + # C32A, SA = F.transform2(CA, 'col32') + ## fc1 + # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t) + + ## fc2 + # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + # C32out1, Sout1 = F.transform2(Cout1, 'col32') + # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t) + + ## delta1 + # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + # C32grad, Sgrad = F.transform2(Cgrad, 'col32') + ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) + ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2) + + ## delta2 + # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + # C32d1, Sd1 = F.transform2(Cd1, 'col32') + ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) + ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1) + + ## grad1 + # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) + ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad) + + ## grad2 + # C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) + ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1) + + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + # CTw1, Sw1 = F.transform2(Cw1, formatB) + # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # CTw2, Sw2 = F.transform2(Cw2, formatB) + # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(k): + # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # #CTw1, Sw1 = F.transform2(Cw1, formatB) + # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # #CTw1, Sw1 = F.transform2(Cw1, formatB) + + # #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5) + # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + # #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # #CTw2, Sw2 = F.transform2(Cw2, formatB) + # #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + + # C32A, SA = F.transform2(CA, 'col32') + + # # fc1 + # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + # #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + # #print(coo_tensor.nnz) + # #out1sp = F.spmm_coo(coo_tensor, w1.t()) + # #print(w1.t().shape) + # #out1 = out1dn + out1sp + + # # fc2 + # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + # C32out1, Sout1 = F.transform2(Cout1, 'col32') + # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + # #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2) + + # # delta1 + # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + # C32grad, Sgrad = F.transform2(Cgrad, 'col32') + # d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) + # #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t) + + # # delta2 + # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + # C32d1, Sd1 = F.transform2(Cd1, 'col32') + # d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) + # #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t) + + # # grad1 + # #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + # #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + # #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) + # #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt) + + # ## grad2 + # #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + # #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) + # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t) + + # torch.cuda.synchronize() + # t8 = time.time() - t0 + # print(t8) diff --git a/benchmarking/matmul_benchmark.py b/benchmarking/matmul_benchmark.py new file mode 100644 index 000000000..89b3dfb8a --- /dev/null +++ b/benchmarking/matmul_benchmark.py @@ -0,0 +1,213 @@ +""" +Extracted from tests/test_functional.py + +Usage: pytest benchmarking/matmul_benchmark.py +""" + +import time + +import pytest +import torch + +import bitsandbytes as bnb +from bitsandbytes import functional as F + +k = 20 + +torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) + + +@pytest.mark.parametrize( + ("batch", "seq", "model", "hidden"), + [ + # pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"), + pytest.param(1, 1, 3584, 512, id="batch=1, seq=128, model=3584, hidden=19k"), + # pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"), + # pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k") + ], +) +@pytest.mark.benchmark +def test_bench_matmul(batch, seq, model, hidden): + iters = 1000 + formatB = F.get_special_format_str() + + A = torch.randn(batch, seq, model, device="cuda").half() + B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") + torch.nn.init.xavier_uniform_(B) + + B_fp4, state = F.quantize_fp4(B) + B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) + + B_nf4, state_nf4 = F.quantize_nf4(B) + B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True) + + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() + linear8bit.eval() + + outliers = torch.randint(0, model, size=(5,)).cuda() + A[:, :, outliers] = 8.0 + + linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half() + # linearMixedBit.eval() + + linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) + + # warmup + for i in range(iters): + torch.matmul(A, B.t()) + torch.cuda.synchronize() + print("") + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + torch.matmul(A, B.t()) + torch.cuda.synchronize() + print( + f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s", + ) + + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) + # torch.cuda.synchronize() + # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) + # torch.cuda.synchronize() + # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) + torch.cuda.synchronize() + print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) + torch.cuda.synchronize() + print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul(A, B) + torch.cuda.synchronize() + print( + f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul(A, B, threshold=6.0) + torch.cuda.synchronize() + print( + f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0) + CB, SCB, _ = F.int8_vectorwise_quant(B) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + out32 = F.int8_linear_matmul(CA, CB) + torch.cuda.synchronize() + print( + f"no overhead int8 [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + # C32A, SA = F.transform(CA, "col32") + + # CxB, SB = F.transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # torch.cuda.synchronize() + # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + # C32A, SA = F.transform(CA, "col32") + # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + # CxB, SB = F.transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # torch.cuda.synchronize() + # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # BA, statsB = F.vectorwise_quant(B, dim=1) + # CxB, SB = F.nvidia_transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1) + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + # torch.cuda.synchronize() + # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + # CxB, SB = F.nvidia_transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # out = Cout * statsB * statsA * (1.0 / (127 * 127)) + # torch.cuda.synchronize() + # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + linear8bit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linear8bit(A) + torch.cuda.synchronize() + print( + f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + linearMixedBit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linearMixedBit(A) + torch.cuda.synchronize() + print( + f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) + + # linear8bit_train(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # linear8bit_train(A) + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # linear8bit_train_thresh(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # linear8bit_train(A) + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") diff --git a/pytest.ini b/pytest.ini index ac6d72e63..0090e0ca7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -11,3 +11,4 @@ log_file = logs/pytest.log markers = benchmark: mark test as benchmark slow: mark test as slow + deprecated: mark test as covering a deprecated feature diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 3422550f8..ae2529542 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -28,6 +28,7 @@ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) +@pytest.mark.deprecated def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]): if dim2 > 0: dim2 = dim2 - (dim2 % 16) @@ -198,12 +199,8 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool assert (idx == 0).sum().item() < n * 0.02 -# @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) -# @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) -# @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) -# @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) # [64,0] +@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [48], ids=id_formatter("dim4")) @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) diff --git a/tests/test_functional.py b/tests/test_functional.py index 948ba5e20..27e12953f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -114,54 +114,11 @@ def test_estimate_quantiles(dtype): assert (diff > 5e-02).sum().item() == 0 -def test_quantile_quantization(): - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - code = F.estimate_quantiles(A1) - C = F.quantize_no_absmax(A1, code) - A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1 - A2).mean().item() - assert diff < 0.0075 - - A1 = torch.rand(1024, 1024, device="cuda") - code = F.estimate_quantiles(A1) - C = F.quantize_no_absmax(A1, code) - A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0) - assert diff < 0.001 - - -def test_dynamic_quantization(): - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize(A1) - A2 = F.dequantize(C, S) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - assert diff.mean().item() < 0.0135 - print(sum(diffs) / len(diffs)) - print(sum(reldiffs) / len(reldiffs)) - - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize(A1) - A2 = F.dequantize(C, S) - diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) - assert diff < 0.004 - - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): - # print('') diffs = [] reldiffs = [] for i in range(100): @@ -204,33 +161,6 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) -@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) -def test_percentile_clipping(gtype): - gnorm_vec1 = torch.zeros(100, device="cuda") - gnorm_vec2 = torch.zeros(100, device="cuda") - n = 4 - step = 0 - percentile = 5 - for i in range(k): - step += 1 - g = torch.randn(n, n, dtype=gtype, device="cuda") - gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) - assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 - - gnorm2 = torch.norm(g.float()) - if step == 1: - gnorm_vec1[:] = gnorm2 - else: - gnorm_vec1[step % 100] = gnorm2 - - vals, idx = torch.sort(gnorm_vec1) - clip1 = vals[percentile] - - torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) - torch.testing.assert_close(clip1, clip2) - torch.testing.assert_close(gnorm1, gnorm2) - - def quant(x): max1 = torch.abs(x).max() x = torch.round(x / max1 * 127) @@ -495,81 +425,6 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): torch.testing.assert_close(out.float(), out2.float()) -@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) -def test_vector_quant(dim1, dim2, dim3): - dim2 = dim2 - (dim2 % 16) - dim3 = dim3 - (dim3 % 16) - for i in range(k): - A = torch.randn(size=(dim2, dim3), device="cuda") - qA, SA = F.vectorwise_quant(A, dim=0) - A1 = F.vectorwise_dequant(qA, SA) - n = A1.numel() - assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) - - -@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype) -@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) -@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut")) -@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose")) -@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims")) -def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - if dims == 3 and orderOut != "col32": - return - if dtype == torch.int32 and orderOut != "col32": - return - try: - func = F.get_transform_func(dtype, orderA, orderOut, transpose) - except ValueError as ve: - pytest.skip(str(ve)) # skip if not supported - - if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) - elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) - - out, S = F.nvidia_transform(A, to_order=orderOut) - - if orderOut == "row": - torch.testing.assert_close(A.flatten(), out.flatten()) - elif orderOut == "col": - torch.testing.assert_close(A.t().flatten(), out.flatten()) - elif orderOut == "col32": - if dims == 2: - n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) - elif dims == 3: - n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) - assert out.numel() == n - elif orderOut == "col_turing": - # 32 col 8 row tiles - n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32))) - assert out.numel() == n - total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) - for row in range(A.shape[0]): - for col in range(A.shape[1]): - i = row * A.shape[1] - j = col - - coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile - offset = 32 * 8 * (rowtile + coltile) - col2 = col % 32 - row2 = (row % 8) * 32 - - assert A.flatten()[i + j] == A[row, col] - # assert A.flatten()[i+j] == out.flatten()[row2+col2] - # torch.testing.assert_close(A.flatten()[i+j], A[row, col]) - # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) - - if orderOut == "col32": - out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) - torch.testing.assert_close(A, out2) - - @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [512], ids=id_formatter("dim3")) @@ -613,171 +468,12 @@ def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims): torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) -@pytest.mark.parametrize( - ("batch", "seq", "model", "hidden"), - [ - pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"), - pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"), - pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"), - ], -) -@pytest.mark.benchmark -def test_bench_8bit_training(batch, seq, model, hidden): - formatB = F.get_special_format_str() - A = torch.randn(batch, seq, model, device="cuda").half() - grad = torch.randn(batch, seq, model, device="cuda").half() - w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half() - w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half() - print("") - - # torch.cuda.synchronize() - ## warmup - # for i in range(100): - # torch.matmul(A, w1.t()) - # torch.cuda.synchronize() - - dtype = torch.int8 - A = A.view(-1, A.shape[-1]).contiguous() - grad = grad.view(-1, grad.shape[-1]).contiguous() - torch.cuda.synchronize() - t0 = time.time() - for i in range(k): - out1 = torch.matmul(A, w1.t()) # fc1 - # out2 = torch.matmul(out1, w2.t())# fc2 - - # d1 = torch.matmul(grad, w2) # delta1 - # d2 = torch.matmul(d1, w1) # delta2 - - # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 - # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 - - torch.cuda.synchronize() - t16 = time.time() - t0 - print(t16) - - # torch.cuda.empty_cache() - - # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - - # CTw1, Sw1 = F.transform2(Cw1, formatB) - # CTw2, Sw2 = F.transform2(Cw2, formatB) - # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) - # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) - - # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - # C32A, SA = F.transform2(CA, 'col32') - ## fc1 - # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) - ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t) - - ## fc2 - # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) - # C32out1, Sout1 = F.transform2(Cout1, 'col32') - # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) - ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t) - - ## delta1 - # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) - # C32grad, Sgrad = F.transform2(Cgrad, 'col32') - ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) - ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2) - - ## delta2 - # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) - # C32d1, Sd1 = F.transform2(Cd1, 'col32') - ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) - ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1) - - ## grad1 - # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) - # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) - ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) - ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad) - - ## grad2 - # C32At, SAt = F.transform2(CAt, 'col32', transpose=True) - # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) - ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) - ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1) - - # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - - # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - - # CTw1, Sw1 = F.transform2(Cw1, formatB) - # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) - # CTw2, Sw2 = F.transform2(Cw2, formatB) - # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(k): - # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - # #CTw1, Sw1 = F.transform2(Cw1, formatB) - # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - # #CTw1, Sw1 = F.transform2(Cw1, formatB) - - # #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5) - # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - # #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) - # #CTw2, Sw2 = F.transform2(Cw2, formatB) - # #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) - - # C32A, SA = F.transform2(CA, 'col32') - - # # fc1 - # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) - # #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) - - # #print(coo_tensor.nnz) - # #out1sp = F.spmm_coo(coo_tensor, w1.t()) - # #print(w1.t().shape) - # #out1 = out1dn + out1sp - - # # fc2 - # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) - # C32out1, Sout1 = F.transform2(Cout1, 'col32') - # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) - # #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2) - - # # delta1 - # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) - # C32grad, Sgrad = F.transform2(Cgrad, 'col32') - # d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) - # #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t) - - # # delta2 - # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) - # C32d1, Sd1 = F.transform2(Cd1, 'col32') - # d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) - # #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t) - - # # grad1 - # #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) - # #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) - # #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) - # #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt) - - # ## grad2 - # #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) - # #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) - # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) - # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t) - - # torch.cuda.synchronize() - # t8 = time.time() - t0 - # print(t8) - - -# @pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) -# @pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_dequant_mm(dim1, dim4, dims, has_bias): - inner = 128 # torch.randint(1, 128, size=(1,)).item() + inner = 128 bias = None if has_bias: bias = torch.randn(dim4, device="cuda", dtype=torch.float16) @@ -853,17 +549,15 @@ def test_colrow_absmax(dim1, dim2, dims, threshold): torch.testing.assert_close(row_stats1, row_stats2) -# @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) -# @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) -def test_double_quant(dim1, dim2): +def test_int8_double_quant(dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() out_col1, Scol = F.vectorwise_quant(A, dim=0) out_row1, Srow = F.vectorwise_quant(A, dim=1) - CA, CAt, statsA, statsAt, coo_tensor = F.int8_double_quant(A) + CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A) # max difference is 1 due to rounding differences torch.testing.assert_close(CA, out_row1, atol=1, rtol=0) @@ -894,9 +588,6 @@ def test_double_quant(dim1, dim2): (1, 8, 2048, 4096), (2, 128, 2048, 4096), (4, 256, 512, 4096), - # get_test_dims(1, 4 * 1024, n=4), - # get_test_dims(1, 4 * 1024, n=4), - # get_test_dims(1, 4 * 1024, n=4), ) ), ) @@ -1004,59 +695,6 @@ def test_igemmlt_row_scale(dim1, dim4, inner): print(sum(err3) / len(err3)) -@pytest.mark.parametrize( - ("dim1", "dim4", "inner"), - [ - pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"), - pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"), - ], -) -@pytest.mark.skip("Row scale has some bugs for ampere") -@pytest.mark.benchmark -def test_row_scale_bench(dim1, dim4, inner): - formatB = F.get_special_format_str() - err1, err2, err3 = [], [], [] - relerr1, relerr2 = [], [] - scale = 1 - A = torch.randn(dim1, inner, device="cuda").half() - B = torch.randn(dim4, inner, device="cuda").half() - torch.nn.init.xavier_uniform_(B) - # warmpup - for i in range(k): - C1 = torch.matmul(A, B.t()) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(k): - C1 = torch.matmul(A, B.t()) - torch.cuda.synchronize() - print("16", time.time() - t0) - - C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A) - CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") - A2, SA = F.nvidia_transform(C1a, "col32") - B2, SB = F.nvidia_transform(CB, formatB) - A1, maxA = F.vectorwise_quant(A, dim=1) - - c = 10.0 * inner * scale - row_scale = maxA / c - torch.cuda.synchronize() - t0 = time.time() - for i in range(k): - outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale) - torch.cuda.synchronize() - print("row-wise", time.time() - t0) - - C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B) - B2, SB = F.nvidia_transform(C2a, formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(k): - outC32 = F.int8_linear_matmul(A2, B2) - torch.cuda.synchronize() - print("vector-wise", time.time() - t0) - - @pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) @@ -1065,6 +703,7 @@ def test_row_scale_bench(dim1, dim4, inner): @pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) @pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) +@pytest.mark.deprecated def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: @@ -1088,21 +727,6 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): torch.testing.assert_close(out1, out2) -def test_overflow(): - formatB = F.get_special_format_str() - print(formatB) - for i in range(2): - a = torch.arange(0, 16).cuda().to(torch.int8).view(-1, 4).contiguous() - b = torch.arange(0, 16).cuda().to(torch.int8).view(-1, 4).contiguous() - - # Ca, Sa = F.nvidia_transform(a, "col32") - # Cb, Sb = F.nvidia_transform(b, formatB) - - # c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) - c = F.int8_linear_matmul(a, b, dtype=torch.int8) - c2 = torch.matmul(a.float(), b.float().t()) - - @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(dim1, dim2): @@ -1124,8 +748,6 @@ def test_coo_double_quant(dim1, dim2): torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) -# @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) -# @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_int8_vectorwise_quant(dim1, dim2): @@ -1216,12 +838,9 @@ def test_spmm_bench(): @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2")) -# @pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1")) -# @pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2")) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 - # formatB = "col_turing" - for i in range(k): + for _ in range(k): A = torch.randn(dim1, dim2).cuda().half() w1 = torch.randn(dim1, dim2).cuda().half() out1 = torch.matmul(A, w1.t()) @@ -1463,202 +1082,6 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -@pytest.mark.parametrize( - ("batch", "seq", "model", "hidden"), - [ - # pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"), - pytest.param(1, 1, 3584, 512, id="batch=1, seq=128, model=3584, hidden=19k"), - # pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"), - # pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k") - ], -) -@pytest.mark.benchmark -def test_bench_matmul(batch, seq, model, hidden): - iters = 1000 - formatB = F.get_special_format_str() - - A = torch.randn(batch, seq, model, device="cuda").half() - B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") - torch.nn.init.xavier_uniform_(B) - - B_fp4, state = F.quantize_fp4(B) - B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) - - B_nf4, state_nf4 = F.quantize_nf4(B) - B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True) - - linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() - linear8bit.eval() - - outliers = torch.randint(0, model, size=(5,)).cuda() - A[:, :, outliers] = 8.0 - - linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half() - # linearMixedBit.eval() - - linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() - linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) - - # warmup - for i in range(iters): - torch.matmul(A, B.t()) - torch.cuda.synchronize() - print("") - - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - torch.matmul(A, B.t()) - torch.cuda.synchronize() - print( - f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s", - ) - - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) - # torch.cuda.synchronize() - # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) - # torch.cuda.synchronize() - # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) - torch.cuda.synchronize() - print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) - torch.cuda.synchronize() - print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul(A, B) - torch.cuda.synchronize() - print( - f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul(A, B, threshold=6.0) - torch.cuda.synchronize() - print( - f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) - - CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0) - CB, SCB, _ = F.int8_vectorwise_quant(B) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - out32 = F.int8_linear_matmul(CA, CB) - torch.cuda.synchronize() - print( - f"no overhead int8 [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) - - # C32A, SA = F.transform(CA, "col32") - - # CxB, SB = F.transform(CB, to_order=formatB) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - # torch.cuda.synchronize() - # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - # C32A, SA = F.transform(CA, "col32") - # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - # CxB, SB = F.transform(CB, to_order=formatB) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - # torch.cuda.synchronize() - # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - # BA, statsB = F.vectorwise_quant(B, dim=1) - # CxB, SB = F.nvidia_transform(CB, to_order=formatB) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # A2 = A.view(-1, A.shape[-1]).contiguous() - # CA, statsA = F.vectorwise_quant(A2, dim=1) - # C32A, SA = F.nvidia_transform(CA, "col32") - # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) - # torch.cuda.synchronize() - # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - # CxB, SB = F.nvidia_transform(CB, to_order=formatB) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # A2 = A.view(-1, A.shape[-1]).contiguous() - # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") - # C32A, SA = F.nvidia_transform(CA, "col32") - # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - # out = Cout * statsB * statsA * (1.0 / (127 * 127)) - # torch.cuda.synchronize() - # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - linear8bit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit(A) - torch.cuda.synchronize() - print( - f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) - - linearMixedBit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linearMixedBit(A) - torch.cuda.synchronize() - print( - f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) - - # linear8bit_train(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # linear8bit_train(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - # linear8bit_train_thresh(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # linear8bit_train(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - def test_zeropoint(): def quant_zp(x): dtype = x.dtype @@ -1745,6 +1168,7 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) +@pytest.mark.deprecated def test_extract_outliers(): for i in range(k): shapeA = (4096, 4096 * 4) @@ -2245,23 +1669,6 @@ def test_managed(): assert (A == 17 * (2**3)).sum().item() == n * n -# F.prefetch_tensor(A) -# F.prefetch_tensor(B) - - -# F.fill(B2, 17.0) -# F._mul(A, B2) - -# F.prefetch_tensor(A, to_cpu=True) -# F.prefetch_tensor(B, to_cpu=True) -# F.prefetch_tensor(B2, to_cpu=True) -# torch.cuda.synchronize() - -# assert (A==17).sum().item() == n*n - -# torch.testing.assert_close(A, torch.ones(A.shape)*289) - - @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @@ -2286,3 +1693,152 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): torch.testing.assert_close(A, C2) # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) + + +@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) +@pytest.mark.deprecated +def test_vector_quant(dim1, dim2, dim3): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + for i in range(k): + A = torch.randn(size=(dim2, dim3), device="cuda") + qA, SA = F.vectorwise_quant(A, dim=0) + A1 = F.vectorwise_dequant(qA, SA) + n = A1.numel() + assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) + + +@pytest.mark.deprecated +def test_quantile_quantization(): + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1 - A2).mean().item() + assert diff < 0.0075 + + A1 = torch.rand(1024, 1024, device="cuda") + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1 - A2).mean().item() + torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0) + assert diff < 0.001 + + +@pytest.mark.deprecated +def test_dynamic_quantization(): + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diff.mean().item() < 0.0135 + print(sum(diffs) / len(diffs)) + print(sum(reldiffs) / len(reldiffs)) + + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1 - A2).mean().item() + torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + assert diff < 0.004 + + +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) +@pytest.mark.deprecated +def test_percentile_clipping(gtype): + gnorm_vec1 = torch.zeros(100, device="cuda") + gnorm_vec2 = torch.zeros(100, device="cuda") + n = 4 + step = 0 + percentile = 5 + for i in range(k): + step += 1 + g = torch.randn(n, n, dtype=gtype, device="cuda") + gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) + assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 + + gnorm2 = torch.norm(g.float()) + if step == 1: + gnorm_vec1[:] = gnorm2 + else: + gnorm_vec1[step % 100] = gnorm2 + + vals, idx = torch.sort(gnorm_vec1) + clip1 = vals[percentile] + + torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) + torch.testing.assert_close(clip1, clip2) + torch.testing.assert_close(gnorm1, gnorm2) + + +@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose")) +@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims")) +@pytest.mark.deprecated +def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + if dims == 3 and orderOut != "col32": + return + if dtype == torch.int32 and orderOut != "col32": + return + try: + func = F.get_transform_func(dtype, orderA, orderOut, transpose) + except ValueError as ve: + pytest.skip(str(ve)) # skip if not supported + + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) + + out, S = F.nvidia_transform(A, to_order=orderOut) + + if orderOut == "row": + torch.testing.assert_close(A.flatten(), out.flatten()) + elif orderOut == "col": + torch.testing.assert_close(A.t().flatten(), out.flatten()) + elif orderOut == "col32": + if dims == 2: + n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) + elif dims == 3: + n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) + assert out.numel() == n + elif orderOut == "col_turing": + # 32 col 8 row tiles + n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32))) + assert out.numel() == n + total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) + for row in range(A.shape[0]): + for col in range(A.shape[1]): + i = row * A.shape[1] + j = col + + coltile = (col // 32) + (1 if col % 32 != 0 else 0) + rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile + offset = 32 * 8 * (rowtile + coltile) + col2 = col % 32 + row2 = (row % 8) * 32 + + assert A.flatten()[i + j] == A[row, col] + # assert A.flatten()[i+j] == out.flatten()[row2+col2] + # torch.testing.assert_close(A.flatten()[i+j], A[row, col]) + # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) + + if orderOut == "col32": + out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) + torch.testing.assert_close(A, out2) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 51e273897..bc9e2600f 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -82,7 +82,6 @@ def test_linear_no_igemmlt(): @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) -# @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) def test_linear_serialization( @@ -104,8 +103,6 @@ def test_linear_serialization( has_fp16_weights=has_fp16_weights, threshold=6.0, ) - # if force_no_igemmlt: - # linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( linear.weight.data.clone(), @@ -151,8 +148,6 @@ def test_linear_serialization( has_fp16_weights=has_fp16_weights, threshold=6.0, ) - # if force_no_igemmlt: - # new_linear_custom.state.force_no_igemmlt = True if deserialize_before_cuda: with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): diff --git a/tests/test_modules.py b/tests/test_modules.py index 239c7d3a6..278add87f 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -566,7 +566,6 @@ def test_kbit_backprop(module): relerrs2.append(relerr2.mean().item()) if isinstance(module, bnb.nn.Linear8bitLt): - # if module == bnb.nn.Linear8bitLt: assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1) torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) else: @@ -577,10 +576,6 @@ def test_kbit_backprop(module): assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 - # print('out', sum(errs1)/len(errs1)) - # print('grad', sum(errs2)/len(errs2)) - # print('rel out', sum(relerrs1)/len(relerrs1)) - # print('rel grad', sum(relerrs2)/len(relerrs2)) def test_fp8linear(): From df941ec723bcfe92eaa9586612c6a47bbf428d13 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 25 Nov 2024 10:32:46 -0500 Subject: [PATCH 47/65] Add int8 dequant function; misc improvements --- benchmarking/int8/training_benchmark.py | 2 + bitsandbytes/autograd/_functions.py | 9 ++-- bitsandbytes/functional.py | 19 +++++++- csrc/kernels.cu | 3 +- tests/test_functional.py | 64 ++++++++++++------------- 5 files changed, 57 insertions(+), 40 deletions(-) diff --git a/benchmarking/int8/training_benchmark.py b/benchmarking/int8/training_benchmark.py index 32060afde..e9641235f 100644 --- a/benchmarking/int8/training_benchmark.py +++ b/benchmarking/int8/training_benchmark.py @@ -13,6 +13,8 @@ k = 20 +torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) + @pytest.mark.parametrize( ("batch", "seq", "model", "hidden"), diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index ad7624c38..e524e0203 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -350,18 +350,17 @@ def forward( CAt[:, state.idx] = 0 # Extract the input outliers in original precision - subA = A[:, state.idx] + subA = A[:, state.idx].contiguous() # Extract the corresponding weights if state.has_fp16_weights: state.subB = B[:, state.idx].t() else: - outliers = state.CB[:, state.idx] - # To dequantize our weights associated with the input outliers, # we want to divide by 127. It's however more performant to multiply # by the reciprocal. - state.subB = (7.874016e-3 * outliers * state.SCB.view(-1, 1)).t().to(A.dtype) + outliers = state.CB[:, state.idx] + state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype) else: subA = None @@ -378,7 +377,7 @@ def forward( # 4. Mixed-precision decomposition matmul if subA is not None and state.subB is not None: - output += torch.matmul(subA, state.subB) + output = output.addmm(subA, state.subB) # 5. Save state ctx.state = state diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7d7547130..d79c39c41 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2722,6 +2722,20 @@ def int8_double_quant( return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols +def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor): + """Dequantizes a tensor with dtype `torch.int8` to `torch.float32`. + + Args: + A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor. + stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics. + + Returns: + `torch.Tensor` with dtype `torch.float32`: The dequantized tensor. + """ + # To dequantize we divide by 127, or multiply by the reciprocal. + return A * stats.view(-1, 1) * 7.874015718698502e-3 + + def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): """Quantizes a tensor with dtype `torch.float16` to `torch.int8` in accordance to the `LLM.int8()` algorithm. @@ -3026,7 +3040,10 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): return None -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) +@deprecated( + "This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.", + category=FutureWarning, +) def vectorwise_dequant(xq, max1, quant_type="vector"): if quant_type == "vector": x = (xq / C * max1).to(torch.float32) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index b92bbc6ea..4056ffbcf 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2159,7 +2159,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat // Threads will read the row values in a striped access pattern and find a local absmax. float row_local_absmax = -FLT_MIN; for (int i = threadIdx.x; i < cols; i += THREADS) { - const float absval = fabsf(__ldg(&(row_data[i]))); + const float absval = fabsf(__ldcs(&(row_data[i]))); // For sparse decomposition, values outside of the threshold are not to be // included when calculating the row's absmax. @@ -2171,7 +2171,6 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat } // Reduce thread-local absmax across the block. - // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. diff --git a/tests/test_functional.py b/tests/test_functional.py index 27e12953f..3adeb1a96 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -695,38 +695,6 @@ def test_igemmlt_row_scale(dim1, dim4, inner): print(sum(err3) / len(err3)) -@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) -@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) -@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) -@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) -@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) -@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) -@pytest.mark.deprecated -def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - for i in range(k): - if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) - elif dims == 3: - A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) - - A.view(-1)[-1] = -1 - if transpose: - At = A.t().contiguous() - out1, S1 = F.nvidia_transform(At, to_order=orderOut) - else: - out1, S1 = F.nvidia_transform(A, to_order=orderOut) - out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) - - assert S1[0][0] == S2[0][0] - assert S1[0][1] == S2[0][1] - # print(out1) - # print(out2) - - torch.testing.assert_close(out1, out2) - - @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(dim1, dim2): @@ -1782,6 +1750,38 @@ def test_percentile_clipping(gtype): torch.testing.assert_close(gnorm1, gnorm2) +@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) +@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) +@pytest.mark.deprecated +def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + for i in range(k): + if dims == 2: + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) + elif dims == 3: + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) + + A.view(-1)[-1] = -1 + if transpose: + At = A.t().contiguous() + out1, S1 = F.nvidia_transform(At, to_order=orderOut) + else: + out1, S1 = F.nvidia_transform(A, to_order=orderOut) + out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) + + assert S1[0][0] == S2[0][0] + assert S1[0][1] == S2[0][1] + # print(out1) + # print(out2) + + torch.testing.assert_close(out1, out2) + + @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) From 73f02e864838f0f52dfe763d26240aa4a7e73e65 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:33:42 -0500 Subject: [PATCH 48/65] int8 matmul fallback for inner dims not divisible by 4 --- bitsandbytes/functional.py | 15 ++++++++++++--- tests/test_functional.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d79c39c41..11874b739 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2314,9 +2314,6 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten shapeC = (*shapeB[:-1], shapeA[0]) - if out is None: - out = torch.empty(shapeC, device=A.device, dtype=dtype) - k, m = shapeA n = prod(shapeB[:-1]) lda = shapeA[-1] # Weights (outputs, inputs) @@ -2327,6 +2324,18 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten lda == ldb ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + if out is not None: + result = out.copy_(result) + return result + + if out is None: + out = torch.empty(shapeC, device=A.device, dtype=dtype) + is_on_gpu([A, B, out]) with _cuda_device_of(A): diff --git a/tests/test_functional.py b/tests/test_functional.py index 3adeb1a96..20375a02e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -427,7 +427,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", [512], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) From ebb67970018628a022d8de273b3f8df0c07a92a7 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 27 Nov 2024 09:07:18 -0500 Subject: [PATCH 49/65] improve register usage of kInt8VectorQuant - especially for A100/H100 --- csrc/kernels.cu | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4056ffbcf..453dcd7cd 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -5,6 +5,7 @@ #include "kernels.cuh" #include "common.cuh" +#include #include #include #include @@ -2141,7 +2142,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char template __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { - using BlockReduceT = cub::BlockReduce; + using BlockReduceT = cub::BlockReduce; // One block per row. // Threads load column values in a striped arrangement. @@ -2151,27 +2152,27 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat // We then do a blockwise reduction to determine the row's absmax. __shared__ typename BlockReduceT::TempStorage temp_storage; - __shared__ float smem_row_absmax; + __shared__ T smem_row_absmax; const int row_id = blockIdx.x; - const T* __restrict__ row_data = A + (row_id * cols); + const T* row_data = A + (row_id * cols); // Threads will read the row values in a striped access pattern and find a local absmax. - float row_local_absmax = -FLT_MIN; + T row_local_absmax = -FLT_MIN; for (int i = threadIdx.x; i < cols; i += THREADS) { - const float absval = fabsf(__ldcs(&(row_data[i]))); + const T absval = fabsf(__ldcs(&(row_data[i]))); // For sparse decomposition, values outside of the threshold are not to be // included when calculating the row's absmax. if constexpr (SPARSE_DECOMP) { - row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + row_local_absmax = fmaxf(row_local_absmax, absval < T(threshold) ? absval : row_local_absmax); } else { row_local_absmax = fmaxf(row_local_absmax, absval); } } // Reduce thread-local absmax across the block. - const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = smem_row_absmax = row_absmax; @@ -2181,13 +2182,14 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat // Quantize row-wise. const float scale = __fdividef(127.0f, smem_row_absmax); for (int i = threadIdx.x; i < cols; i += THREADS) { + float val = row_data[i]; + if constexpr (SPARSE_DECOMP) { // For sparse decomposition, we do not want to quantize the outliers. // Instead they're zeroed out. - float val = row_data[i]; out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; } else { - out[row_id * cols + i] = __float2int_rn(float(row_data[i]) * scale); + out[row_id * cols + i] = __float2int_rn(val * scale); } } } From 196c8e00718a4523ae1bf20ee893c21b015c9b54 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 27 Nov 2024 09:44:00 -0500 Subject: [PATCH 50/65] disable fail-fast for package build --- .github/workflows/python-package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 560741edb..4ede1c9f2 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -60,6 +60,7 @@ jobs: ## build-shared-libs-cuda: strategy: + fail-fast: false matrix: os: [ubuntu-latest, windows-latest] arch: [x86_64, aarch64] From fa6f59720c69946f277b312392b4411b99e37aa1 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 27 Nov 2024 10:55:26 -0500 Subject: [PATCH 51/65] maxwell compat --- csrc/kernels.cu | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 453dcd7cd..a655a78b2 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2142,7 +2142,16 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char template __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { - using BlockReduceT = cub::BlockReduce; + + // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. + // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE && __CUDACC__ + using TReduction = T; +#else + using TReduction = float; +#endif + + using BlockReduceT = cub::BlockReduce; // One block per row. // Threads load column values in a striped arrangement. @@ -2152,27 +2161,27 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat // We then do a blockwise reduction to determine the row's absmax. __shared__ typename BlockReduceT::TempStorage temp_storage; - __shared__ T smem_row_absmax; + __shared__ TReduction smem_row_absmax; const int row_id = blockIdx.x; const T* row_data = A + (row_id * cols); // Threads will read the row values in a striped access pattern and find a local absmax. - T row_local_absmax = -FLT_MIN; + TReduction row_local_absmax = -FLT_MIN; for (int i = threadIdx.x; i < cols; i += THREADS) { - const T absval = fabsf(__ldcs(&(row_data[i]))); + const TReduction absval = fabsf(__ldcs(&(row_data[i]))); // For sparse decomposition, values outside of the threshold are not to be // included when calculating the row's absmax. if constexpr (SPARSE_DECOMP) { - row_local_absmax = fmaxf(row_local_absmax, absval < T(threshold) ? absval : row_local_absmax); + row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); } else { row_local_absmax = fmaxf(row_local_absmax, absval); } } // Reduce thread-local absmax across the block. - const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = smem_row_absmax = row_absmax; From 498d8de59a911db0570c5a654a884bc76dbf4efb Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 27 Nov 2024 12:00:10 -0500 Subject: [PATCH 52/65] ptxas verbose --- .github/scripts/build-cuda.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/build-cuda.sh b/.github/scripts/build-cuda.sh index 26a7075b0..4f616a7c9 100644 --- a/.github/scripts/build-cuda.sh +++ b/.github/scripts/build-cuda.sh @@ -15,7 +15,7 @@ if [ "${build_os:0:6}" == ubuntu ]; then docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \ "apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ - && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \ + && cmake -DPTXAS_VERBOSE=1 -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \ && cmake --build ." else pip install cmake==3.28.3 From a2ee1c44c901e063e4e4e6822f0d46aeff1918a2 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 29 Nov 2024 12:55:47 -0500 Subject: [PATCH 53/65] docs update --- docs/source/algorithms.mdx | 2 +- docs/source/installation.mdx | 31 ++++++++++++++++--------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/source/algorithms.mdx b/docs/source/algorithms.mdx index d9db5cb04..65e5567a4 100644 --- a/docs/source/algorithms.mdx +++ b/docs/source/algorithms.mdx @@ -5,7 +5,7 @@ This is an overview of the `bnb.functional` API in `bitsandbytes` that we think ## Using Int8 Matrix Multiplication -For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: +For straight Int8 matrix multiplication without mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: ```py bnb.matmul(..., threshold=6.0) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index d1acb2cd6..d846c35a5 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -19,29 +19,30 @@ Welcome to the installation guide for the `bitsandbytes` library! This document ## CUDA[[cuda]] -`bitsandbytes` is currently only supported on CUDA GPUs for CUDA versions **11.0 - 12.5**. However, there's an ongoing multi-backend effort under development, which is currently in alpha. If you're interested in providing feedback or testing, check out [the multi-backend section below](#multi-backend). +`bitsandbytes` is currently only supported on CUDA GPUs for CUDA versions **11.0 - 12.6**. However, there's an ongoing multi-backend effort under development, which is currently in alpha. If you're interested in providing feedback or testing, check out [the multi-backend section below](#multi-backend). ### Supported CUDA Configurations[[cuda-pip]] -The latest version of `bitsandbytes` builds on the following configurations: +The latest version of the distributed `bitsandbytes` package is built with the following configurations: -| **OS** | **CUDA Version** | **Compiler** | +| **OS** | **CUDA Toolkit** | **Host Compiler** | |-------------|------------------|----------------------| | **Linux** | 11.7 - 12.3 | GCC 11.4 | -| | 12.4+ | GCC 13.2 | -| **Windows** | 11.7 - 12.4 | MSVC 19.38+ (VS2022) | +| | 12.4 - 12.6 | GCC 13.2 | +| **Windows** | 11.7 - 12.6 | MSVC 19.42+ (VS2022) | -For Linux systems, ensure your hardware meets the following requirements: +For CUDA systems, ensure your hardware meets the following requirements: -| **Feature** | **Hardware Requirement** | -|---------------------------------|--------------------------------------------------------------------| -| LLM.int8() | NVIDIA Turing (RTX 20 series, T4) or Ampere (RTX 30 series, A4-A100) GPUs | -| 8-bit optimizers/quantization | NVIDIA Kepler (GTX 780 or newer) | +| **Feature** | **Minimum Hardware Requirement** | +|---------------------------------|---------------------------------------------------------------| +| LLM.int8() | NVIDIA Turing (RTX 20 series, T4) or newer GPUs | +| 8-bit optimizers/quantization | NVIDIA Maxwell (GTX 900 series, TITAN X, M40) or newer GPUs * | +| NF4/FP4 quantization | NVIDIA Maxwell (GTX 900 series, TITAN X, M40) or newer GPUs * | > [!WARNING] -> `bitsandbytes >= 0.39.1` no longer includes Kepler binaries in pip installations. This requires [manual compilation using](#cuda-compile) the `cuda11x_nomatmul_kepler` configuration. - -To install from PyPI. +> `bitsandbytes >= 0.45.0` no longer supports Kepler GPUs. +> +> Support for Maxwell GPUs is deprecated and will be removed in a future release. For the best results, a Turing generation device or newer is recommended. ```bash pip install bitsandbytes @@ -79,7 +80,7 @@ For Linux and Windows systems, compiling from source allows you to customize the -To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. Make sure you have a compiler installed to compile C++ (`gcc`, `make`, headers, etc.). +To compile from source, you need CMake >= **3.22.1** and Python >= **3.9** installed. Make sure you have a compiler installed to compile C++ (`gcc`, `make`, headers, etc.). For example, to install a compiler and CMake on Ubuntu: @@ -115,7 +116,7 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise Windows systems require Visual Studio with C++ support as well as an installation of the CUDA SDK. -To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. +To compile from source, you need CMake >= **3.22.1** and Python >= **3.9** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. Refer to the following table if you're using another CUDA Toolkit version. From 15f1661559da5568562fcaa1b3c2ef534dc3f2d1 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:13:19 -0500 Subject: [PATCH 54/65] doc update --- docs/source/index.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 5943e7d1d..064420cf7 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -3,7 +3,7 @@ bitsandbytes enables accessible large language models via k-bit quantization for PyTorch. bitsandbytes provides three main features for dramatically reducing memory consumption for inference and training: * 8-bit optimizers uses block-wise quantization to maintain 32-bit performance at a small fraction of the memory cost. -* LLM.Int() or 8-bit quantization enables large language model inference with only half the required memory and without any performance degradation. This method is based on vector-wise quantization to quantize most features to 8-bits and separately treating outliers with 16-bit matrix multiplication. +* LLM.int8() or 8-bit quantization enables large language model inference with only half the required memory and without any performance degradation. This method is based on vector-wise quantization to quantize most features to 8-bits and separately treating outliers with 16-bit matrix multiplication. * QLoRA or 4-bit quantization enables large language model training with several memory-saving techniques that don't compromise performance. This method quantizes a model to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allow training. # License From 5d536c6987dab23ff85acf526025187bc06c6c0f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 2 Dec 2024 16:46:15 -0500 Subject: [PATCH 55/65] backward fix --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e524e0203..a1c9e9b28 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -431,7 +431,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor if req_gradA: if state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape) else: raise Exception("State must contain CB matrix for backward") From 5b2348bf2d6880933d5b5cba4f4c22c823b796d1 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 2 Dec 2024 23:10:09 -0500 Subject: [PATCH 56/65] Bugfix sparse decomp --- bitsandbytes/functional.py | 5 +++++ csrc/kernels.cu | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 11874b739..644aaf864 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2793,6 +2793,11 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): _get_tensor_stream(A), ) + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + return out_row, row_stats, outlier_cols diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a655a78b2..6cd330079 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2145,7 +2145,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. -#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE && __CUDACC__ +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE using TReduction = T; #else using TReduction = float; From bbb7063278af6c1198c0f0ab164f2537ebc5db90 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:43:18 -0500 Subject: [PATCH 57/65] Int8 fix for PEFT OLoRA init --- bitsandbytes/autograd/_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index a1c9e9b28..f66cdf68d 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -328,13 +328,13 @@ def forward( has_grad = False - if state.has_fp16_weights: + if state.has_fp16_weights or state.CB is None: has_grad = getattr(B, "grad", None) is not None is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) if is_transposed: B = B.contiguous() - if (state.is_training and not has_grad) or state.SCB is None: + if (state.is_training and not has_grad) or state.CB is None or state.SCB is None: state.reset_grads() # 2. Quantize B From d25ebb44ac58806f2e14450b30f354fc2bc841ec Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:35:20 -0500 Subject: [PATCH 58/65] Fix test for deprecated spmm_coo --- bitsandbytes/functional.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 644aaf864..674fee142 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2647,7 +2647,7 @@ def double_quant( """ coo_tensor = None - quant_row, quant_col, row_stats, col_stats, _ = int8_double_quant( + quant_row, quant_col, row_stats, col_stats, outlier_cols = int8_double_quant( A, col_stats, row_stats, @@ -2657,16 +2657,15 @@ def double_quant( ) if threshold > 0.0: - # Build COO tensor for any outliers. - outlier_mask = A.abs() >= threshold - outlier_locations = outlier_mask.nonzero() - outliers = A[outlier_mask] + # Build a COO tensor including all of the outlier columns. + outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32) + outliers = A[:, outlier_cols] coo_tensor = COOSparseTensor( A.shape[0], A.shape[1], outliers.numel(), - outlier_locations[:, 0].int(), - outlier_locations[:, 1].int(), + outlier_rows.repeat_interleave(outliers.size(1)), + outlier_cols.repeat(outliers.size(0)).int(), outliers, ) From 3d595f138e467f6b9ce935dd58a1fa04ca93531a Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Dec 2024 16:55:25 -0500 Subject: [PATCH 59/65] test improvement --- bitsandbytes/functional.py | 2 +- tests/test_functional.py | 12 ++++++------ tests/test_modules.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 674fee142..d1c5d1d2e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2656,7 +2656,7 @@ def double_quant( threshold=threshold, ) - if threshold > 0.0: + if threshold > 0.0 and outlier_cols is not None: # Build a COO tensor including all of the outlier columns. outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32) outliers = A[:, outlier_cols] diff --git a/tests/test_functional.py b/tests/test_functional.py index 20375a02e..c8ac20896 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -703,17 +703,16 @@ def test_coo_double_quant(dim1, dim2): A = torch.randn(dim1, dim2, device="cuda").half() idx = torch.abs(A) >= threshold - CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold) + CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) - if coo_tensor is not None: + if outlier_cols is not None: A1 = A * idx - A2 = torch.zeros_like(A) - A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values + A2 = torch.zeros_like(A) + A1 torch.testing.assert_close(A1, A2) - A1 = A * (idx == 0) + A[:, outlier_cols] = 0 A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @@ -728,6 +727,7 @@ def test_coo_int8_vectorwise_quant(dim1, dim2): if outlier_cols is not None: A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + A[:, outlier_cols] = 0 torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) diff --git a/tests/test_modules.py b/tests/test_modules.py index 278add87f..c2583550d 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -349,8 +349,8 @@ def test_linear8bitlt_accumulated_gradient(): l1[0].bias.data.copy_(l2[0].bias.data) l1[1].bias.data.copy_(l2[1].bias.data) else: - torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04) - torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.00, atol=0.02) + assert_all_approx_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04, count=1) + assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1) @pytest.mark.parametrize("threshold", [0.0, 2.0]) From 03fcabd6078b899ea6ce69d874dbda62cd888490 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Dec 2024 22:15:17 -0500 Subject: [PATCH 60/65] doc update --- docs/source/_toctree.yml | 2 ++ docs/source/explanations/resources.mdx | 2 +- docs/source/reference/functional.mdx | 48 +++++++++++++++++++++++++ docs/source/reference/nn/linear8bit.mdx | 5 +-- 4 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 docs/source/reference/functional.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 77ea3ceff..629c6d0f8 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -32,6 +32,8 @@ title: Papers, resources & how to cite - title: API reference sections: + - title: Functional + local: reference/functional - title: Optimizers sections: - local: reference/optim/optim_overview diff --git a/docs/source/explanations/resources.mdx b/docs/source/explanations/resources.mdx index 56330175a..92bbdf947 100644 --- a/docs/source/explanations/resources.mdx +++ b/docs/source/explanations/resources.mdx @@ -49,7 +49,7 @@ Authors: Tim Dettmers, Luke Zettlemoyer } ``` -## [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)](https://arxiv.org/abs/2208.07339) +## [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)](https://arxiv.org/abs/2208.07339) [[llm-int8]] Authors: Tim Dettmers, Mike Lewis, Younes Belkada, Luke Zettlemoyer - [LLM.int8() Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) diff --git a/docs/source/reference/functional.mdx b/docs/source/reference/functional.mdx new file mode 100644 index 000000000..e9e800779 --- /dev/null +++ b/docs/source/reference/functional.mdx @@ -0,0 +1,48 @@ +# Overview +The `bitsandbytes.functional` API provides the low-level building blocks for the library's features. + +## When to Use `bitsandbytes.functional` + +* When you need direct control over quantized operations and their parameters. +* To build custom layers or operations leveraging low-bit arithmetic. +* To integrate with other ecosystem tooling. +* For experimental or research purposes requiring non-standard quantization or performance optimizations. + +## LLM.int8() +[[autodoc]] functional.int8_double_quant + +[[autodoc]] functional.int8_linear_matmul + +[[autodoc]] functional.int8_mm_dequant + +[[autodoc]] functional.int8_vectorwise_deqant + +[[autodoc]] functional.int8_vectorwise_quant + + +## 4-bit +[[autodoc]] functional.dequantize_4bit + +[[autodoc]] functional.dequantize_fp4 + +[[autodoc]] functional.dequantize_nf4 + +[[autodoc]] functional.gemv_4bit + +[[autodoc]] functional.quantize_4bit + +[[autodoc]] functional.quantize_fp4 + +[[autodoc]] functional.quantize_nf4 + +[[autodoc]] functional.QuantState + +## General Quantization +[[autodoc]] functional.dequantize_blockwise + +[[autodoc]] functional.quantize_blockwise + +## Utility +[[autodoc]] functional.get_ptr + +[[autodoc]] functional.is_on_gpu diff --git a/docs/source/reference/nn/linear8bit.mdx b/docs/source/reference/nn/linear8bit.mdx index 73254fe67..d1cfd67d5 100644 --- a/docs/source/reference/nn/linear8bit.mdx +++ b/docs/source/reference/nn/linear8bit.mdx @@ -1,6 +1,7 @@ -# 8-bit quantization +# LLM.int8() +[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that aims to make large language model inference more accessible without significant degradation. Unlike naive 8-bit quantization, which can result in loss of critical information and accuracy, LLM.int8() dynamically adapts to ensure sensitive components of the computation retain higher precision when needed. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output. -[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that doesn't degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output. +[Further Resources](../../explanations/resources#llm-int8) ## Linear8bitLt From 582bf229fcfea75d7dafe3c72da30e5ed891d557 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Dec 2024 22:21:45 -0500 Subject: [PATCH 61/65] typo --- docs/source/reference/functional.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/functional.mdx b/docs/source/reference/functional.mdx index e9e800779..0384ccfc8 100644 --- a/docs/source/reference/functional.mdx +++ b/docs/source/reference/functional.mdx @@ -15,7 +15,7 @@ The `bitsandbytes.functional` API provides the low-level building blocks for the [[autodoc]] functional.int8_mm_dequant -[[autodoc]] functional.int8_vectorwise_deqant +[[autodoc]] functional.int8_vectorwise_dequant [[autodoc]] functional.int8_vectorwise_quant From 1ae7c6b2ede9eda1e81887ca282121c253631381 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Dec 2024 23:26:54 -0500 Subject: [PATCH 62/65] doc cleanup --- bitsandbytes/functional.py | 194 +++++++++++++++++++------------------ docs/source/_toctree.yml | 2 +- 2 files changed, 100 insertions(+), 96 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d1c5d1d2e..a5cc4a9f0 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -483,17 +483,13 @@ def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: - """ - Get the ctypes pointer from a PyTorch Tensor. + """Gets the memory address of the first element of a tenso - Parameters - ---------- - A : torch.tensor - The PyTorch tensor. + Args: + A (`Optional[Tensor]`): A PyTorch tensor. - Returns - ------- - ctypes.c_void_p + Returns: + `Optional[ct.c_void_p]`: A pointer to the underlying tensor data. """ if A is None: return None @@ -863,30 +859,31 @@ def quantize_blockwise( blocksize=4096, nested=False, ) -> Tuple[torch.Tensor, QuantState]: - """ - Quantize tensor A in blocks of size 4096 values. + """Quantize a tensor in blocks of values. - Quantizes tensor A by dividing it into blocks of 4096 values. - Then the absolute maximum value within these blocks is calculated - for the non-linear quantization. + The input tensor is quantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is calculated for scaling + the non-linear quantization. - Parameters - ---------- - A : torch.Tensor - The input tensor. - code : torch.Tensor - The quantization map. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor (8-bit). + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. - Returns - ------- - torch.Tensor: - The 8-bit tensor. - tuple(torch.Tensor, torch.Tensor): - The quantization state to undo the quantization. + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor. + - [`QuantState`]: The state object used to undo the quantization. """ if code is None: @@ -967,31 +964,38 @@ def dequantize_blockwise( blocksize: int = 4096, nested=False, ) -> torch.Tensor: - """ - Dequantizes blockwise quantized values. + """Dequantize a tensor in blocks of values. - Dequantizes the tensor A with maximum absolute values absmax in - blocks of size 4096. + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. - Parameters - ---------- - A : torch.Tensor - The input 8-bit tensor. - quant_state : QuantState - Object with code, absmax and other quantization state components. - absmax : torch.Tensor - The absmax values. - code : torch.Tensor - The quantization map. - out : torch.Tensor - Dequantized output tensor (default: float32) + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_blockwise`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + Ignored when `quant_state` is provided. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + Ignored when `quant_state` is provided. + Raises: + ValueError: Raised when the input data type is not supported. - Returns - ------- - torch.Tensor: - Dequantized tensor (default: float32) + Returns: + `torch.Tensor`: + The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. """ + assert quant_state is not None or absmax is not None if code is None and quant_state is None: if "dynamic" not in name2qmap: @@ -1166,31 +1170,30 @@ def quantize_4bit( quant_type="fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: - """ - Quantize tensor A in blocks of 4-bit values. + """Quantize tensor A in blocks of 4-bit values. - Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + Quantizes tensor A by dividing it into blocks which are independently quantized. - Parameters - ---------- - A : torch.Tensor - The input tensor. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 64. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. + quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`. - Returns - ------- - torch.Tensor: - Tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype, int): - The quantization state to undo the quantization. + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + Tuple[`torch.Tensor`, `QuantState`]: A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor with packed 4-bit values. + - [`QuantState`]: The state object used to undo the quantization. """ + if A.device.type != "cuda": raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") if quant_type not in ["fp4", "nf4"]: @@ -1297,32 +1300,33 @@ def dequantize_4bit( blocksize: int = 64, quant_type="fp4", ) -> torch.Tensor: - """ - Dequantizes FP4 blockwise quantized values. + """Dequantizes a packed 4-bit quantized tensor. - Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. - Parameters - ---------- - A : torch.Tensor - The input tensor (packed 4-bit values). - quant_state : QuantState - object with quantisation stats, incl. absmax values, original tensor shape and original dtype. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - Dequantized output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_4bit`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 64. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. + Raises: + ValueError: Raised when the input data type or blocksize is not supported. - Returns - ------- - torch.Tensor: - Dequantized tensor. + Returns: + `torch.Tensor`: The dequantized tensor. """ + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError( f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", @@ -2277,7 +2281,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten Args: A (`torch.Tensor`): The first matrix operand with the data type `torch.int8`. B (`torch.Tensor`): The second matrix operand with the data type `torch.int8`. - out (`torch.Tensor, *optional*): A pre-allocated tensor used to store the result. + out (`torch.Tensor`, *optional*): A pre-allocated tensor used to store the result. dtype (`torch.dtype`, *optional*): The expected data type of the output. Defaults to `torch.int32`. Raises: @@ -2384,7 +2388,7 @@ def int8_mm_dequant( A (`torch.Tensor` with dtype `torch.int32`): The result of a quantized int8 matrix multiplication. row_stats (`torch.Tensor`): The row-wise quantization statistics for the lhs operand of the matrix multiplication. col_stats (`torch.Tensor`): The column-wise quantization statistics for the rhs operand of the matrix multiplication. - out (`torch.Tensor], *optional*): A pre-allocated tensor to store the output of the operation. + out (`torch.Tensor`, *optional*): A pre-allocated tensor to store the output of the operation. bias (`torch.Tensor`, *optional*): An optional bias vector to add to the result. Returns: @@ -2454,7 +2458,7 @@ def get_colrow_absmax( row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped. col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped. nnz_block_ptr (`torch.Tensor`, *optional*): Not used. - threshold (`float`, `optional`): + threshold (`float`, *optional*): An optional threshold for sparse decomposition of outlier features. No outliers are held back when 0.0. Defaults to 0.0. diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 629c6d0f8..5fa353d6d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -59,7 +59,7 @@ - title: k-bit quantizers sections: - local: reference/nn/linear8bit - title: 8-bit quantizer + title: LLM.int8() - local: reference/nn/linear4bit title: 4-bit quantizer - local: reference/nn/embeddings From 213b10bb39fbec9cd70a2fb66a277fb540050a83 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Dec 2024 23:39:27 -0500 Subject: [PATCH 63/65] docs --- docs/source/reference/functional.mdx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/functional.mdx b/docs/source/reference/functional.mdx index 0384ccfc8..a666f2442 100644 --- a/docs/source/reference/functional.mdx +++ b/docs/source/reference/functional.mdx @@ -37,7 +37,12 @@ The `bitsandbytes.functional` API provides the low-level building blocks for the [[autodoc]] functional.QuantState -## General Quantization +## Dynamic 8-bit Quantization + +Primitives used in the 8-bit optimizer quantization. + +For more details see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] + [[autodoc]] functional.dequantize_blockwise [[autodoc]] functional.quantize_blockwise From ca6fd44ec228fe18da0fd0b5cbe847dadbbb6f6e Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 4 Dec 2024 12:59:39 -0500 Subject: [PATCH 64/65] add inference benchmark script --- benchmarking/inference_benchmark.py | 134 ++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 benchmarking/inference_benchmark.py diff --git a/benchmarking/inference_benchmark.py b/benchmarking/inference_benchmark.py new file mode 100644 index 000000000..61ac570f2 --- /dev/null +++ b/benchmarking/inference_benchmark.py @@ -0,0 +1,134 @@ +""" +Inference benchmarking tool. + +Requirements: + transformers + accelerate + bitsandbytes + optimum-benchmark + +Usage: python inference_benchmark.py model_id + +options: + -h, --help show this help message and exit + --configs {bf16,fp16,nf4,nf4-dq,int8,int8-decomp} [{bf16,fp16,nf4,nf4-dq,int8,int8-decomp} ...] + --bf16 + --fp16 + --nf4 + --nf4-dq + --int8 + --int8-decomp + --batches BATCHES [BATCHES ...] + --input-length INPUT_LENGTH + --out-dir OUT_DIR +""" + +import argparse +from pathlib import Path + +from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig +from optimum_benchmark.logging_utils import setup_logging +import torch + +BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8 + +WEIGHTS_CONFIGS = { + "fp16": {"torch_dtype": "float16", "quantization_scheme": None, "quantization_config": {}}, + "bf16": {"torch_dtype": "bfloat16", "quantization_scheme": None, "quantization_config": {}}, + "nf4": { + "torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16", + "quantization_scheme": "bnb", + "quantization_config": { + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_use_double_quant": False, + "bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16", + }, + }, + "nf4-dq": { + "torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16", + "quantization_scheme": "bnb", + "quantization_config": { + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_use_double_quant": True, + "bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16", + }, + }, + "int8-decomp": { + "torch_dtype": "float16", + "quantization_scheme": "bnb", + "quantization_config": { + "load_in_8bit": True, + "llm_int8_threshold": 6.0, + }, + }, + "int8": { + "torch_dtype": "float16", + "quantization_scheme": "bnb", + "quantization_config": { + "load_in_8bit": True, + "llm_int8_threshold": 0.0, + }, + }, +} + +if __name__ == "__main__": + setup_logging(level="INFO") + + parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool") + + parser.add_argument("model_id", type=str, help="The model checkpoint to use.") + + parser.add_argument( + "--configs", + nargs="+", + choices=["bf16", "fp16", "nf4", "nf4-dq", "int8", "int8-decomp"], + default=["nf4", "int8", "int8-decomp"], + ) + parser.add_argument("--bf16", dest="configs", action="append_const", const="bf16") + parser.add_argument("--fp16", dest="configs", action="append_const", const="fp16") + parser.add_argument("--nf4", dest="configs", action="append_const", const="nf4") + parser.add_argument("--nf4-dq", dest="configs", action="append_const", const="nf4-dq") + parser.add_argument("--int8", dest="configs", action="append_const", const="int8") + parser.add_argument("--int8-decomp", dest="configs", action="append_const", const="int8-decomp") + + parser.add_argument("--batches", nargs="+", type=int, default=[1, 8, 16, 32]) + parser.add_argument("--input-length", type=int, default=64) + + parser.add_argument("--out-dir", type=str, default="reports") + + args = parser.parse_args() + + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + for batch_size in args.batches: + print(f"Benchmarking batch size: {batch_size}") + for config in args.configs: + launcher_config = ProcessConfig(device_isolation=True, start_method="spawn") + scenario_config = InferenceConfig( + latency=True, + memory=True, + input_shapes={"batch_size": batch_size, "sequence_length": args.input_length}, + ) + backend_config = PyTorchConfig( + device="cuda", + device_ids="0", + device_map="auto", + no_weights=False, + model=args.model_id, + **WEIGHTS_CONFIGS[config], + ) + benchmark_config = BenchmarkConfig( + name=f"benchmark-{config}-bsz{batch_size}", + scenario=scenario_config, + launcher=launcher_config, + backend=backend_config, + ) + + out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json" + + benchmark_report = Benchmark.launch(benchmark_config) + benchmark_report.log() + benchmark_report.save_json(out_path) From b8c736b5bde445a8b2af37a71d03086fed7a4355 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 4 Dec 2024 13:53:25 -0500 Subject: [PATCH 65/65] Add benchmarks, doc update --- benchmarking/README.md | 159 +++++++++++++++++++++++++++ docs/source/reference/functional.mdx | 2 +- 2 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 benchmarking/README.md diff --git a/benchmarking/README.md b/benchmarking/README.md new file mode 100644 index 000000000..ebd2bcf56 --- /dev/null +++ b/benchmarking/README.md @@ -0,0 +1,159 @@ +# Benchmarking + +## Inference +End-to-end inference benchmarking can be performed using the 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library. + +See the example script in +[inference_benchmark.py](inference_benchmark.py). + +### Results (as of v0.45.0) + +Our overall benchmarking results compared with v0.44.1 provide the following insights: +#### LLM.int8() +* **Turing/Ampere/Ada**: The observed per-token throughput is improved by 60-85%, while latency is decreased by 40-45%. +* **H100**: With our benchmarking of Llama 3.1 70B, we observed the new LLM.int8() to consistently outperform NF4 at batch size >= 8. + +#### NF4/FP4 +* **Turing/Ampere/Ada**: With batch size of 1, per-token throughput is _improved by 10-25%_ and per-token latency is _decreased by 10-20%_. +* **H100**: Across all batch sizes, per-token throughput is _improved by up to 28%_ and per-token latency is _decreased by up to 22%_. + +Summaries with the benchmarking results are provided below. + +#### NVIDIA T4 16GB +
+Qwen 2.5 3B Instruct + +| | Batch Size | Mean Latency (s) v0.45.0.dev | Throughput v0.45.0.dev | Mean Latency (s) v0.44.1 | Latency Improvement | Throughput v0.44.1 | Throughput Improvement | +|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------| +| FP16 | 1 | 0.0390 | 25.66 | 0.0390 | 1.00 | 25.66 | 1.000x | +| NF4 | 1 | 0.0608 | 16.45 | 0.0710 | 1.14 | 14.08 | 1.168x | +| NF4+DQ | 1 | 0.0736 | 13.58 | 0.0905 | 1.19 | 11.05 | 1.229x | +| INT8 | 1 | 0.0902 | 11.08 | 0.1609 | 1.44 | 6.21 | 1.784x | +| INT8+Decomp | 1 | 0.1672 | 5.98 | 0.2994 | 1.44 | 3.34 | 1.790x | +| FP16 | 8 | 0.0422 | 189.56 | 0.0422 | 1.00 | 189.56 | 1.000x | +| NF4 | 8 | 0.0960 | 83.37 | 0.1010 | 1.05 | 79.17 | 1.053x | +| NF4+DQ | 8 | 0.1042 | 76.80 | 0.1156 | 1.10 | 69.18 | 1.110x | +| INT8 | 8 | 0.0919 | 87.01 | 0.1640 | 1.44 | 48.78 | 1.784x | +| INT8+Decomp | 8 | 0.1812 | 44.15 | 0.3296 | 1.45 | 24.28 | 1.818x | +| FP16 | 32 | 0.0601 | 532.30 | 0.0601 | 1.00 | 532.30 | 1.000x | +| NF4 | 32 | 0.1150 | 278.32 | 0.1182 | 1.03 | 270.71 | 1.028x | +| NF4+DQ | 32 | 0.1215 | 263.36 | 0.1297 | 1.06 | 246.76 | 1.067x | +| INT8 | 32 | 0.0943 | 339.21 | 0.1640 | 1.42 | 195.14 | 1.738x | +| INT8+Decomp | 32 | 0.1912 | 167.37 | 0.3413 | 1.44 | 93.75 | 1.785x | +
+ +#### NVIDIA RTX 4090 24GB +
+Llama 3.1 8B + +| | Batch Size | Mean Latency (s) v0.45.0.dev | Throughput v0.45.0.dev | Mean Latency (s) v0.44.1 | Latency Improvement | Throughput v0.44.1 | Throughput Improvement | +|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------| +| BF16 | 1 | 0.0211 | 47.46 | 0.0211 | 1.00 | 47.46 | 1.000x | +| NF4 | 1 | 0.0148 | 67.71 | 0.0164 | 1.10 | 61.08 | 1.109x | +| NF4+DQ | 1 | 0.0175 | 57.08 | 0.0208 | 1.16 | 48.15 | 1.185x | +| INT8 | 1 | 0.0220 | 45.39 | 0.0395 | 1.44 | 25.32 | 1.793x | +| INT8+Decomp | 1 | 0.0449 | 22.26 | 0.0743 | 1.40 | 13.45 | 1.655x | +| BF16 | 8 | 0.0239 | 334.64 | 0.0239 | 1.00 | 334.64 | 1.000x | +| NF4 | 8 | 0.0425 | 188.08 | 0.0422 | 0.99 | 189.50 | 0.993x | +| NF4+DQ | 8 | 0.0443 | 180.68 | 0.0437 | 0.99 | 183.02 | 0.987x | +| INT8 | 8 | 0.0221 | 361.61 | 0.0389 | 1.43 | 205.82 | 1.757x | +| INT8+Decomp | 8 | 0.0478 | 164.55 | 0.0777 | 1.38 | 103.01 | 1.597x | +| BF16 | 32 | 0.0304 | 1054.35 | 0.0304 | 1.00 | 1054.35 | 1.000x | +| NF4 | 32 | 0.0461 | 694.60 | 0.0466 | 1.01 | 686.90 | 1.011x | +| NF4+DQ | 32 | 0.0471 | 678.73 | 0.0480 | 1.02 | 666.33 | 1.019x | +| INT8 | 32 | 0.0230 | 1390.54 | 0.0390 | 1.41 | 819.99 | 1.696x | +| INT8+Decomp | 32 | 0.0512 | 624.94 | 0.0835 | 1.39 | 383.18 | 1.631x | +
+ +
+Qwen 2.5 14B Instruct + +| | Batch Size | Mean Latency (s) v0.45.0.dev | Throughput v0.45.0.dev | Mean Latency (s) v0.44.1 | Latency Improvement | Throughput v0.44.1 | Throughput Improvement | +|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------| +| NF4 | 1 | 0.0214 | 46.74 | 0.0256 | 1.16 | 39.10 | 1.195x | +| NF4+DQ | 1 | 0.0256 | 39.03 | 0.0318 | 1.19 | 31.46 | 1.241x | +| INT8 | 1 | 0.0326 | 30.68 | 0.0596 | 1.45 | 16.79 | 1.827x | +| INT8+Decomp | 1 | 0.0648 | 15.44 | 0.1105 | 1.41 | 9.05 | 1.706x | +| NF4 | 8 | 0.0696 | 114.95 | 0.0697 | 1.00 | 114.78 | 1.001x | +| NF4+DQ | 8 | 0.0719 | 111.29 | 0.0723 | 1.01 | 110.70 | 1.005x | +| INT8 | 8 | 0.0325 | 246.22 | 0.0596 | 1.45 | 134.21 | 1.835x | +| INT8+Decomp | 8 | 0.0721 | 110.95 | 0.1201 | 1.40 | 66.62 | 1.665x | +
+ + +#### NVIDIA H100 80GB SXM +
+Llama 3.1 8B + +| | Batch Size | Mean Latency (s) v0.45.0.dev | Throughput v0.45.0.dev | Mean Latency (s) v0.44.1 | Latency Improvement | Throughput v0.44.1 | Throughput Improvement | +|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------| +| BF16 | 1 | 0.0244 | 40.99 | 0.0244 | 1.00 | 40.99 | 1.000x | +| NF4 | 1 | 0.0331 | 30.14 | 0.0391 | 1.15 | 25.60 | 1.177x | +| NF4+DQ | 1 | 0.0411 | 24.34 | 0.0528 | 1.22 | 18.92 | 1.286x | +| INT8 | 1 | 0.0522 | 19.17 | N/A | N/A | N/A | N/A | +| INT8+Decomp | 1 | 0.0817 | 12.24 | N/A | N/A | N/A | N/A | +| BF16 | 8 | 0.0255 | 313.90 | 0.0255 | 1.00 | 313.90 | 1.000x | +| NF4 | 8 | 0.0476 | 168.05 | 0.0551 | 1.14 | 145.13 | 1.158x | +| NF4+DQ | 8 | 0.0566 | 141.27 | 0.0663 | 1.15 | 120.67 | 1.171x | +| INT8 | 8 | 0.0515 | 155.44 | N/A | N/A | N/A | N/A | +| INT8+Decomp | 8 | 0.0853 | 93.79 | N/A | N/A | N/A | N/A | +| BF16 | 32 | 0.0261 | 1227.96 | 0.0261 | 1.00 | 1227.96 | 1.000x | +| NF4 | 32 | 0.0486 | 658.65 | 0.0546 | 1.11 | 585.91 | 1.124x | +| NF4+DQ | 32 | 0.0577 | 555.06 | 0.0665 | 1.13 | 481.04 | 1.154x | +| INT8 | 32 | 0.0545 | 586.26 | N/A | N/A | N/A | N/A | +| INT8+Decomp | 32 | 0.0864 | 370.51 | N/A | N/A | N/A | N/A | +
+ +
+Qwen 2.5 32B Instruct + +| | Batch Size | Mean Latency (s) v0.45.0.dev | Throughput v0.45.0.dev | +|-------------|------------|-----------------------------------------|-----------------------------------| +| BF16 | 1 | 0.0508 | 19.67 | +| NF4 | 1 | 0.0707 | 14.14 | +| NF4+DQ | 1 | 0.0860 | 11.63 | +| INT8 | 1 | 0.1031 | 9.70 | +| INT8+Decomp | 1 | 0.1820 | 5.49 | +| BF16 | 8 | 0.0525 | 152.50 | +| NF4 | 8 | 0.1154 | 69.35 | +| NF4+DQ | 8 | 0.1209 | 66.19 | +| INT8 | 8 | 0.1078 | 74.24 | +| INT8+Decomp | 8 | 0.1958 | 40.87 | +| BF16 | 32 | 0.0547 | 584.54 | +| NF4 | 32 | 0.1246 | 256.84 | +| NF4+DQ | 32 | 0.1298 | 246.47 | +| INT8 | 32 | 0.1056 | 302.96 | +| INT8+Decomp | 32 | 0.2027 | 157.83 | +
+ +
+Llama 3.1 70B + +| | Batch Size | Mean Latency (s) v0.45.0.dev | Throughput v0.45.0.dev | +|-------------|------------|-----------------------------------------|-----------------------------------| +| NF4 | 1 | 0.0833 | 12.00 | +| NF4+DQ | 1 | 0.1052 | 9.50 | +| INT8 | 1 | 0.1294 | 7.73 | +| INT8+Decomp | 1 | 0.1985 | 5.04 | +| NF4 | 8 | 0.2348 | 34.07 | +| NF4+DQ | 8 | 0.2423 | 33.01 | +| INT8 | 8 | 0.1313 | 60.94 | +| INT8+Decomp | 8 | 0.2052 | 38.99 | +| NF4 | 32 | 0.2491 | 128.46 | +| NF4+DQ | 32 | 0.2580 | 124.04 | +| INT8 | 32 | 0.1314 | 243.45 | +| INT8+Decomp | 32 | 0.2189 | 146.19 | +
+ +#### Software Configuration +We focus on the default PyTorch CUDA backend in 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark). We used commit [`6e6b1036`](https://github.com/huggingface/optimum-benchmark/commit/6e6b10363f3ac65926881f2c6a6113b6cefc06cd). + +For all hardware configurations, we used the following dependencies: +* `transformers==4.46.3` +* `accelerate==1.1.1` +* `tokenizers==0.20.3` +* `torch==2.5.1` +* `bitsandbytes==0.44.1` +* `bitsandbytes==0.45.0.dev` + +In the RTX 4090 setting, the CUDA 12.4 build of PyTorch is used. In the other settings we used the CUDA 12.1 build. diff --git a/docs/source/reference/functional.mdx b/docs/source/reference/functional.mdx index a666f2442..dbbe21794 100644 --- a/docs/source/reference/functional.mdx +++ b/docs/source/reference/functional.mdx @@ -41,7 +41,7 @@ The `bitsandbytes.functional` API provides the low-level building blocks for the Primitives used in the 8-bit optimizer quantization. -For more details see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] +For more details see [8-Bit Approximations for Parallelism in Deep Learning](https://arxiv.org/abs/1511.04561) [[autodoc]] functional.dequantize_blockwise