From a50fea5e0e996fb2d3e01c081beebc54c41f43a9 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 15 Apr 2024 19:27:20 -0700 Subject: [PATCH] Refactor tensor subclass API to also use paramterization Summary: Also added tests for tensor subclass api + AOTI compilation Test Plan: python test/integration/test_integration.py -k test_aoti Reviewers: Subscribers: Tasks: Tags: --- .github/workflows/regression_test.yml | 6 +- test/integration/test_integration.py | 83 +++++++++++++++--- torchao/quantization/quant_api.py | 32 ++++--- torchao/quantization/quant_primitives.py | 24 ++++-- torchao/quantization/subclass.py | 104 +++++++++++++++++------ 5 files changed, 191 insertions(+), 58 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 85a79cd5c..a9e8ab2d9 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -31,9 +31,9 @@ jobs: torch-spec: 'torch==2.3.0' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - - name: CUDA 2.4.0.dev20240421 + - name: CUDA 2.4.0.dev20240428 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: '--pre torch==2.4.0.dev20240421+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121' + torch-spec: '--pre torch==2.4.0.dev20240428+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - name: CPU 2.2.2 @@ -58,6 +58,8 @@ jobs: gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} script: | + conda create -n venv python=3.9 -y + conda activate venv python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} pip install -r requirements.txt diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 83211cecd..643c093be 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -67,20 +67,28 @@ from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os from parameterized import parameterized +import itertools +import logging from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 +logger = logging.getLogger("INFO") + torch.manual_seed(0) config.cache_size_limit = 100 -COMMON_DEVICE_DTYPE=[ - ("cpu", torch.float32), - ("cpu", torch.float16), - ("cpu", torch.bfloat16), - ("cuda", torch.float32), - ("cuda", torch.float16), - ("cuda", torch.bfloat16), +# TODO: use this to reduce the number of tests +TENSOR_SUBCLASS_APIS = [ + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + change_linear_weights_to_int4_woqtensors, ] +COMMON_DEVICES = ["cpu", "cuda"] + +COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + +COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() + def combine_parameters(a, b): new_tuples = [] for (tuple1, tuple2) in itertools.product(a, b): @@ -88,10 +96,16 @@ def combine_parameters(a, b): return new_tuples def run_supported_device_dtype(test_method): + """Assumes that the 3rd arg (args[2]) of the decorated method is device and + there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing + """ def wrapper(*args, **kwargs): - if args[2] == "cuda" and not torch.cuda.is_available(): + assert len(args) >= 3, f"Not enough args. Expected more than or equal to 3, but got {len(args)}" + device = args[2] + dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3] + if device == "cuda" and not torch.cuda.is_available(): raise unittest.SkipTest(f"Need CUDA available.") - if args[2] == "cuda" and torch.cuda.is_available() and kwargs['test_dtype'] == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): + if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): raise unittest.SkipTest("Need CUDA and SM80+ available.") return test_method(*args, **kwargs) return wrapper @@ -1148,6 +1162,7 @@ def _test_handle_save_load_meta_impl( min_sqnr=35, test_dtype=torch.bfloat16 ): + logger.info(f"TestSaveLoad: {api}, {test_device}, {test_dtype}") m, k, n = 32, 64, 32 class test_model(nn.Module): @@ -1180,7 +1195,7 @@ def forward(self, x): # load model structure with torch.device('meta'): - model = test_model() + model = test_model().to(dtype=test_dtype) api(model) # load quantized state_dict @@ -1407,5 +1422,53 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) + +class TestAOTI(unittest.TestCase): + @parameterized.expand( + list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), + ) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "aoti compatibility requires 2.4+.") + @torch.no_grad() + # @run_supported_device_dtype + def test_aoti(self, api, test_device, test_dtype): + logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}") + if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": + self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet") + + if test_dtype != torch.bfloat16: + self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet") + + m, k, n = 32, 64, 32 + + class test_model(nn.Module): + def __init__(self): + super().__init__() + self.lin1 = nn.Linear(k, n) + self.relu = nn.ReLU() + self.lin2 = nn.Linear(n, n) + + def forward(self, x): + x = self.lin1(x) + x = self.relu(x) + x = self.lin2(x) + return x + + x = torch.randn(m, k, dtype=test_dtype, device=test_device) + + # get float reference + model = test_model().to(dtype=test_dtype, device=test_device).eval() + ref_f = model(x) + + kwargs = {"dtype": test_dtype} + api(model, **kwargs) + + # running model + model(x) + + # make sure it compiles + example_inputs = (x,) + torch._export.aot_compile(model, example_inputs) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a830d52d7..d4dd01abe 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from .dynamic_quant import DynamicallyPerAxisQuantizedLinear -from .utils import TORCH_VERSION_AFTER_2_3 +from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, @@ -117,19 +117,27 @@ def apply_dynamic_quant(model, filter_fn=None): change_linear_weights_to_int8_dqtensors(model, filter_fn) -def _get_subclass_inserter(cls, **kwargs): - method = kwargs.pop("method", "from_float") +import torch.nn.utils.parametrize as parametrize + +def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs): + constructor = kwargs.pop("constructor", "subclass_constructor") + from_float = kwargs.pop("method", "from_float") def insert_subclass(lin): - lin.weight = torch.nn.Parameter( - # cls.from_float(...) - getattr(cls, method)(lin.weight, **kwargs), requires_grad=False - ) + if enable_parametrization: + lin.weight = torch.nn.Parameter(cls.from_float(lin.weight, **kwargs), requires_grad=False) + _, args = lin.weight.__tensor_flatten__() + parametrize.register_parametrization(lin, "weight", getattr(cls, constructor)(cls, *args)) + else: + lin.weight = torch.nn.Parameter( + # cls.from_float(...) + getattr(cls, from_float)(lin.weight, **kwargs), requires_grad=False + ) return lin return insert_subclass -def change_linear_weights_to_int8_dqtensors(model, filter_fn=None): +def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): """ Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass, effectively applying the same form of quantization @@ -141,11 +149,11 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None): ) _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight), filter_fn + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn ) -def change_linear_weights_to_int8_woqtensors(model, filter_fn=None): +def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): """ Converts all linear weight tensors to the `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, @@ -154,7 +162,7 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None): """ _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight), + _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), _is_linear if filter_fn is None else filter_fn, ) @@ -170,7 +178,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs), + _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn, ) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index bd4bcce1a..b42f6481d 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -47,6 +47,13 @@ ] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) +def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): + if dtype is not None and tensor_arg.dtype != dtype: + raise ValueError("Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead.") + if size is not None and tensor_arg.size() != size: + raise ValueError("Expected Tensor argument {arg_name} to have size {size}, but got {tensor_arg.size()} instead.") + + _DTYPE_TO_QVALUE_BOUNDS = { torch.uint8: (0, 255), torch.int8: (-128, 127), @@ -493,7 +500,7 @@ def quant_int8_dynamic_per_token_linear( x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype ) if bias is not None: - mm_out += bias + mm_out = mm_out + bias return mm_out @@ -554,7 +561,7 @@ def quant_int8_per_token_matmul( return y -def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128): +def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): """This is tinygemm specific, we'll keep this for now""" if groupsize > w.shape[-1]: groupsize = w.shape[-1] @@ -570,15 +577,14 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128): max_int = 2**n_bit - 1 scales = (max_val - min_val).clamp(min=1e-6) / max_int zeros = min_val + scales * (2 ** (n_bit - 1)) - return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( - torch.bfloat16 + return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to( + dtype=dtype ).reshape(w.shape[0], -1) def pack_tinygemm_scales_and_zeros(scales, zeros): - assert scales.shape == zeros.shape - assert scales.dtype == torch.bfloat16 - assert zeros.dtype == torch.bfloat16 + guard_dtype_size(scales, "scales", dtype=torch.bfloat16, size=zeros.size()) + guard_dtype_size(zeros, "zeros", dtype=torch.bfloat16) return ( torch.cat( [ @@ -661,8 +667,8 @@ def groupwise_affine_dequantize_tensor_from_qparams( return w_dq -def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128): - scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize) +def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): + scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype) w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( w, scales, zeros, n_bit, groupsize ) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index e9b532d6d..544114495 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -190,6 +190,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs): args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) +class ConstructTensorSubclass(torch.nn.Module): + def __init__(self, tensor_subclass_ctr, *args, **kwargs): + super().__init__() + self.tensor_subclass_ctr = tensor_subclass_ctr + self.args = args + self.kwargs = kwargs + + def right_inverse(self, tensor_subclass_instance): + fields, _ = tensor_subclass_instance.__tensor_flatten__() + return [getattr(tensor_subclass_instance, field) for field in fields] + +@torch._dynamo.allow_in_graph +def from_qtensor_components_int8dyn(*args, **kwargs): + return Int8DynamicallyQuantizedLinearWeight(*args, **kwargs) + +class ConstructTensorSubclassInt8Dyn(ConstructTensorSubclass): + def forward(self, int_data, q_scales): + return from_qtensor_components_int8dyn(int_data, q_scales, *self.args, **self.kwargs) + class Int8DynamicallyQuantizedLinearWeight(QuantizedLinearWeightBase): """ @@ -197,13 +216,16 @@ class Int8DynamicallyQuantizedLinearWeight(QuantizedLinearWeightBase): linear op to a dynamically quantized linear op with symmetric per-token and per-channel quantization on the activation and weight respectively. """ + subclass_constructor = ConstructTensorSubclassInt8Dyn @staticmethod - def __new__(cls, int_data, q_scales, transposed, shape, **kwargs): - kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) + def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs): + if dtype is None: + dtype = qscales.dtype + kwargs["dtype"] = dtype return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, int_data, q_scales, transposed, shape, **kwargs): + def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs): self.q_scales = q_scales super().__init__(int_data, transposed) @@ -266,14 +288,15 @@ def _change_shape(self, shape): ) def __tensor_flatten__(self): - return ["int_data", "q_scales"], [self.transposed, self.dtype, self.shape] + # note: the order of args must match the order of args in __init__ + return ["int_data", "q_scales"], [self.transposed, self.shape, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None ): int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] - transposed, dtype, shape = tensor_attributes + transposed, shape, dtype = tensor_attributes return cls( int_data, q_scales, @@ -284,7 +307,7 @@ def __tensor_unflatten__( ) @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): + def from_float(cls, input_float, qmin=-128, qmax=127, dtype=None): """ Method used to convert a linear weight tensor to an instance of the Int8DynamicallyQuantizedLinearWeight subclass. @@ -295,6 +318,9 @@ def from_float(cls, input_float, qmin=-128, qmax=127): Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight) ) """ + if dtype is None: + dtype = input_float.dtype + # because we call transpose in dequantization w_int_repr, w_scales, _ = dynamically_quantize_per_channel( input_float, qmin, qmax, torch.int8 @@ -308,16 +334,27 @@ def from_float(cls, input_float, qmin=-128, qmax=127): if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight): int_data = int_data.contiguous() return cls( - int_data, w_scales, False, input_float.shape, dtype=input_float.dtype + int_data, w_scales, False, input_float.shape, dtype=dtype, ) +@torch._dynamo.allow_in_graph +def from_qtensor_components_int8wo(*args, **kwargs): + return Int8WeightOnlyQuantizedLinearWeight(*args, **kwargs) + + +class ConstructTensorSubclassInt8wo(ConstructTensorSubclass): + def forward(self, int_data, q_scales): + return from_qtensor_components_int8wo(int_data, q_scales, *self.args, **self.kwargs) + + class Int8WeightOnlyQuantizedLinearWeight(Int8DynamicallyQuantizedLinearWeight): """ A Tensor subclass that when applied to a weight used in a linear op/module, changes the linear op to a weight-only quantized linear op with symmetric per-channel quantization on the weight. """ + subclass_constructor = ConstructTensorSubclassInt8wo @staticmethod def _quantized_op(act_mat, w_qtensor, bias): @@ -335,12 +372,21 @@ def _quantized_op(act_mat, w_qtensor, bias): return y.to(orig_dtype) +@torch._dynamo.allow_in_graph +def from_qtensor_components_int4wo(*args, **kwargs): + return Int4WeightOnlyQuantizedLinearWeight(*args, **kwargs) + +class ConstructTensorSubclassInt4wo(ConstructTensorSubclass): + def forward(self, int_data, scales_and_zeros): + return from_qtensor_components_int4wo(int_data, scales_and_zeros, *self.args, **self.kwargs) + class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase): """ A Tensor subclass that when applied to a weight used in a linear op/module, changes that linear op to a weight-only int4 quantized linear op with groupwise affine quantization on the weight. """ + subclass_constructor = ConstructTensorSubclassInt4wo @staticmethod def __new__( @@ -351,9 +397,12 @@ def __new__( shape, groupsize=128, inner_k_tiles=8, + dtype=None, **kwargs, ): - kwargs["dtype"] = kwargs.get("dtype", scales_and_zeros.dtype) + if dtype is None: + dtype = scales_and_zeros.dtype + kwargs["dtype"] = dtype return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] def __init__( @@ -364,6 +413,7 @@ def __init__( shape, groupsize, inner_k_tiles, + dtype, **kwargs, ): # the transposed flag tracks whether the tensor subclass has been transposed relative @@ -372,9 +422,7 @@ def __init__( # square matrices can cause issues otherwise self.scales_and_zeros = scales_and_zeros - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles super().__init__(int_data, transposed) @@ -465,10 +513,10 @@ def _change_shape(self, shape): def __tensor_flatten__(self): return ["int_data", "scales_and_zeros"], ( self.transposed, + self.shape, self.groupsize, self.inner_k_tiles, self.dtype, - self.shape, ) @classmethod @@ -482,7 +530,7 @@ def __tensor_unflatten__( tensor_data_dict["int_data"], tensor_data_dict["scales_and_zeros"], ) - transposed, groupsize, inner_k_tiles, dtype, shape = attributes + transposed, shape, groupsize, inner_k_tiles, dtype = attributes return cls( int_data, scales_and_zeros, @@ -495,7 +543,7 @@ def __tensor_unflatten__( ) @classmethod - def from_float(cls, input_float, groupsize=128, inner_k_tiles=8): + def from_float(cls, input_float, groupsize=128, inner_k_tiles=8, dtype=None): """ Method used to convert a linear weight tensor to an instance of the Int4WeightOnlyQuantizedLinearWeight subclass. @@ -506,9 +554,24 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8): Int4WeightOnlyQuantizedLinearWeight.from_float(model.lin_mod.weight) ) """ + if dtype is None: + dtype = input_float.dtype + + int_data, scales_and_zeros, transposed, groupsize, inner_k_tils = cls.to_qtensor_components(input_float, groupsize, inner_k_tiles) + return cls( + int_data, + scales_and_zeros, + transposed, + input_float.shape, + groupsize, + inner_k_tiles, + dtype=dtype, + ) + + @classmethod + def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): assert groupsize in [256, 128, 64, 32] assert inner_k_tiles in [8, 4, 2] - orig_shape = input_float.shape orig_out_features, orig_in_features = input_float.shape # padding @@ -521,16 +584,7 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8): # quantization and packing input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( - input_float, 4, groupsize + input_float, 4, groupsize, dtype=input_float.dtype ) int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) - - return cls( - int_data, - scales_and_zeros, - False, - orig_shape, - groupsize, - inner_k_tiles, - dtype=input_float.dtype, - ) + return int_data, scales_and_zeros, False, groupsize, inner_k_tiles