diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 022e1d622..b329577a1 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -12,11 +12,22 @@ import torch from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torchao.dtypes import ( + TensorCoreTiledLayoutType, +) +from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, +) from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, _GenericFakeQuantize, + _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, +) +from torchao.quantization.quant_api import ( + int4_weight_only, + quantize_, ) from torchao.quantization.quant_primitives import ( fake_quantize_affine, @@ -190,6 +201,7 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + # TODO: compare against quantize_ API instead @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer @@ -217,13 +229,6 @@ def test_qat_8da4w_quantizer(self): converted_out = converted_model(*x) torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) - # Compare converted state dict - ptq_state_dict = ptq_model.state_dict() - converted_state_dict = converted_model.state_dict() - self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) - for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer @@ -236,6 +241,20 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) + def _copy_subclass_weights( + self, + nn_linear: torch.nn.Linear, + subclass_linear: AffineFakeQuantizedTensor, + ): + nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor) + + def _assert_matches_subclass_weights( + self, + nn_linear: torch.nn.Linear, + subclass_linear: AffineFakeQuantizedTensor, + ): + torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ @@ -247,6 +266,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): enable_8da4w_fake_quant, ) + def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): + self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor)) + self.assertEqual(m.weight.fake_quant_enabled, enabled) + self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)) + (_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) + if enabled: + self.assertIsNotNone(handle) + else: + self.assertIsNone(handle) + group_size = 16 torch.manual_seed(self.SEED) m = M() @@ -255,14 +284,14 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - self.assertFalse(qat_model.linear1._fake_quant_enabled) - self.assertFalse(qat_model.linear2._fake_quant_enabled) - self.assertFalse(qat_model.sub.linear._fake_quant_enabled) + assert_fake_quant_enabled(qat_model.linear1, enabled=False) + assert_fake_quant_enabled(qat_model.linear2, enabled=False) + assert_fake_quant_enabled(qat_model.sub.linear, enabled=False) # Disabled fake quant is just a normal linear - m2.linear1.weight = qat_model.linear1.weight - m2.linear2.weight = qat_model.linear2.weight - m2.sub.linear.weight = qat_model.sub.linear.weight + self._copy_subclass_weights(m2.linear1, qat_model.linear1) + self._copy_subclass_weights(m2.linear2, qat_model.linear2) + self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear) torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -272,16 +301,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): # Renable fake quant qat_model.apply(enable_8da4w_fake_quant) - self.assertTrue(qat_model.linear1._fake_quant_enabled) - self.assertTrue(qat_model.linear2._fake_quant_enabled) - self.assertTrue(qat_model.sub.linear._fake_quant_enabled) + assert_fake_quant_enabled(qat_model.linear1, enabled=True) + assert_fake_quant_enabled(qat_model.linear2, enabled=True) + assert_fake_quant_enabled(qat_model.sub.linear, enabled=True) # Fake quant should be applied as normal quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model2 = quantizer2.prepare(m3) - qat_model2.linear1.weight = qat_model.linear1.weight - qat_model2.linear2.weight = qat_model.linear2.weight - qat_model2.sub.linear.weight = qat_model.sub.linear.weight + qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor + qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor + qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -306,9 +335,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - nn_model.linear1.weight = qat_model.linear1.weight - nn_model.linear2.weight = qat_model.linear2.weight - nn_model.sub.linear.weight = qat_model.sub.linear.weight + self._copy_subclass_weights(nn_model.linear1, qat_model.linear1) + self._copy_subclass_weights(nn_model.linear2, qat_model.linear2) + self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) # Simulate training for both models optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) @@ -330,9 +359,55 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) + self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1) + self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2) + self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) + + def _test_qat_quantized_gradients(self, quantizer): + """ + Test that QAT produces gradients in the backward pass. + """ + num_steps = 10 + torch.manual_seed(self.SEED) + m = M() + model = quantizer.prepare(m) + optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + loss_fn = torch.nn.CrossEntropyLoss() + + # Simulate training + current_step = 0 + last_linear1_grad = None + last_linear2_grad = None + last_sub_linear_grad = None + while current_step < num_steps: + example_inputs = model.example_inputs() + target = torch.randn(1, 512).float() + output = model(*example_inputs) + loss = loss_fn(output, target) + loss.backward() + # assert each linear grad is updated + new_linear1_grad = model.linear1.weight.grad + new_linear2_grad = model.linear2.weight.grad + new_sub_linear_grad = model.sub.linear.weight.grad + self.assertIsNotNone(new_linear1_grad) + self.assertIsNotNone(new_linear2_grad) + self.assertIsNotNone(new_sub_linear_grad) + if current_step > 0: + self.assertFalse(torch.equal(last_linear1_grad, new_linear1_grad)) + self.assertFalse(torch.equal(last_linear2_grad, new_linear2_grad)) + self.assertFalse(torch.equal(last_sub_linear_grad, new_sub_linear_grad)) + last_linear1_grad = new_linear1_grad + last_linear2_grad = new_linear2_grad + last_sub_linear_grad = new_sub_linear_grad + optimizer.zero_grad() + optimizer.step() + current_step += 1 + + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_qat_8da4w_quantizer_gradients(self): + from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) + self._test_qat_quantized_gradients(quantizer) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_generic_fake_quantize(self): @@ -353,7 +428,7 @@ def test_qat_generic_fake_quantize(self): block_size = (1, ao_input.shape[-1]) ao_s = copy.deepcopy(py_s) ao_zp = copy.deepcopy(py_zp) - ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax, block_size) + ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax) ao_out.sum().backward() torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) @@ -373,10 +448,7 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") - # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 @@ -464,11 +536,9 @@ def test_qat_4w_quantizer(self): qat_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, ) - ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, - ) qat_model = qat_quantizer.prepare(m) - ptq_model = ptq_quantizer.quantize(m2) + ptq_model = m2 + quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles))) # Compare model values torch.manual_seed(self.SEED) @@ -483,12 +553,11 @@ def test_qat_4w_quantizer(self): converted_out = converted_model(*x) torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) - # Compare converted state dict - ptq_state_dict = ptq_model.state_dict() - converted_state_dict = converted_model.state_dict() - self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) - for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + def test_qat_4w_quantizer_gradients(self): + from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) + self._test_qat_quantized_gradients(quantizer) if __name__ == "__main__": diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py index ed9701de5..ccb7aac0e 100644 --- a/torchao/quantization/prototype/qat/__init__.py +++ b/torchao/quantization/prototype/qat/__init__.py @@ -3,6 +3,8 @@ disable_8da4w_fake_quant, enable_4w_fake_quant, enable_8da4w_fake_quant, + int4_weight_only_fake_quantize, + int8_dynamic_activation_int4_weight_fake_quantize, Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, Int8DynActInt4WeightQATLinear, @@ -13,6 +15,8 @@ "disable_8da4w_fake_quant", "enable_4w_fake_quant", "enable_8da4w_fake_quant", + "int4_weight_only_fake_quantize", + "int8_dynamic_activation_int4_weight_fake_quantize", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", "Int8DynActInt4WeightQATLinear", diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py new file mode 100644 index 000000000..c5f820477 --- /dev/null +++ b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py @@ -0,0 +1,336 @@ +import torch +import torch.utils._pytree as pytree +from typing import Callable, Optional, Tuple +from torchao.quantization.quant_primitives import ( + _get_and_check_qmin_qmax, + choose_qparams_affine, + fake_quantize_affine, + ZeroPointDomain, + MappingType, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.utils import ( + _implements, + _dispatch__torch_function__, + _dispatch__torch_dispatch__, +) +from .utils import ( + _GenericFakeQuantize, + _UnwrapAffineFakeQuantizedTensor, +) + +aten = torch.ops.aten + + +class _ToAffineFakeQuantized(torch.autograd.Function): + """ + Differentiable constructor for `AffineFakeQuantizedTensor`, + needed for input activation fake quantization. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + original_tensor: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ) -> "AffineFakeQuantizedTensor": + def apply_fake_quant_fn(t: torch.Tensor): + assert isinstance(t, AffineFakeQuantizedTensor) + qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + scale, zero_point = choose_qparams_affine( + t.original_tensor, + mapping_type, + block_size, + target_dtype, + qmin, + qmax, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + fq = _GenericFakeQuantize.apply( + t, + block_size, + scale, + zero_point, + qmin, + qmax, + zero_point_domain, + ) + return fq + return AffineFakeQuantizedTensor( + original_tensor, + apply_fake_quant_fn, + fake_quant_enabled=True, + ) + + @staticmethod + def backward(ctx, gy): + return gy, None, None, None, None, None, None, None, None, None, None + + +class AffineFakeQuantizedTensor(torch.Tensor): + """ + Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor + with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point + + Fake quantization refers to performing the quantization math without actually casting the floating point + tensor into lower bit-width dtypes. It is commonly used for quantization-aware training (QAT). + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. + + fields: + original_tensor (torch.Tensor): tensor holding the original float values, needed for actual quantization later + apply_fake_quant_fn (Callable): function that transforms `original_tensor` to fake quantized values + """ + + @staticmethod + def __new__( + cls, + original_tensor: torch.Tensor, + apply_fake_quant_fn: Callable, + fake_quant_enabled: bool = True, + **kwargs, + ): + kwargs.setdefault("dtype", original_tensor.dtype) + kwargs.setdefault("device", original_tensor.device) + kwargs.setdefault("requires_grad", original_tensor.requires_grad) + return torch.Tensor._make_wrapper_subclass( + cls, + original_tensor.shape, + **kwargs, + ) + + def __init__( + self, + original_tensor: torch.Tensor, + apply_fake_quant_fn: Callable, + fake_quant_enabled: bool = True, + **kwargs + ): + self.original_tensor = original_tensor + self.apply_fake_quant_fn = apply_fake_quant_fn + self.fake_quant_enabled = fake_quant_enabled + + def __tensor_flatten__(self): + return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride, + ): + original_tensor = tensor_data_dict["original_tensor"] + (apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes + return cls( + original_tensor, + apply_fake_quant_fn, + fake_quant_enabled, + ) + + @classmethod + def from_float( + cls, + original_input: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ): + return _ToAffineFakeQuantized.apply( + original_input, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + + def get_value(self) -> torch.Tensor: + if self.fake_quant_enabled: + return self.apply_fake_quant_fn(self) + else: + return _UnwrapAffineFakeQuantizedTensor.apply(self) + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + "requires_grad": self.requires_grad, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + # not supported yet + kwargs.pop("memory_format") + return self.__class__( + self.original_tensor.to(device), + self.apply_fake_quant_fn, + self.fake_quant_enabled, + **kwargs, + ) + + def _apply_fn_to_data(self, fn: Callable): + """ + Create a new `AffineFakeQuantizedTensor` with `fn` applied to the + original tensor, to be called within __torch_dispatch__. + """ + return self._create_new(fn(self.original_tensor)) + + def _create_new(self, new_value: torch.Tensor): + """ + Create a new `AffineFakeQuantizedTensor` with a new value, + to be called within __torch_dispatch__. + + Note: `requires_grad` must be False here because tensors created + in `__torch_dispatch__` cannot produce gradients, since autograd + will try to attach autograd metadata to these tensors when we exit + `__torch_dispatch__`, but if these tensors already have metadata + attached then autograd will throw an error. + """ + return self.__class__( + new_value, + self.apply_fake_quant_fn, + self.fake_quant_enabled, + requires_grad=False, + ) + + implements = classmethod(_implements) + __torch_function__ = classmethod(_dispatch__torch_function__) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + +implements = AffineFakeQuantizedTensor.implements + + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(aten.mm.default) +def _(func, types, args, kwargs): + bias = None + input_tensor = args[0] + weight_tensor = args[1] + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return func(input_tensor, weight_tensor) + + +@implements(aten.addmm.default) +def _(func, types, args, kwargs): + bias = args[0] + input_tensor = args[1] + weight_tensor = args[2] + if isinstance(input_tensor, AffineFakeQuantizedTensor): + input_tensor = input_tensor.get_value() + if isinstance(weight_tensor, AffineFakeQuantizedTensor): + weight_tensor = weight_tensor.get_value() + return func(bias, input_tensor, weight_tensor) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach), + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone), + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t), + ) + + +@implements([ + aten.add.Tensor, + aten.add_.Tensor, + aten.mul_.Tensor, + aten.copy_.default, +]) +def _(func, types, args, kwargs): + assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}" + new_args = pytree.tree_map_only(AffineFakeQuantizedTensor, lambda x: x.original_tensor, args) + first_afq_tensor = args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1] + new_value = func(*new_args, **kwargs) + out = first_afq_tensor._create_new(new_value) + return return_and_correct_aliasing(func, args, kwargs, out) + + +# Needed by FSDP: + +@implements(aten.empty_like.default) +def _(func, types, args, kwargs): + out = torch.empty_like(args[0].original_tensor, **kwargs) + return return_and_correct_aliasing(func, args, kwargs, out) + + +@implements(aten.split.Tensor) +def _(func, types, args, kwargs): + new_values = torch.split(args[0].original_tensor, *args[1:], **kwargs) + + def make_new_tensor(value): + out = args[0]._create_new(value) + return return_and_correct_aliasing(func, args, kwargs, out) + + return list(map(make_new_tensor, new_values)) + + +@implements(aten.new_zeros.default) +def _(func, types, args, kwargs): + out = args[0].original_tensor.new_zeros(*args[1:], **kwargs) + return return_and_correct_aliasing(func, args, kwargs, out) + + +to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 668b73787..6cd9a704b 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -9,6 +9,9 @@ import torch import torch.nn.functional as F +from torchao.dtypes import ( + TensorCoreTiledLayoutType, +) from torchao.quantization.GPTQ import ( _check_linear_int4_k, _replace_linear_int4, @@ -18,13 +21,34 @@ Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) -from torchao.quantization.quant_primitives import ZeroPointDomain +from torchao.quantization.linear_activation_quantized_tensor import ( + to_linear_activation_quantized, +) +from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, + _replace_with_custom_fn_if_matches_filter, + int4_weight_only, + int8_dynamic_activation_int4_weight, + quantize_, +) +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.quantization.utils import ( + _get_per_token_block_size, + get_group_qparams_symmetric, +) +from .affine_fake_quantized_tensor import to_affine_fake_quantized from .utils import ( _choose_qparams_per_token_asymmetric, + _enable_fake_quant, _fake_quantize_per_channel_group, _fake_quantize_per_token, + _get_qat_linear_subclass_inserter, + _is_linear_with_fq_weight, + _unwrap_affine_fake_quantized_tensor, ) @@ -32,6 +56,51 @@ # | 8da4w QAT | # ================= +def int8_dynamic_activation_int4_weight_fake_quantize(group_size=32): + """ + Applies int8 dynamic per token asymmetric activation fake quantization and + int4 per group weight symmetric fake quantization to linear. Please see + :func:`~torchao.quantization.int8_dynamic_activation_int4_weight` for more details. + + Example usage: + from torchao.quantization import quantize_ + quantize_(model, int8_dynamic_activation_int4_weight_fake_quantize(group_size=32)) + """ + # avoid circular dep + from torchao.dtypes import to_affine_quantized + + def _apply_weight_fake_quant(weight: torch.Tensor): + mapping_type = MappingType.SYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + quant_min = -8 + quant_max = 7 + return to_affine_fake_quantized( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + ) + + def _apply_input_activation_fake_quant(x: torch.Tensor): + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int8 + return to_affine_fake_quantized( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + ) + + return _get_qat_linear_subclass_inserter( + _apply_weight_fake_quant, + _apply_input_activation_fake_quant, + ) + class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int8 @@ -58,14 +127,9 @@ def prepare( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _replace_linear_8da4w( + quantize_( model, - self.groupsize, - self.padding_allowed, - self.precision, - self.scales_precision, - Int8DynActInt4WeightQATLinear, - copy_weights=True, + int8_dynamic_activation_int4_weight_fake_quantize(group_size=self.groupsize), ) return model @@ -75,39 +139,14 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_8da4w(model) + unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) + filter_fn = _is_linear_with_fq_weight + model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) + quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize) + quantize_(model, quantize_fn) return model -def _convert_qat_linear_8da4w(module: torch.nn.Module): - """ - Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int8DynActInt4WeightQATLinear): - quantized_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - groupsize=child.groupsize, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (qmin, qmax) = child._get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) - from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper - q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( - child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, - ) - quantized_linear.weight = q_weight - quantized_linear.scales = s - quantized_linear.zeros = zp - else: - _convert_qat_linear_8da4w(child) - +# TODO: deprecate class Int8DynActInt4WeightQATLinear(torch.nn.Linear): """ This module implements a linear layer with int8 dynamic per token fake @@ -194,23 +233,54 @@ def _get_qmin_qmax(self, n_bit: int): def enable_8da4w_fake_quant(mod: torch.nn.Module): """ - Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + Enable fake quantization for int8 dynamic activations + int4 weight. """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.enable_fake_quant() + _enable_fake_quant(mod, enable=True) def disable_8da4w_fake_quant(mod: torch.nn.Module): """ - Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + Disable fake quantization for int8 dynamic activations + int4 weight. """ - if isinstance(mod, Int8DynActInt4WeightQATLinear): - mod.disable_fake_quant() + _enable_fake_quant(mod, enable=False) # ================== # | int4wo QAT | # ================== +def int4_weight_only_fake_quantize(group_size=128): + """ + Applies uint4 weight-only asymmetric per-group fake quantization to linear layers. + Please see :func:`~torchao.quantization.int4_weight_only` for more details. + + Example usage: + from torchao.quantization import quantize_ + quantize_(model, int4_weight_only_fake_quantize(group_size=32)) + """ + def _apply_fake_quant(weight): + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return to_affine_fake_quantized( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + return _get_qat_linear_subclass_inserter(_apply_fake_quant) + class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): """ Quantizer for performing QAT on a model, where linear layers have @@ -238,16 +308,7 @@ def prepare( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _replace_linear_int4( - model, - self.groupsize, - self.inner_k_tiles, - padding_allowed=True, - precision=self.precision, - scales_precision=self.scales_precision, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True, - ) + quantize_(model, int4_weight_only_fake_quantize(group_size=self.groupsize)) return model def convert( @@ -256,43 +317,15 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - _convert_qat_linear_4w(model) + unwrap_fn = _get_linear_subclass_inserter(_unwrap_affine_fake_quantized_tensor) + filter_fn = _is_linear_with_fq_weight + model = _replace_with_custom_fn_if_matches_filter(model, unwrap_fn, filter_fn) + layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles) + quantize_fn = int4_weight_only(self.groupsize, layout_type) + quantize_(model, quantize_fn) return model -def _convert_qat_linear_4w(module: torch.nn.Module): - """ - Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. - """ - for name, child in module.named_children(): - if isinstance(child, Int4WeightOnlyQATLinear): - in_features = child.in_features - out_features = child.out_features - groupsize = child.groupsize - inner_k_tiles = child.inner_k_tiles - quantized_linear = WeightOnlyInt4Linear( - in_features, - out_features, - bias=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=child.precision, - scales_precision=child.scales_precision, - ) - setattr(module, name, quantized_linear) - - # Load weights and qparams into quantized linear - n_bit = 4 - (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - child.weight, n_bit, child.groupsize, - ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), child.inner_k_tiles, - ) - quantized_linear.weight = q_weight - quantized_linear.scales_and_zeros = scales_and_zeros - else: - _convert_qat_linear_4w(child) - +# TODO: deprecate class Int4WeightOnlyQATLinear(torch.nn.Linear): """ This module implements a linear layer with int4 fake quantized grouped @@ -359,14 +392,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def enable_4w_fake_quant(mod: torch.nn.Module): """ - Enable fake quantization for `Int4WeightOnlyQATLinear`. + Enable fake quantization for int4 weight only. """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.enable_fake_quant() + _enable_fake_quant(mod, enable=True) def disable_4w_fake_quant(mod: torch.nn.Module): """ - Disable fake quantization for `Int4WeightOnlyQATLinear`. + Disable fake quantization for int4 weight only. """ - if isinstance(mod, Int4WeightOnlyQATLinear): - mod.disable_fake_quant() + _enable_fake_quant(mod, enable=False) diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 580df1ea4..625da4e39 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Tuple +from typing import Callable, List, Optional, Tuple import torch @@ -17,6 +17,15 @@ ) +# Attribute name representing the forward prehook wrapping the +# linear input in an `AffineFakeQuantizedTensor` on a linear module. +# +# The value of this attribute is a 2-tuple of (prehook, handle). +# The prehook can be disabled by calling `handle.remove()`, and +# re-enabled by calling `module.register_forward_pre_hook(prehook)`. +_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK = "_qat_linear_subclass_input_prehook" + + class _GenericFakeQuantize(torch.autograd.Function): """ Implementation of generic fake quantize with backward STE. @@ -29,15 +38,25 @@ class _GenericFakeQuantize(torch.autograd.Function): def forward( ctx: torch.autograd.function.FunctionCtx, input: torch.Tensor, + block_size: List[int], scales: torch.Tensor, zero_points: torch.Tensor, quant_min: int, quant_max: int, - block_size: List[int], zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + + if isinstance(input, AffineFakeQuantizedTensor): + _input = input.original_tensor + else: + _input = input + (fq, mask) = fake_quantize_affine_cachemask( - input, + _input, block_size, scales, zero_points, @@ -55,6 +74,31 @@ def backward(ctx, gy): (mask,) = ctx.saved_tensors return gy * mask, None, None, None, None, None, None + +class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function): + """ + Helper autograd function to unwrap `AffineFakeQuantizedTensor` while ensuring + gradients are still passed to the tensor subclass. This is used in place of + `_GenericFakeQuantize` when fake quant is disabled. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + input: torch.Tensor, + ) -> torch.Tensor: + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + assert isinstance(input, AffineFakeQuantizedTensor) + return input.original_tensor + + @staticmethod + def backward(ctx, gy): + return gy, + + def _fake_quantize_per_channel_group( input: torch.Tensor, scales: torch.Tensor, @@ -69,7 +113,7 @@ def _fake_quantize_per_channel_group( assert input.dim() == 2 block_size = (1, group_size) return _GenericFakeQuantize.apply( - input, scales, zero_points, quant_min, quant_max, block_size, zero_point_domain, + input, block_size, scales, zero_points, quant_min, quant_max, zero_point_domain, ) def _fake_quantize_per_token( @@ -85,7 +129,7 @@ def _fake_quantize_per_token( block_size = _get_per_token_block_size(input) fq_input = input.to(torch.float32) fq = _GenericFakeQuantize.apply( - fq_input, scales, zero_points, quant_min, quant_max, block_size, + fq_input, block_size, scales, zero_points, quant_min, quant_max, ) return fq.reshape_as(input).to(input.dtype) @@ -136,3 +180,82 @@ def _choose_qparams_per_token_asymmetric( zero_point = torch.clamp(zero_point, qmin, qmax).round() return scale.to(scales_precision), zero_point.to(zero_points_precision) + +def _forward_pre_hook_handler( + mod: torch.nn.Linear, + prehook: Callable, + handler: torch.utils.hooks.RemovableHandle, +): + """ + Store a 2-tuple (prehook function, handler) as an attribute on the given linear module. + """ + setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handler)) + +def _unwrap_affine_fake_quantized_tensor(t: torch.Tensor): + """ + Return the original, non-fake-quantized float tensor from a `AffineFakeQuantizedTensor`. + """ + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + assert isinstance(t, AffineFakeQuantizedTensor) + return t.original_tensor + +def _is_linear_with_fq_weight(mod: torch.nn.Module, *args): + """ + Return whether this is a nn.Linear module with `AffineFakeQuantizeTensor` weights. + """ + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + if not isinstance(mod, torch.nn.Linear) or not hasattr(mod, "weight"): + return False + weight = mod.weight + return isinstance(weight, AffineFakeQuantizedTensor) + +def _enable_fake_quant(mod: torch.nn.Module, enable: bool): + """ + Enable or disable fake quantization in the activations and weights of a `nn.Linear` module. + """ + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + if not _is_linear_with_fq_weight(mod): + return + weight = mod.weight + assert isinstance(weight, AffineFakeQuantizedTensor) + weight.fake_quant_enabled = enable + + # Enable/disable input fake quant + if hasattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK): + (prehook, handle) = getattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) + if enable and handle is None: + handle = mod.register_forward_pre_hook(prehook) + elif not enable and handle is not None: + handle.remove() + handle = None + setattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) + +def _get_qat_linear_subclass_inserter( + weight_constructor: Callable, + input_constructor: Optional[Callable] = None, +) -> Callable: + """ + Return a function that inserts wraps the weight and/or input activation of a + linear module in tensor subclasses. + + Args: + weight_constructor: constructor of the weight subclass, accepts a tensor + input_constructor: (optional) constructor of the input subclass, accepts a tensor + """ + def insert_subclass(lin): + lin.weight = torch.nn.Parameter(weight_constructor(lin.weight), requires_grad=True) + if input_constructor is not None: + prehook = lambda _, args: tuple([input_constructor(args[0])] + list(args[1:])) + handle = lin.register_forward_pre_hook(prehook) + setattr(lin, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK, (prehook, handle)) + return lin + + return insert_subclass diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a0ad665ea..95fa63b39 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -31,6 +31,7 @@ ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass, ) from .subclass import ( @@ -55,7 +56,6 @@ from .utils import _get_per_token_block_size import logging from .autoquant import autoquant, AutoQuantizableLinearWeight -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 __all__ = [ @@ -189,6 +189,11 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + # adding weight tensor subclass isinstance check to make sure the weight is only quantized once # when it is shared by multiple linear modules return ( @@ -198,6 +203,7 @@ def _is_linear(mod, *args): and not isinstance(mod.weight, AutoQuantizableLinearWeight) and not isinstance(mod.weight, AffineQuantizedTensor) and not isinstance(mod.weight, LinearActivationQuantizedTensor) + and not isinstance(mod.weight, AffineFakeQuantizedTensor) ) import torch.nn.utils.parametrize as parametrize