From 556c5bb2e68b3279c4581418cb66e08ea08f648f Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 6 Aug 2024 15:11:37 -0700 Subject: [PATCH] temp enable/disable fq Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_qat.py | 25 +++- .../linear_activation_quantized_tensor.py | 25 ++-- .../qat/affine_fake_quantized_tensor.py | 121 +++++++++--------- torchao/quantization/prototype/qat/api.py | 21 ++- torchao/quantization/prototype/qat/utils.py | 33 ++++- 5 files changed, 137 insertions(+), 88 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f58068e42..7f0501c98 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -12,6 +12,12 @@ import torch from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torchao.quantization.linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, +) +from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, +) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, @@ -252,6 +258,13 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): enable_8da4w_fake_quant, ) + def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): + assert isinstance(m.weight, LinearActivationQuantizedTensor) + self.assertEqual(m.weight.input_quant_func_enabled, enabled) + weight = m.weight.original_weight_tensor + self.assertTrue(isinstance(weight, AffineFakeQuantizedTensor)) + self.assertEqual(weight.fake_quant_enabled, enabled) + group_size = 16 torch.manual_seed(self.SEED) m = M() @@ -260,9 +273,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self.assertFalse(qat_model.linear1._fake_quant_enabled) - self.assertFalse(qat_model.linear2._fake_quant_enabled) - self.assertFalse(qat_model.sub.linear._fake_quant_enabled) + assert_fake_quant_enabled(qat_model.linear1, enabled=False) + assert_fake_quant_enabled(qat_model.linear2, enabled=False) + assert_fake_quant_enabled(qat_model.sub.linear, enabled=False) # Disabled fake quant is just a normal linear m2.linear1.weight = qat_model.linear1.weight @@ -277,9 +290,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - self.assertTrue(qat_model.linear1._fake_quant_enabled) - self.assertTrue(qat_model.linear2._fake_quant_enabled) - self.assertTrue(qat_model.sub.linear._fake_quant_enabled) + assert_fake_quant_enabled(qat_model.linear1, enabled=True) + assert_fake_quant_enabled(qat_model.linear2, enabled=True) + assert_fake_quant_enabled(qat_model.sub.linear, enabled=True) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index e4e4fedc4..b79b64a3b 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -22,6 +22,7 @@ def __new__( cls, original_weight_tensor: torch.Tensor, input_quant_func: Callable, + input_quant_func_enabled: bool = True, ): kwargs = {} dtype = original_weight_tensor.dtype @@ -35,22 +36,25 @@ def __init__( self, original_weight_tensor: torch.Tensor, input_quant_func: Callable, + input_quant_func_enabled: bool = True, ): self.original_weight_tensor = original_weight_tensor self.input_quant_func = input_quant_func + self.input_quant_func_enabled = input_quant_func_enabled def __tensor_flatten__(self): - return ["original_weight_tensor"], [self.input_quant_func] + return ["original_weight_tensor"], [self.input_quant_func, self.input_quant_func_enabled] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): original_weight_tensor = tensor_data_dict["original_weight_tensor"] - input_quant_func, = tensor_attributes + (input_quant_func, input_quant_func_enabled) = tensor_attributes return cls( original_weight_tensor, input_quant_func, + input_quant_func_enabled, ) @classmethod @@ -61,8 +65,15 @@ def _apply_fn_to_data(self, fn): return self.__class__( fn(self.original_weight_tensor), self.input_quant_func, + self.input_quant_func_enabled, ) + def apply_input_quant_func(self, t: torch.Tensor): + if self.input_quant_func_enabled: + return self.input_quant_func(t) + else: + return t + def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) device = self.device if device is None else device @@ -82,6 +93,7 @@ def to(self, *args, **kwargs): return self.__class__( self.original_weight_tensor.to(**kwargs), self.input_quant_func, + self.input_quant_func_enabled, ) implements = classmethod(_implements) @@ -98,9 +110,8 @@ def _(func, types, *args, **kwargs): args[2] if len(args) > 2 else None, ) if isinstance(weight_tensor, LinearActivationQuantizedTensor): - input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) + aqt = weight_tensor.apply_input_quant_func(input_tensor) return torch.nn.functional.linear(aqt, original_weight_tensor, bias) raise NotImplementedError("LinearActivationQuantizedTensor: No specialized dispatch found for linear op") @@ -120,9 +131,8 @@ def _(func, types, *args, **kwargs): args[2], args[0], ) - input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) + aqt = weight_tensor.apply_input_quant_func(input_tensor) return func(bias, aqt, original_weight_tensor) else: # aten.mm.default @@ -134,9 +144,8 @@ def _(func, types, *args, **kwargs): args[0], args[1], ) - input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) + aqt = weight_tensor.apply_input_quant_func(input_tensor) return func(aqt, original_weight_tensor) diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py index 9892a1acf..c0f0f734d 100644 --- a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py @@ -1,5 +1,5 @@ import torch -from typing import Tuple, Optional +from typing import Callable, Optional, Tuple from torchao.quantization.quant_primitives import ( _get_and_check_qmin_qmax, choose_qparams_affine, @@ -30,50 +30,47 @@ class AffineFakeQuantizedTensor(torch.Tensor): regardless of the internal representation's type or orientation. fields: - float_data (torch.Tensor): tensor holding the original float values, needed for actual quantization later - fq_data (torch.Tensor): tensor holding the fake quantized values - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization - shape (torch.Size): the shape for the Tensor - quant_min (Optional[int]): minimum quantized value for the Tensor - quant_max (Optional[int]): maximum quantized value for the Tensor - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT + original_tensor (torch.Tensor): tensor holding the original float values, needed for actual quantization later + apply_fake_quant_fn (Callable): function that transforms `original_tensor` to fake quantized values """ @staticmethod def __new__( cls, - float_data: torch.Tensor, - fq_data: torch.Tensor, + original_tensor: torch.Tensor, + apply_fake_quant_fn: Callable, + fake_quant_enabled: bool = True, ): kwargs = {} - kwargs["device"] = float_data.device - kwargs["dtype"] = float_data.dtype + kwargs["device"] = original_tensor.device + kwargs["dtype"] = original_tensor.dtype kwargs["requires_grad"] = True - return torch.Tensor._make_wrapper_subclass(cls, float_data.shape, **kwargs) # type: ignore[attr-defined] + return torch.Tensor._make_wrapper_subclass(cls, original_tensor.shape, **kwargs) # type: ignore[attr-defined] def __init__( self, - float_data: torch.Tensor, - fq_data: torch.Tensor, + original_tensor: torch.Tensor, + apply_fake_quant_fn: Callable, + fake_quant_enabled: bool = True, ): - self.float_data = float_data - self.fq_data = fq_data + self.original_tensor = original_tensor + self.apply_fake_quant_fn = apply_fake_quant_fn + self.fake_quant_enabled = fake_quant_enabled def __tensor_flatten__(self): - return ["float_data", "fq_data"], [] + return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride, ): - float_data = tensor_data_dict["float_data"] - fq_data = tensor_data_dict["fq_data"] - return cls(float_data, fq_data) + original_tensor = tensor_data_dict["original_tensor"] + (apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes + return cls( + original_tensor, + apply_fake_quant_fn, + fake_quant_enabled, + ) @classmethod def from_float( @@ -90,30 +87,35 @@ def from_float( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): - quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) - scale, zero_point = choose_qparams_affine( - input_float, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - preserve_zero, - zero_point_domain, - ) - fq_data = _GenericFakeQuantize.apply( - input_float, - block_size, - scale, - zero_point, - quant_min, - quant_max, - zero_point_domain, - ) - return cls(input_float, fq_data) + def apply_fake_quant_fn(t: torch.Tensor): + qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + scale, zero_point = choose_qparams_affine( + t, + mapping_type, + block_size, + target_dtype, + qmin, + qmax, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + fq = _GenericFakeQuantize.apply( + t, + block_size, + scale, + zero_point, + qmin, + qmax, + zero_point_domain, + ) + return fq + return cls(input_float, apply_fake_quant_fn) + + def to_fake_quantized(self) -> torch.Tensor: + return self.apply_fake_quant_fn(self.original_tensor) def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) @@ -135,13 +137,18 @@ def to(self, *args, **kwargs): # not supported yet kwargs.pop("memory_format") return self.__class__( - self.float_data.to(device), - self.fq_data.to(device), + self.original_tensor.to(device), + self.apply_fake_quant_fn, + self.fake_quant_enabled, **kwargs, ) def _apply_fn_to_data(self, fn): - return self.__class__(self.float_data, fn(self.fq_data)) + return self.__class__( + fn(self.original_tensor), + self.apply_fake_quant_fn, + self.fake_quant_enabled, + ) implements = classmethod(_implements) __torch_function__ = classmethod(_dispatch__torch_function__) @@ -158,9 +165,9 @@ def _(func, types, *args, **kwargs): args[2] if len(args) > 2 else None, ) if isinstance(input_tensor, AffineFakeQuantizedTensor): - input_tensor = input_tensor.fq_data + input_tensor = input_tensor.to_fake_quantized() if isinstance(weight_tensor, AffineFakeQuantizedTensor): - weight_tensor = weight_tensor.fq_data + weight_tensor = weight_tensor.to_fake_quantized() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @implements([aten.mm.default, aten.addmm.default]) @@ -174,9 +181,9 @@ def _(func, types, *args, **kwargs): input_tensor = args[input_index] weight_tensor = args[input_index + 1] if isinstance(input_tensor, AffineFakeQuantizedTensor): - input_tensor = input_tensor.fq_data + input_tensor = input_tensor.to_fake_quantized() if isinstance(weight_tensor, AffineFakeQuantizedTensor): - weight_tensor = weight_tensor.fq_data + weight_tensor = weight_tensor.to_fake_quantized() if bias is not None: return func(bias, input_tensor, weight_tensor) else: diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 334be0504..58b761905 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -40,6 +40,7 @@ from .affine_fake_quantized_tensor import to_affine_fake_quantized from .utils import ( _choose_qparams_per_token_asymmetric, + _enable_fake_quant, _fake_quantize_per_channel_group, _fake_quantize_per_token, _is_linear_with_fq_weight, @@ -230,17 +231,15 @@ def _get_qmin_qmax(self, n_bit: int): def enable_8da4w_fake_quant(mod: torch.nn.Module): """ - Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + Enable fake quantization for int8 dynamic activations + int4 weight. """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.enable_fake_quant() + _enable_fake_quant(mod, enable=True) def disable_8da4w_fake_quant(mod: torch.nn.Module): """ - Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + Disable fake quantization for int8 dynamic activations + int4 weight. """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.disable_fake_quant() + _enable_fake_quant(mod, enable=False) # ================== @@ -390,14 +389,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def enable_4w_fake_quant(mod: torch.nn.Module): """ - Enable fake quantization for `Int4WeightOnlyQATLinear`. + Enable fake quantization for int4 weight only. """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.enable_fake_quant() + _enable_fake_quant(mod, enable=True) def disable_4w_fake_quant(mod: torch.nn.Module): """ - Disable fake quantization for `Int4WeightOnlyQATLinear`. + Disable fake quantization for int4 weight only. """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.disable_fake_quant() + _enable_fake_quant(mod, enable=False) diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 8ac0d1912..01bd6c93b 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -153,18 +153,41 @@ def _unwrap_affine_fake_quantized_tensor(t: torch.Tensor): AffineFakeQuantizedTensor, ) assert isinstance(t, AffineFakeQuantizedTensor) - return t.float_data + return t.original_tensor def _is_linear_with_fq_weight(mod: torch.nn.Module, *args): """ Return whether this is a nn.Linear module with `AffineFakeQuantizeTensor` weights. """ # avoid circular dependencies + from torchao.quantization.linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, + ) from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) - return ( - isinstance(mod, torch.nn.Linear) - and hasattr(mod, "weight") - and isinstance(mod.weight, AffineFakeQuantizedTensor) + if not isinstance(mod, torch.nn.Linear) or not hasattr(mod, "weight"): + return False + weight = mod.weight + if isinstance(weight, LinearActivationQuantizedTensor): + weight = weight.original_weight_tensor + return isinstance(weight, AffineFakeQuantizedTensor) + +def _enable_fake_quant(mod: torch.nn.Module, enable: bool): + """ + Enable or disable fake quantization in the activations and weights of a `nn.Linear` module. + """ + from torchao.quantization.linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, + ) + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, ) + if not _is_linear_with_fq_weight(mod): + return + weight = mod.weight + if isinstance(weight, LinearActivationQuantizedTensor): + weight.input_quant_func_enabled = enable + weight = weight.original_weight_tensor + assert isinstance(weight, AffineFakeQuantizedTensor) + weight.fake_quant_enabled = enable