From 61bd8d3968ab8bcf19f0131567cfdd3cc99f332a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 15 Aug 2024 19:21:07 -0700 Subject: [PATCH] Refactor QAT to use tensor subclasses This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce `AffineFakeQuantizedTensor`, which is analogous to `AffineQuantizedTensor` but applies fake quantization instead and requires gradient updates. `AffineFakeQuantizedTensor` wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (`AffineFakeQuantizedTensor`) and never to the inner tensor. For weights, the outer tensor is also a `torch.nn.Parameter`, and gradient updates received by the outer tensor are then passed to the inner tensor through ops like `aten.add_` and `aten.mul_`. An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module `forward_pre_hook` instead of relying on another subclass `LinearActivationQuantizedTensor` that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under `__torch_dispatch__`, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register a `forward_pre_hook` that wraps the input activations before each call to forward. This approach is also motivated by how [DTensor wraps their subclasses ](https://github.com/pytorch/pytorch/blob/844103197d3e8cf6b4b59176e473365113f4f962/torch/distributed/tensor/parallel/style.py#L521). - [x] Add AffineFakeQuantizedTensor - [x] Add support for int4 weight only fake quantize - [x] Add support for int8 dynamic activations + int4 weight fake quantize (8da4w) - [x] Add prepare and convert path to int4 QAT quantizer - [x] Add prepare and convert path to 8da4w QAT quantizer - [x] Support enabling and disabling fake quant dynamically - [x] Support `__repr__` in AffineFakeQuantizedTensor - [x] Fix backward pass for int4 weight only - [x] Fix backward pass for int8 dynamic activations + int4 weight --- test/quantization/test_qat.py | 147 ++++++--- .../quantization/prototype/qat/__init__.py | 4 + .../qat/affine_fake_quantized_tensor.py | 295 ++++++++++++++++++ torchao/quantization/prototype/qat/api.py | 225 +++++++------ torchao/quantization/prototype/qat/utils.py | 133 +++++++- torchao/quantization/quant_api.py | 8 +- 6 files changed, 670 insertions(+), 142 deletions(-) create mode 100644 torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 022e1d622..b329577a1 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -12,11 +12,22 @@ import torch from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torchao.dtypes import ( + TensorCoreTiledLayoutType, +) +from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, +) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, _GenericFakeQuantize, + _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, +) +from torchao.quantization.quant_api import ( + int4_weight_only, + quantize_, ) from torchao.quantization.quant_primitives import ( fake_quantize_affine, @@ -190,6 +201,7 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + # TODO: compare against quantize_ API instead @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer @@ -217,13 +229,6 @@ def test_qat_8da4w_quantizer(self): converted_out = converted_model(*x) torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) - # Compare converted state dict - ptq_state_dict = ptq_model.state_dict() - converted_state_dict = converted_model.state_dict() - self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) - for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer @@ -236,6 +241,20 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) + def _copy_subclass_weights( + self, + nn_linear: torch.nn.Linear, + subclass_linear: AffineFakeQuantizedTensor, + ): + nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor) + + def _assert_matches_subclass_weights( + self, + nn_linear: torch.nn.Linear, + subclass_linear: AffineFakeQuantizedTensor, + ): + torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ @@ -247,6 +266,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): enable_8da4w_fake_quant, ) + def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): + self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor)) + self.assertEqual(m.weight.fake_quant_enabled, enabled) + self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)) + (_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) + if enabled: + self.assertIsNotNone(handle) + else: + self.assertIsNone(handle) + group_size = 16 torch.manual_seed(self.SEED) m = M() @@ -255,14 +284,14 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self.assertFalse(qat_model.linear1._fake_quant_enabled) - self.assertFalse(qat_model.linear2._fake_quant_enabled) - self.assertFalse(qat_model.sub.linear._fake_quant_enabled) + assert_fake_quant_enabled(qat_model.linear1, enabled=False) + assert_fake_quant_enabled(qat_model.linear2, enabled=False) + assert_fake_quant_enabled(qat_model.sub.linear, enabled=False) # Disabled fake quant is just a normal linear - m2.linear1.weight = qat_model.linear1.weight - m2.linear2.weight = qat_model.linear2.weight - m2.sub.linear.weight = qat_model.sub.linear.weight + self._copy_subclass_weights(m2.linear1, qat_model.linear1) + self._copy_subclass_weights(m2.linear2, qat_model.linear2) + self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear) torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -272,16 +301,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - self.assertTrue(qat_model.linear1._fake_quant_enabled) - self.assertTrue(qat_model.linear2._fake_quant_enabled) - self.assertTrue(qat_model.sub.linear._fake_quant_enabled) + assert_fake_quant_enabled(qat_model.linear1, enabled=True) + assert_fake_quant_enabled(qat_model.linear2, enabled=True) + assert_fake_quant_enabled(qat_model.sub.linear, enabled=True) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model2 = quantizer2.prepare(m3) - qat_model2.linear1.weight = qat_model.linear1.weight - qat_model2.linear2.weight = qat_model.linear2.weight - qat_model2.sub.linear.weight = qat_model.sub.linear.weight + qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor + qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor + qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -306,9 +335,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - nn_model.linear1.weight = qat_model.linear1.weight - nn_model.linear2.weight = qat_model.linear2.weight - nn_model.sub.linear.weight = qat_model.sub.linear.weight + self._copy_subclass_weights(nn_model.linear1, qat_model.linear1) + self._copy_subclass_weights(nn_model.linear2, qat_model.linear2) + self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) # Simulate training for both models optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) @@ -330,9 +359,55 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) + self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1) + self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2) + self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) + + def _test_qat_quantized_gradients(self, quantizer): + """ + Test that QAT produces gradients in the backward pass. + """ + num_steps = 10 + torch.manual_seed(self.SEED) + m = M() + model = quantizer.prepare(m) + optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + loss_fn = torch.nn.CrossEntropyLoss() + + # Simulate training + current_step = 0 + last_linear1_grad = None + last_linear2_grad = None + last_sub_linear_grad = None + while current_step < num_steps: + example_inputs = model.example_inputs() + target = torch.randn(1, 512).float() + output = model(*example_inputs) + loss = loss_fn(output, target) + loss.backward() + # assert each linear grad is updated + new_linear1_grad = model.linear1.weight.grad + new_linear2_grad = model.linear2.weight.grad + new_sub_linear_grad = model.sub.linear.weight.grad + self.assertIsNotNone(new_linear1_grad) + self.assertIsNotNone(new_linear2_grad) + self.assertIsNotNone(new_sub_linear_grad) + if current_step > 0: + self.assertFalse(torch.equal(last_linear1_grad, new_linear1_grad)) + self.assertFalse(torch.equal(last_linear2_grad, new_linear2_grad)) + self.assertFalse(torch.equal(last_sub_linear_grad, new_sub_linear_grad)) + last_linear1_grad = new_linear1_grad + last_linear2_grad = new_linear2_grad + last_sub_linear_grad = new_sub_linear_grad + optimizer.zero_grad() + optimizer.step() + current_step += 1 + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_qat_8da4w_quantizer_gradients(self): + from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) + self._test_qat_quantized_gradients(quantizer) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_generic_fake_quantize(self): @@ -353,7 +428,7 @@ def test_qat_generic_fake_quantize(self): block_size = (1, ao_input.shape[-1]) ao_s = copy.deepcopy(py_s) ao_zp = copy.deepcopy(py_zp) - ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax, block_size) + ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax) ao_out.sum().backward() torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) @@ -373,10 +448,7 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") - # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 @@ -464,11 +536,9 @@ def test_qat_4w_quantizer(self): qat_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, - ) qat_model = qat_quantizer.prepare(m) - ptq_model = ptq_quantizer.quantize(m2) + ptq_model = m2 + quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles))) # Compare model values torch.manual_seed(self.SEED) @@ -483,12 +553,11 @@ def test_qat_4w_quantizer(self): converted_out = converted_model(*x) torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) - # Compare converted state dict - ptq_state_dict = ptq_model.state_dict() - converted_state_dict = converted_model.state_dict() - self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) - for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_qat_4w_quantizer_gradients(self): + from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) + self._test_qat_quantized_gradients(quantizer) if __name__ == "__main__": diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index ed9701de5..ccb7aac0e 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -3,6 +3,8 @@ disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, + int4_weight_only_fake_quantize, + int8_dynamic_activation_int4_weight_fake_quantize, Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, Int8DynActInt4WeightQATLinear, @@ -13,6 +15,8 @@ "disable_8da4w_fake_quant", "enable_4w_fake_quant", "enable_8da4w_fake_quant", + "int4_weight_only_fake_quantize", + "int8_dynamic_activation_int4_weight_fake_quantize", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", "Int8DynActInt4WeightQATLinear", diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py new file mode 100644 index 000000000..9f629e7a5 --- /dev/null +++ b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py @@ -0,0 +1,295 @@ +import torch +import torch.utils._pytree as pytree +from typing import Callable, Optional, Tuple +from torchao.quantization.quant_primitives import ( + _get_and_check_qmin_qmax, + choose_qparams_affine, + fake_quantize_affine, + ZeroPointDomain, + MappingType, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.utils import ( + _implements, + _dispatch__torch_function__, + _dispatch__torch_dispatch__, +) +from .utils import ( + _GenericFakeQuantize, + _UnwrapAffineFakeQuantizedTensor, +) + +aten = torch.ops.aten + +class AffineFakeQuantizedTensor(torch.Tensor): + """ + Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor + with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point + + Fake quantization refers to performing the quantization math without actually casting the floating point + tensor into lower bit-width dtypes. It is commonly used for quantization-aware training (QAT). + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. + + fields: + original_tensor (torch.Tensor): tensor holding the original float values, needed for actual quantization later + apply_fake_quant_fn (Callable): function that transforms `original_tensor` to fake quantized values + """ + + @staticmethod + def __new__( + cls, + original_tensor: torch.Tensor, + apply_fake_quant_fn: Callable, + fake_quant_enabled: bool = True, + **kwargs, + ): + kwargs.setdefault("dtype", original_tensor.dtype) + kwargs.setdefault("device", original_tensor.device) + kwargs.setdefault("requires_grad", original_tensor.requires_grad) + return torch.Tensor._make_wrapper_subclass( + cls, + original_tensor.shape, + **kwargs, + ) + + def __init__( + self, + original_tensor: torch.Tensor, + apply_fake_quant_fn: Callable, + fake_quant_enabled: bool = True, + **kwargs + ): + self.original_tensor = original_tensor + self.apply_fake_quant_fn = apply_fake_quant_fn + self.fake_quant_enabled = fake_quant_enabled + + def __tensor_flatten__(self): + return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride, + ): + original_tensor = tensor_data_dict["original_tensor"] + (apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes + return cls( + original_tensor, + apply_fake_quant_fn, + fake_quant_enabled, + ) + + @classmethod + def from_float( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ): + def apply_fake_quant_fn(t: torch.Tensor): + assert isinstance(t, cls) + qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + scale, zero_point = choose_qparams_affine( + t.original_tensor, + mapping_type, + block_size, + target_dtype, + qmin, + qmax, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + fq = _GenericFakeQuantize.apply( + t, + block_size, + scale, + zero_point, + qmin, + qmax, + zero_point_domain, + ) + return fq + return cls( + input_float, + apply_fake_quant_fn, + fake_quant_enabled=True, + ) + + def get_value(self) -> torch.Tensor: + if self.fake_quant_enabled: + return self.apply_fake_quant_fn(self) + else: + return _UnwrapAffineFakeQuantizedTensor.apply(self) + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + "requires_grad": self.requires_grad, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + # not supported yet + kwargs.pop("memory_format") + return self.__class__( + self.original_tensor.to(device), + self.apply_fake_quant_fn, + self.fake_quant_enabled, + **kwargs, + ) + + def _apply_fn_to_data(self, fn: Callable): + """ + Create a new `AffineFakeQuantizedTensor` with `fn` applied to the + original tensor, to be called within __torch_dispatch__. + """ + return self._create_new(fn(self.original_tensor)) + + def _create_new(self, new_value: torch.Tensor): + """ + Create a new `AffineFakeQuantizedTensor` with a new value, + to be called within __torch_dispatch__. + + Note: `requires_grad` must be False here because tensors created + in `__torch_dispatch__` cannot produce gradients, since autograd + will try to attach autograd metadata to these tensors when we exit + `__torch_dispatch__`, but if these tensors already have metadata + attached then autograd will throw an error. + """ + return self.__class__( + new_value, + self.apply_fake_quant_fn, + self.fake_quant_enabled, + requires_grad=False, + ) + + implements = classmethod(_implements) + __torch_function__ = classmethod(_dispatch__torch_function__) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + +implements = AffineFakeQuantizedTensor.implements + + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(aten.mm.default) +def _(func, types, args, kwargs): + bias = None + input_tensor = args[0] + weight_tensor = args[1] + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return func(input_tensor, weight_tensor) + + +@implements(aten.addmm.default) +def _(func, types, args, kwargs): + bias = args[0] + input_tensor = args[1] + weight_tensor = args[2] + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return func(bias, input_tensor, weight_tensor) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach), + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone), + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t), + ) + + +@implements([ + aten.add.Tensor, + aten.add_.Tensor, + aten.mul_.Tensor, + aten.copy_.default, +]) +def _(func, types, args, kwargs): + assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}" + new_args = pytree.tree_map_only(AffineFakeQuantizedTensor, lambda x: x.original_tensor, args) + first_afq_tensor = args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1] + new_value = func(*new_args, **kwargs) + out = first_afq_tensor._create_new(new_value) + return return_and_correct_aliasing(func, args, kwargs, out) + + +# Needed by FSDP: + +@implements(aten.empty_like.default) +def _(func, types, args, kwargs): + out = torch.empty_like(args[0].original_tensor, **kwargs) + return return_and_correct_aliasing(func, args, kwargs, out) + + +@implements(aten.split.Tensor) +def _(func, types, args, kwargs): + new_values = torch.split(args[0].original_tensor, *args[1:], **kwargs) + + def make_new_tensor(value): + out = args[0]._create_new(value) + return return_and_correct_aliasing(func, args, kwargs, out) + + return list(map(make_new_tensor, new_values)) + + +@implements(aten.new_zeros.default) +def _(func, types, args, kwargs): + out = args[0].original_tensor.new_zeros(*args[1:], **kwargs) + return return_and_correct_aliasing(func, args, kwargs, out) + + +to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 668b73787..6cd9a704b 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -9,6 +9,9 @@ import torch import torch.nn.functional as F +from torchao.dtypes import ( + TensorCoreTiledLayoutType, +) from torchao.quantization.GPTQ import ( _check_linear_int4_k, _replace_linear_int4, @@ -18,13 +21,34 @@ Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) -from torchao.quantization.quant_primitives import ZeroPointDomain +from torchao.quantization.linear_activation_quantized_tensor import ( + to_linear_activation_quantized, +) +from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, + _replace_with_custom_fn_if_matches_filter, + int4_weight_only, + int8_dynamic_activation_int4_weight, + quantize_, +) +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.quantization.utils import ( + _get_per_token_block_size, + get_group_qparams_symmetric, +) +from .affine_fake_quantized_tensor import to_affine_fake_quantized from .utils import ( _choose_qparams_per_token_asymmetric, + _enable_fake_quant, _fake_quantize_per_channel_group, _fake_quantize_per_token, + _get_qat_linear_subclass_inserter, + _is_linear_with_fq_weight, + _unwrap_affine_fake_quantized_tensor, ) @@ -32,6 +56,51 @@ # | 8da4w QAT | # ================= +def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32): + """ + Applies int8 dynamic per token asymmetric activation fake quantization and + int4 per group weight symmetric fake quantization to linear. Please see + :func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details. + + Example usage: + from torchao.quantization import quantize_ + quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32)) + """ + # avoid circular dep + from torchao.dtypes import to_affine_quantized + + def _apply_weight_fake_quant(weight: torch.Tensor): + mapping_type = MappingType.SYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + quant_min = -8 + quant_max = 7 + return to_affine_fake_quantized( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + ) + + def _apply_input_activation_fake_quant(x: torch.Tensor): + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int8 + return to_affine_fake_quantized( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + ) + + return _get_qat_linear_subclass_inserter( + _apply_weight_fake_quant, + _apply_input_activation_fake_quant, + ) + class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int8 @@ -58,14 +127,9 @@ def prepare( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _replace_linear_8da4w( + quantize_( model, - self.groupsize, - self.padding_allowed, - self.precision, - self.scales_precision, - Int8DynActInt4WeightQATLinear, - copy_weights=True, + int8_dynamic_activation_int4_weight_fake_quantize(group_size=self.groupsize), ) return model @@ -75,39 +139,14 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_8da4w(model) + unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) + filter_fn = _is_linear_with_fq_weight + model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) + quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize) + quantize_(model, quantize_fn) return model -def _convert_qat_linear_8da4w(module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=child.groupsize, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = child._get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - _convert_qat_linear_8da4w(child) - +# TODO: deprecate class Int8DynActInt4WeightQATLinear(torch.nn.Linear): """ This module implements a linear layer with int8 dynamic per token fake @@ -194,23 +233,54 @@ def _get_qmin_qmax(self, n_bit: int): def enable_8da4w_fake_quant(mod: torch.nn.Module): """ - Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + Enable fake quantization for int8 dynamic activations + int4 weight. """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.enable_fake_quant() + _enable_fake_quant(mod, enable=True) def disable_8da4w_fake_quant(mod: torch.nn.Module): """ - Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + Disable fake quantization for int8 dynamic activations + int4 weight. """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.disable_fake_quant() + _enable_fake_quant(mod, enable=False) # ================== # | int4wo QAT | # ================== +def int4_weight_only_fake_quantize(group_size=128): + """ + Applies uint4 weight-only asymmetric per-group fake quantization to linear layers. + Please see :func:`~torchao.quantization.int4_weight_only` for more details. + + Example usage: + from torchao.quantization import quantize_ + quantize_(model, int4_weight_only_fake_quantize(group_size=32)) + """ + def _apply_fake_quant(weight): + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return to_affine_fake_quantized( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + return _get_qat_linear_subclass_inserter(_apply_fake_quant) + class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): """ Quantizer for performing QAT on a model, where linear layers have @@ -238,16 +308,7 @@ def prepare( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _replace_linear_int4( - model, - self.groupsize, - self.inner_k_tiles, - padding_allowed=True, - precision=self.precision, - scales_precision=self.scales_precision, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True, - ) + quantize_(model, int4_weight_only_fake_quantize(group_size=self.groupsize)) return model def convert( @@ -256,43 +317,15 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_4w(model) + unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) + filter_fn = _is_linear_with_fq_weight + model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) + layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles) + quantize_fn = int4_weight_only(self.groupsize, layout_type) + quantize_(model, quantize_fn) return model -def _convert_qat_linear_4w(module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - groupsize = child.groupsize - inner_k_tiles = child.inner_k_tiles - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, child.groupsize, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - _convert_qat_linear_4w(child) - +# TODO: deprecate class Int4WeightOnlyQATLinear(torch.nn.Linear): """ This module implements a linear layer with int4 fake quantized grouped @@ -359,14 +392,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def enable_4w_fake_quant(mod: torch.nn.Module): """ - Enable fake quantization for `Int4WeightOnlyQATLinear`. + Enable fake quantization for int4 weight only. """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.enable_fake_quant() + _enable_fake_quant(mod, enable=True) def disable_4w_fake_quant(mod: torch.nn.Module): """ - Disable fake quantization for `Int4WeightOnlyQATLinear`. + Disable fake quantization for int4 weight only. """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.disable_fake_quant() + _enable_fake_quant(mod, enable=False) diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 580df1ea4..625da4e39 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Tuple +from typing import Callable, List, Optional, Tuple import torch @@ -17,6 +17,15 @@ ) +# Attribute name representing the forward prehook wrapping the +# linear input in an `AffineFakeQuantizedTensor` on a linear module. +# +# The value of this attribute is a 2-tuple of (prehook, handle). +# The prehook can be disabled by calling `handle.remove()`, and +# re-enabled by calling `module.register_forward_pre_hook(prehook)`. +_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK = "_qat_linear_subclass_input_prehook" + + class _GenericFakeQuantize(torch.autograd.Function): """ Implementation of generic fake quantize with backward STE. @@ -29,15 +38,25 @@ class _GenericFakeQuantize(torch.autograd.Function): def forward( ctx: torch.autograd.function.FunctionCtx, input: torch.Tensor, + block_size: List[int], scales: torch.Tensor, zero_points: torch.Tensor, quant_min: int, quant_max: int, - block_size: List[int], zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + + if isinstance(input, AffineFakeQuantizedTensor): + _input = input.original_tensor + else: + _input = input + (fq, mask) = fake_quantize_affine_cachemask( - input, + _input, block_size, scales, zero_points, @@ -55,6 +74,31 @@ def backward(ctx, gy): (mask,) = ctx.saved_tensors return gy * mask, None, None, None, None, None, None + +class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function): + """ + Helper autograd function to unwrap `AffineFakeQuantizedTensor` while ensuring + gradients are still passed to the tensor subclass. This is used in place of + `_GenericFakeQuantize` when fake quant is disabled. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + input: torch.Tensor, + ) -> torch.Tensor: + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + assert isinstance(input, AffineFakeQuantizedTensor) + return input.original_tensor + + @staticmethod + def backward(ctx, gy): + return gy, + + def _fake_quantize_per_channel_group( input: torch.Tensor, scales: torch.Tensor, @@ -69,7 +113,7 @@ def _fake_quantize_per_channel_group( assert input.dim() == 2 block_size = (1, group_size) return _GenericFakeQuantize.apply( - input, scales, zero_points, quant_min, quant_max, block_size, zero_point_domain, + input, block_size, scales, zero_points, quant_min, quant_max, zero_point_domain, ) def _fake_quantize_per_token( @@ -85,7 +129,7 @@ def _fake_quantize_per_token( block_size = _get_per_token_block_size(input) fq_input = input.to(torch.float32) fq = _GenericFakeQuantize.apply( - fq_input, scales, zero_points, quant_min, quant_max, block_size, + fq_input, block_size, scales, zero_points, quant_min, quant_max, ) return fq.reshape_as(input).to(input.dtype) @@ -136,3 +180,82 @@ def _choose_qparams_per_token_asymmetric( zero_point = torch.clamp(zero_point, qmin, qmax).round() return scale.to(scales_precision), zero_point.to(zero_points_precision) + +def _forward_pre_hook_handler( + mod: torch.nn.Linear, + prehook: Callable, + handler: torch.utils.hooks.RemovableHandle, +): + """ + Store a 2-tuple (prehook function, handler) as an attribute on the given linear module. + """ + setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handler)) + +def _unwrap_affine_fake_quantized_tensor(t: torch.Tensor): + """ + Return the original, non-fake-quantized float tensor from a `AffineFakeQuantizedTensor`. + """ + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + assert isinstance(t, AffineFakeQuantizedTensor) + return t.original_tensor + +def _is_linear_with_fq_weight(mod: torch.nn.Module, *args): + """ + Return whether this is a nn.Linear module with `AffineFakeQuantizeTensor` weights. + """ + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + if not isinstance(mod, torch.nn.Linear) or not hasattr(mod, "weight"): + return False + weight = mod.weight + return isinstance(weight, AffineFakeQuantizedTensor) + +def _enable_fake_quant(mod: torch.nn.Module, enable: bool): + """ + Enable or disable fake quantization in the activations and weights of a `nn.Linear` module. + """ + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + if not _is_linear_with_fq_weight(mod): + return + weight = mod.weight + assert isinstance(weight, AffineFakeQuantizedTensor) + weight.fake_quant_enabled = enable + + # Enable/disable input fake quant + if hasattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK): + (prehook, handle) = getattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) + if enable and handle is None: + handle = mod.register_forward_pre_hook(prehook) + elif not enable and handle is not None: + handle.remove() + handle = None + setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) + +def _get_qat_linear_subclass_inserter( + weight_constructor: Callable, + input_constructor: Optional[Callable] = None, +) -> Callable: + """ + Return a function that inserts wraps the weight and/or input activation of a + linear module in tensor subclasses. + + Args: + weight_constructor: constructor of the weight subclass, accepts a tensor + input_constructor: (optional) constructor of the input subclass, accepts a tensor + """ + def insert_subclass(lin): + lin.weight = torch.nn.Parameter(weight_constructor(lin.weight), requires_grad=True) + if input_constructor is not None: + prehook = lambda _, args: tuple([input_constructor(args[0])] + list(args[1:])) + handle = lin.register_forward_pre_hook(prehook) + setattr(lin, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) + return lin + + return insert_subclass diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a0ad665ea..95fa63b39 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -31,6 +31,7 @@ ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass, ) from .subclass import ( @@ -55,7 +56,6 @@ from .utils import _get_per_token_block_size import logging from .autoquant import autoquant, AutoQuantizableLinearWeight -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 __all__ = [ @@ -189,6 +189,11 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + # adding weight tensor subclass isinstance check to make sure the weight is only quantized once # when it is shared by multiple linear modules return ( @@ -198,6 +203,7 @@ def _is_linear(mod, *args): and not isinstance(mod.weight, AutoQuantizableLinearWeight) and not isinstance(mod.weight, AffineQuantizedTensor) and not isinstance(mod.weight, LinearActivationQuantizedTensor) + and not isinstance(mod.weight, AffineFakeQuantizedTensor) ) import torch.nn.utils.parametrize as parametrize