From c9854d8d01fc2bb82b67ec02452d442f1e3f72b0 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 27 Nov 2023 21:31:17 -0800 Subject: [PATCH] Adding int4 tensor subclass Summary: Adding int4 tensor subclass support, also refactoring tensor subclass code to be easier to use with multiple subclasses. This subclass uses the tinygemm int4 mixed dtype gemm that was added to pytroch as _weight_int4pack_mm and _convert_weight_to_int4pack. Also added support for .to for tensor subclasses to get the save/loading of meta tensors working for int4. Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/test.py | 143 +++++++--- torchao/quantization/__init__.py | 8 +- torchao/quantization/quant_api.py | 34 ++- torchao/quantization/quant_primitives.py | 109 ++++++++ torchao/quantization/subclass.py | 332 +++++++++++++++++------ torchao/quantization/utils.py | 6 + 6 files changed, 493 insertions(+), 139 deletions(-) diff --git a/test/test.py b/test/test.py index d57519ba28..b2b95a1495 100644 --- a/test/test.py +++ b/test/test.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code - +from torch._dynamo import config from torch.ao.quantization import MinMaxObserver, QConfigMapping from torchao.quantization.dynamic_quant import ( @@ -21,7 +21,8 @@ apply_dynamic_quant, apply_weight_only_int8_quant, change_linear_weights_to_dqtensors, - change_linear_weights_to_woqtensors, + change_linear_weights_to_int8woqtensors, + change_linear_weights_to_int4woqtensors, _replace_with_custom_fn_if_matches_filter, ) from torchao.quantization.quant_primitives import ( @@ -42,8 +43,9 @@ swap_linear_with_smooth_fq_linear, ) from torchao.quantization.subclass import ( - DynamicallyQuantizedLinearWeight, - WeightOnlyQuantizedLinearWeight + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, + Int4WeightOnlyQuantizedLinearWeight ) from torchao.quantization.utils import ( apply_logging_hook, @@ -59,6 +61,7 @@ import os torch.manual_seed(0) +config.cache_size_limit = 100 class SmoothquantUnitTest(unittest.TestCase): @@ -788,62 +791,108 @@ def test_qlinear_per_channel_numerics_cuda(self): class TestSubclass(unittest.TestCase): + def _test_dequantize_impl( + self, + test_subclass_from_float, + min_sqnr=35, + test_dtype=torch.bfloat16, + test_shape=[32, 64, 64], + ): + m, k, n = test_shape + lin = torch.nn.Linear(k, n, device="cuda").to(test_dtype) + w = lin.weight.detach() + lin.weight = torch.nn.Parameter( + test_subclass_from_float(lin.weight), requires_grad=False + ) + self.assertGreater(SQNR(w, lin.weight.dequantize()), min_sqnr, f"{lin.weight.__class__.__name__} failed dtype={test_dtype}") + self.assertGreater(SQNR(w.t(), lin.weight.t().dequantize()), min_sqnr, f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}") + + def test_dequantize_int8_dynamic_quant_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_dequantize_impl(Int8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype) + + def test_dequantize_int8_weight_only_quant_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_dequantize_impl(Int8WeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype) + + def test_dequantize_int4_weight_only_quant_subclass(self): + self._test_dequantize_impl(Int4WeightOnlyQuantizedLinearWeight.from_float, 15, test_shape=[1, 1024, 8]) + for groupsize in [256, 128]: + for inner_k_tiles in [8, 2]: + for m in [1, 256]: + self._test_dequantize_impl(lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles), 15, test_shape=[m, 256, 8]) + def _test_lin_weight_subclass_impl(self, - test_subclass, + test_subclass_from_float, min_sqnr=35, - test_dtypes=[torch.float32, torch.float16, torch.bfloat16], - test_shape=[32, 64, 32] + test_dtype=torch.bfloat16, + test_shape=[32, 64, 32], ): - for test_dtype in test_dtypes: - m, k, n = test_shape - x = torch.randn(m, k, device="cuda", dtype=test_dtype) - lin = torch.nn.Linear(k, n, device="cuda").to(test_dtype) - ref_f = lin(x) - - lin.weight = torch.nn.Parameter( - test_subclass.from_float(lin.weight), requires_grad=False - ) - test = lin(x) - self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{test_subclass.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}") - lin_comp = torch.compile(lin, mode='max-autotune') - test_comp = lin_comp(x) - self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{test_subclass.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}") + m, k, n = test_shape + x = torch.randn(m, k, device="cuda", dtype=test_dtype) + lin = torch.nn.Linear(k, n, device="cuda").to(test_dtype) + ref_f = lin(x) + + lin.weight = torch.nn.Parameter( + test_subclass_from_float(lin.weight), requires_grad=False + ) + test = lin(x) + self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{lin.weight.__class__.__name__} failed, no compile, dtype={test_dtype}, (m, k, n)={test_shape}") + lin_comp = torch.compile(lin, mode='max-autotune') + test_comp = lin_comp(x) + self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}") def test_int8_dynamic_quant_subclass(self): - self._test_lin_weight_subclass_impl(DynamicallyQuantizedLinearWeight, 35) + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl(Int8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype) def test_int8_weight_only_quant_subclass(self): - self._test_lin_weight_subclass_impl(WeightOnlyQuantizedLinearWeight, 40) + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl(Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype) + + def test_int4_weight_only_quant_subclass(self): + self._test_lin_weight_subclass_impl(Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]) + for groupsize in [128, 64]: + for inner_k_tiles in [4, 2]: + for m in [1, 256]: + self._test_lin_weight_subclass_impl(lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(w, groupsize, inner_k_tiles), 10, test_shape=[m, 256, 8]) @torch.no_grad() def _test_lin_weight_subclass_api_impl( self, api, min_sqnr=35, - test_dtypes=[torch.float32, torch.float16, torch.bfloat16], + test_dtype=torch.bfloat16, test_shape=[32, 64, 32] ): - for test_dtype in test_dtypes: - m, k, n = test_shape - x = torch.randn(m, k, device="cuda", dtype=test_dtype) - mod = nn.Sequential( - nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda") - ).to(test_dtype) - ref_f = mod(x) - api(mod) - test = mod(x) - self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}") - - mod_qc = torch.compile(mod, mode="max-autotune") - test_comp = mod_qc(x) - self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}") + m, k, n = test_shape + x = torch.randn(m, k, device="cuda", dtype=test_dtype) + mod = nn.Sequential( + nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda") + ).to(test_dtype) + ref_f = mod(x) + api(mod) + test = mod(x) + self.assertGreater(SQNR(ref_f, test), min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}") + mod_qc = torch.compile(mod, mode="max-autotune") + test_comp = mod_qc(x) + self.assertGreater(SQNR(ref_f, test_comp), min_sqnr, f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}") def test_int8_dynamic_quant_subclass_api(self): - self._test_lin_weight_subclass_api_impl(change_linear_weights_to_dqtensors, 35) + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_api_impl(change_linear_weights_to_dqtensors, 35) def test_int8_weight_only_quant_subclass_api(self): - self._test_lin_weight_subclass_api_impl(change_linear_weights_to_woqtensors, 40) + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_api_impl(change_linear_weights_to_int8woqtensors, 40) + + def test_int4_weight_only_quant_subclass_api(self): + self._test_lin_weight_subclass_api_impl(change_linear_weights_to_int4woqtensors, 15, test_shape=[1, 1024, 256]) + for groupsize in [64, 32]: + for inner_k_tiles in [4, 2]: + kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles} + self._test_lin_weight_subclass_api_impl(lambda mod: change_linear_weights_to_int4woqtensors(mod, **kwargs), 15, test_shape=[256, 256, 8]) class TestDynamicQuant(unittest.TestCase): def test_dynamic_quant(self): @@ -906,7 +955,7 @@ def test_weight_only_quant_use_mixed_mm(self): class TestSaveLoadMeta(unittest.TestCase): @torch.no_grad() - def _test_handle_save_load_meta_impl(self, api): + def _test_handle_save_load_meta_impl(self, api, min_sqnr=35): m, k, n = 32, 64, 32 class test_model(nn.Module): def __init__(self): @@ -934,7 +983,7 @@ def forward(self, x): model_qc = torch.compile(model, mode="max-autotune") ref_q = model_qc(x).detach() - assert SQNR(ref_f, ref_q) > 35 + assert SQNR(ref_f, ref_q) > min_sqnr # load model structure with torch.device('meta'): @@ -951,7 +1000,7 @@ def forward(self, x): model_qc = torch.compile(model, mode="max-autotune") test = model_qc(x).detach() - assert SQNR(ref_f, test) > 35 + assert SQNR(ref_f, test) > min_sqnr self.assertTrue(torch.equal(ref_q, test)) @torch.no_grad() @@ -959,8 +1008,12 @@ def test_save_load_dqtensors(self): self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors) @torch.no_grad() - def test_save_load_woqtensors(self): - self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors) + def test_save_load_int8woqtensors(self): + self._test_handle_save_load_meta_impl(change_linear_weights_to_int8woqtensors) + + @torch.no_grad() + def test_save_load_int4woqtensors(self): + self._test_handle_save_load_meta_impl(change_linear_weights_to_int4woqtensors, 20) class TorchCompileUnitTest(unittest.TestCase): def test_fullgraph(self): diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index f2912e1ce5..217edae19d 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -16,7 +16,8 @@ "apply_weight_only_int8_quant", "apply_dynamic_quant", "change_linear_weights_to_dqtensors", - "change_linear_weights_to_woqtensors", + "change_linear_weights_to_int8woqtensors", + "change_linear_weights_to_int4woqtensors", "insert_subclass", "safe_int_mm", "dynamically_quantize_per_tensor", @@ -34,8 +35,9 @@ "swap_linear_with_smooth_fq_linear", "smooth_fq_linear_to_inference", "set_smooth_fq_attribute", - "DynamicallyQuantizedLinearWeight", - "WeightOnlyQuantizedLinearWeight", + "Int8DynamicallyQuantizedLinearWeight", + "Int8WeightOnlyQuantizedLinearWeight", + "Int4WeightOnlyQuantizedLinearWeight", "log_with_rank", "clear_logs", "compute_error", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c117146112..04779e014f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -20,8 +20,9 @@ DynamicallyPerAxisQuantizedLinear, ) from .subclass import ( - DynamicallyQuantizedLinearWeight, - WeightOnlyQuantizedLinearWeight, + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, + Int4WeightOnlyQuantizedLinearWeight, ) from .weight_only import ( WeightOnlyInt8QuantLinear, @@ -31,7 +32,8 @@ "apply_weight_only_int8_quant", "apply_dynamic_quant", "change_linear_weights_to_dqtensors", - "change_linear_weights_to_woqtensors", + "change_linear_weights_to_int8woqtensors", + "change_linear_weights_to_int4woqtensors", ] @@ -77,34 +79,46 @@ def apply_dynamic_quant(model): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -def _get_subclass_inserter(cls): +def _get_subclass_inserter(cls, **kwargs): def insert_subclass(lin): lin.weight = torch.nn.Parameter( - cls.from_float(lin.weight), requires_grad=False + cls.from_float(lin.weight, **kwargs), requires_grad=False ) return lin return insert_subclass def change_linear_weights_to_dqtensors(model): """ - Converts all linear weight tensors to the `DynamicallyQuantizedLinearWeight` + Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass, effectively applying the same form of quantization as apply_dynamic_quant while not modifying the linear modules. """ _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(DynamicallyQuantizedLinearWeight), + _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight), lambda mod, fqn: isinstance(mod, torch.nn.Linear) ) -def change_linear_weights_to_woqtensors(model): +def change_linear_weights_to_int8woqtensors(model): """ - Converts all linear weight tensors to the `WeightOnlyQuantizedLinearWeight` + Converts all linear weight tensors to the `Int8WeightOnlyQuantizedLinearWeight` Tensor subclass, effectively applying the same form of quantization as apply_dynamic_quant while not modifying the linear modules. """ _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(WeightOnlyQuantizedLinearWeight), + _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight), + lambda mod, fqn: isinstance(mod, torch.nn.Linear) + ) + +def change_linear_weights_to_int4woqtensors(model, **kwargs): + """ + Converts all linear weight tensors to the `Int4WeightOnlyQuantizedLinearWeight` + Tensor subclass, effectively applying the same form of quantization + as apply_dynamic_quant while not modifying the linear modules. + """ + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs), lambda mod, fqn: isinstance(mod, torch.nn.Linear) ) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index be54751f68..1cb9a631c8 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -19,6 +19,13 @@ "quant_int8_matmul", "quant_int8_dynamic_per_token_linear", "quant_int8_per_token_matmul", + "get_groupwise_affine_qparams", + "pack_tinygemm_scales_and_zeros", + "unpack_tinygemm_scales_and_zeros", + "groupwise_affine_quantize_tensor_from_qparams", + "groupwise_affine_dequantize_tensor_from_qparams", + "groupwise_affine_quantize_tensor", + "groupwise_affine_dequantize_tensor", ] @@ -375,3 +382,105 @@ def quant_int8_per_token_matmul( # can downcast only at the very end y = y.to(output_dtype) return y + +def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128): + """ + + """ + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + # assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( + torch.bfloat16 + ).reshape(w.shape[0], -1) + +def pack_tinygemm_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + +def unpack_tinygemm_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + +def groupwise_affine_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + # assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int4x8 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int4x8 + +def groupwise_affine_dequantize_tensor_from_qparams( + w_int4x8, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int4x8.shape[-1] + assert w_int4x8.shape[-1] % groupsize == 0 + assert w_int4x8.dim() == 2 + + w_int4x8_grouped = w_int4x8.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int4x8_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int4x8) + ) + return w_dq + +def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize) + w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + return w_int4x8, scales_and_zeros + +def groupwise_affine_dequantize_tensor(w_int4x8, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_tinygemm_scales_and_zeros(scales_and_zeros) + return groupwise_affine_dequantize_tensor_from_qparams( + w_int4x8, scales, zeros, n_bit, groupsize + ) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 3649b8e029..11a6aea752 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -9,48 +9,44 @@ dequantize_per_channel, dynamically_quantize_per_channel, quant_int8_dynamic_per_token_linear, + groupwise_affine_quantize_tensor, + unpack_tinygemm_scales_and_zeros, ) +from .utils import find_multiple from torch.utils._python_dispatch import return_and_correct_aliasing __all__ = [ - "DynamicallyQuantizedLinearWeight", - "WeightOnlyQuantizedLinearWeight" + "Int8DynamicallyQuantizedLinearWeight", + "Int8WeightOnlyQuantizedLinearWeight", + "Int4WeightOnlyQuantizedLinearWeight", ] - -class Int8QuantizedLinearWeightBase(torch.Tensor): +class QuantizedLinearWeightBase(torch.Tensor): """ - Base Quantized Tensor subclass for int8 quantized Linear weights. The weight - is quantized symmetrically per-channel. When the float_float method is used, - to create an instance of any Int8QuantizedLinearWeightBase, we assume the input + Base quantized tensor subclass for quantized linear weights. When the from_float method is used, + to create an instance of any QuantizedLinearWeightBase, we assume the input weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels. - Subclasses which inherit from this class need to implement the _quantized_op method. + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. """ @staticmethod - def __new__(cls, int_data, q_scales, transposed=False, **kwargs): - # The `transposed` argument indicates that the int_data (attribute or argument) - # is transposed compared to how we'd like the external representation - # of the shape to be. - # This is needed so we don't have to mutate the int_data when it gets - # transposed/detached, instead we can just pass the int_data to the - # new instance and alter the transposed flag where needed. + def __new__(cls, int_data, transposed, shape, *args, **kwargs): kwargs["device"] = int_data.device - kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) - size = int_data.shape[::-1] if transposed else int_data.shape kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout ) + assert "dtype" in kwargs assert not kwargs.get("requires_grad", False) kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, size, **kwargs) # type: ignore[attr-defined] + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, int_data, q_scales, transposed=False): - self._transposed = transposed + def __init__(self, int_data, transposed, *args, **kwargs): self.int_data = int_data - self.q_scales = q_scales + self.transposed = transposed @staticmethod - def _quantized_op(act_mat, int_w_mat, q_scales, bias): + def _quantized_op(act_mat, w_qtensor, bias): pass def __repr__(self): @@ -59,47 +55,43 @@ def __repr__(self): f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) - def dequantize(self, dtype=None): - """ - Obtain the dequantized version of the quantized tensor subclass - """ - dq_t = dequantize_per_channel( - self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype - ) - # note: data was already transposed to calculate out - return dq_t if self._transposed else dq_t.t() + def dequantize(self): + pass def int_repr(self): - """ - Get the internal integer representation of the quantized tensor - """ - return self.int_data.t() if self._transposed else self.int_data + pass - def q_scales(self): - """ - Get the quantization scales for the quantized tensor - """ - return self.q_scales + def q_params(self): + pass + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = memory_format if memory_format is not None else torch.preserve_format + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs def _detach(self): - return self.__class__( - self.int_data, self.q_scales, transposed=self._transposed - ) + pass def _transpose(self): - return self.__class__( - self.int_data, self.q_scales, transposed=(not self._transposed) - ) + pass def __tensor_flatten__(self): - return ["int_data", "q_scales"], self._transposed + pass @classmethod - def __tensor_unflatten__(cls, tensor_data, transposed): - int_data, q_scales = tensor_data["int_data"], tensor_data["q_scales"] - return cls( - int_data, q_scales, transposed=transposed - ) + def __tensor_unflatten__(cls, tensor_data, tensor_attributes): + pass + + @classmethod + def from_float(cls, input_float): + pass __torch_function__ = torch._C._disabled_torch_function_impl @@ -121,10 +113,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"need mat1 shape: {args[1].shape} final" f"dim to match mat2 shape: {args[2].shape} first dim " ) - mat1, mat2, q_scales, bias = ( + mat1, w_qtensor, bias = ( args[1], - args[2].int_data, - args[2].q_scales, + args[2], args[0], ) else: @@ -134,16 +125,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"need mat1 shape: {args[0].shape} final dim" f"to match mat2 shape: {args[1].shape} first dim" ) - mat1, mat2, q_scales, bias = ( + mat1, w_qtensor, bias = ( args[0], - args[1].int_data, - args[1].q_scales, + args[1], None, ) # call the quantized op for the specific type # of quantized tensor subclass return cls._quantized_op( - mat1, mat2, q_scales, bias + mat1, w_qtensor, bias ) if func is torch.ops.aten.detach.default: @@ -156,49 +146,229 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._transpose() ) + +class Int8DynamicallyQuantizedLinearWeight(QuantizedLinearWeightBase): + """ + A Tensor subclass that when applied to a weight used in a linear op/module, changes the + linear op to a dynamically quantized linear op with symmetric per-token and per-channel + quantization on the activation and weight respectively. + """ + @staticmethod + def __new__(cls, int_data, q_scales, transposed, shape, **kwargs): + kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) + return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, int_data, q_scales, transposed, shape, **kwargs): + self.q_scales = q_scales + super().__init__(int_data, transposed) + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return quant_int8_dynamic_per_token_linear( + act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype + ) + + def dequantize(self, dtype=None): + """ + Obtain the dequantized version of the quantized tensor subclass + """ + dq_t = dequantize_per_channel( + self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype + ).to(self.dtype) + # data was transposed to dequantize so make sure shape is correct + return dq_t if not self.transposed else dq_t.t() + + def int_repr(self): + """ + Get the internal integer representation of the quantized tensor + """ + return self.int_data if self.transposed else self.int_data.t() + + def q_params(self): + """ + Get the quantization scales for the quantized tensor + """ + return {"q_scales": self.q_scales} + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.q_scales.to(kwargs["device"]), + self.transposed, + self.shape, + **kwargs + ) + + def _detach(self): + return self.__class__( + self.int_data, self.q_scales, self.transposed, self.shape, dtype=self.dtype + ) + + def _transpose(self): + return self.__class__( + self.int_data, self.q_scales, not self.transposed, self.shape[::-1], dtype=self.dtype + ) + + def __tensor_flatten__(self): + return ["int_data", "q_scales"], [self.transposed, self.shape, self.dtype] + @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127, dtype=torch.int8): + def __tensor_unflatten__(cls, tensor_data, tensor_attributes): + int_data, q_scales = tensor_data["int_data"], tensor_data["q_scales"] + transposed, shape, dtype = tensor_attributes + return cls( + int_data, q_scales, transposed, shape, dtype=dtype + ) + + @classmethod + def from_float(cls, input_float, qmin=-128, qmax=127): """ Method used to convert a linear weight tensor to an instance of the - desired Tensor subclass. + Int8DynamicallyQuantizedLinearWeight subclass. Example usage:: - model.lin_mod.weight = DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight) + model.lin_mod.weight = Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight) """ w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, dtype + input_float, qmin, qmax, torch.int8 ) # the desired representation shape for fast quantized matmul is # transposed compared to how it's stored as a linear weight, - # i.e. we want in_channels is dim=0 and out_channels (and quantized axis) is dim=1 - return cls(w_int_repr.contiguous().t(), w_scales, transposed=True) + # i.e. we want in_channels as dim=0 and out_channels (and quantized axis) as dim=1 + # however the external representation of our tensor will maintain the correct + # shape attribute which needs to be tracked directly. + int_data = w_int_repr.contiguous().t() + return cls(int_data, w_scales, False, input_float.shape, dtype=input_float.dtype) -class DynamicallyQuantizedLinearWeight(Int8QuantizedLinearWeightBase, torch.Tensor): +class Int8WeightOnlyQuantizedLinearWeight(Int8DynamicallyQuantizedLinearWeight): """ - A Tensor subclass that when applied to a weight used in a linear op/module, changes the - linear op to a dynamically quantized linear op with symmetric per-token and per-channel - quantization on the activation and weight respectively. + A Tensor subclass that when applied to a weight used in a linear op/module, + changes the linear op to a weight-only quantized linear op with symmetric + per-channel quantization on the weight. """ @staticmethod - def _quantized_op(act_mat, int_w_mat, q_scales, bias): - return quant_int8_dynamic_per_token_linear( - act_mat, int_w_mat, q_scales, bias, act_mat.dtype - ) + def _quantized_op(act_mat, w_qtensor, bias): + act_mat = act_mat.view(-1, act_mat.shape[-1]) + orig_dtype = act_mat.dtype + y = torch.mm(act_mat, w_qtensor.int_data.to(act_mat.dtype)) * w_qtensor.q_scales + y = y.reshape(*act_mat.shape[:-1], -1) + if bias is not None: + y += bias + return y.to(orig_dtype) -class WeightOnlyQuantizedLinearWeight(Int8QuantizedLinearWeightBase, torch.Tensor): +class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase): """ A Tensor subclass that when applied to a weight used in a linear op/module, - changes the linear op to a weight-only quantized linear op with symmetric - per-channel quantization on the weight. + changes that linear op to a weight-only int4 quantized linear op with groupwise + affine quantization on the weight. """ @staticmethod - def _quantized_op(act_mat, int_w_mat, q_scales, bias): - act_mat = act_mat.view(-1, act_mat.shape[-1]) - y = torch.mm(act_mat, int_w_mat.to(act_mat.dtype)) * q_scales - y = y.reshape(*act_mat.shape[:-1], -1) + def __new__(cls, int_data, scales_and_zeros, transposed, shape, groupsize=128, inner_k_tiles=8, **kwargs): + kwargs["dtype"] = kwargs.get("dtype", scales_and_zeros.dtype) + return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, int_data, scales_and_zeros, transposed, shape, groupsize, inner_k_tiles, **kwargs): + # the transposed flag tracks whether the tensor subclass has been transposed relative + # to how a weight is normally stored in a linear i.e. [out_features, in_features]. + # tracking both transposed and shape is slightly redundant but corner cases like + # square matrices can cause issues otherwise + self.scales_and_zeros = scales_and_zeros + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + super().__init__(int_data, transposed) + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape and pad activation + act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) + pad_size = find_multiple(act_mat.shape[1], 1024) + act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[1])) + + # matmul + y = torch.ops.aten._weight_int4pack_mm(act_mat, w_qtensor.int_data, w_qtensor.groupsize, w_qtensor.scales_and_zeros) + + y = y.reshape(*orig_act_size[:-1], -1) if bias is not None: y += bias - return y + return y.to(orig_dtype) + + def dequantize(self): + eye_shape = self.shape[1] if not self.transposed else self.shape[0] + w_dq = self._quantized_op(torch.eye(eye_shape, device=self.device, dtype=self.dtype), self, None) + # we dequantized using linear with the identity matrix, output has shape [in_channels, out_channels] + # so we need to transpose back to get the original shape unless self.transposed is set. + w_dq = w_dq if self.transposed else w_dq.t() + return w_dq.to(self.dtype) + + def int_repr(self): + return self.int_data + + def q_params(self): + scales, zero_points = unpack_tinygemm_scales_and_zeros(self.scales_and_zeros, ) + return {"q_scales": scales, "q_zero_points": zero_points} + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scales_and_zeros.to(kwargs["device"]), + self.transposed, + self.shape, + self.groupsize, + self.inner_k_tiles, + **kwargs + ) + + def _detach(self): + return self.__class__( + self.int_data, self.scales_and_zeros, self.transposed, self.shape, self.groupsize, self.inner_k_tiles, dtype=self.dtype + ) + + def _transpose(self): + return self.__class__( + self.int_data, self.scales_and_zeros, not self.transposed, self.shape[::-1], self.groupsize, self.inner_k_tiles, dtype=self.dtype + ) + + def __tensor_flatten__(self): + return ["int_data", "scales_and_zeros"], (self.transposed, self.shape, self.groupsize, self.inner_k_tiles, self.dtype) + + @classmethod + def __tensor_unflatten__(cls, tensor_data, attributes): + int_data, scales_and_zeros = tensor_data["int_data"], tensor_data["scales_and_zeros"] + transposed, shape, groupsize, inner_k_tiles, dtype = attributes + return cls( + int_data, scales_and_zeros, transposed, shape, groupsize, inner_k_tiles, dtype=dtype + ) + + @classmethod + def from_float(cls, input_float, groupsize=128, inner_k_tiles=8): + """ + Method used to convert a linear weight tensor to an instance of the + Int4WeightOnlyQuantizedLinearWeight subclass. + + Example usage:: + + model.lin_mod.weight = Int4WeightOnlyQuantizedLinearWeight.from_float(model.lin_mod.weight) + """ + assert groupsize in [256, 128, 64, 32] + assert inner_k_tiles in [8, 4, 2] + orig_shape = input_float.shape + out_features, orig_in_features = input_float.shape + assert out_features % 8 == 0, "require out_features % 8 == 0" + + # padding + in_features = find_multiple(orig_in_features, 1024) + input_float = torch.nn.functional.pad(input_float, (0, in_features - orig_in_features)) + + # quantization and packing + input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor(input_float, 4, groupsize) + int_data = torch.ops.aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) + + return cls(int_data, scales_and_zeros, False, orig_shape, groupsize, inner_k_tiles, dtype=input_float.dtype) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 6c162ba76c..236a5db5ea 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -11,6 +11,7 @@ from torch.utils._python_dispatch import TorchDispatchMode __all__ = [ + "find_multiple", "log_with_rank", "clear_logs", "compute_error", @@ -18,6 +19,11 @@ "get_model_size_in_bytes", ] +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + def log_with_rank(*args): # append