From f470f496daae0796fc2b6898d62442bc28380bc4 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 29 Jul 2024 17:57:42 -0700 Subject: [PATCH] Refactor QAT into its own module Summary: Refactor QAT into its own module so future QAT features can live under the same folder without making qat.py longer, and a separate QAT README can be added in the future. Test Plan: python test/quantization/test_qat.py --- test/quantization/test_qat.py | 26 +-- .../quantization/prototype/qat/__init__.py | 17 ++ .../prototype/{qat.py => qat/api.py} | 155 ++---------------- torchao/quantization/prototype/qat/utils.py | 145 ++++++++++++++++ 4 files changed, 185 insertions(+), 158 deletions(-) create mode 100644 torchao/quantization/prototype/qat/__init__.py rename torchao/quantization/prototype/{qat.py => qat/api.py} (69%) create mode 100644 torchao/quantization/prototype/qat/utils.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 9a3888274b..75bd3061b0 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -12,11 +12,11 @@ import torch from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 -from torchao.quantization.prototype.qat import ( +from torchao.quantization.prototype.qat.utils import ( _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, _GenericFakeQuantize, - fake_quantize_per_channel_group, - fake_quantize_per_token, ) from torchao.quantization.quant_primitives import ( fake_quantize_affine, @@ -85,7 +85,7 @@ def test_fake_quantize_per_channel_group(self): x2 = copy.deepcopy(x) # fake quant op - out = fake_quantize_per_channel_group( + out = _fake_quantize_per_channel_group( x, s, zp, qmin, qmax, group_size, ) out.sum().backward() @@ -110,7 +110,7 @@ def test_fake_quantize_per_token(self): (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) # fake quant op - out = fake_quantize_per_token(x, s, zp, qmin, qmax) + out = _fake_quantize_per_token(x, s, zp, qmin, qmax) out.sum().backward() # compare against PTQ ops @@ -135,7 +135,7 @@ def _set_ptq_weight( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) - from torchao.quantization.prototype.qat import ( + from torchao.quantization.prototype.qat.api import ( Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, ) @@ -167,7 +167,7 @@ def _set_ptq_weight( @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear + from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear group_size = 128 @@ -192,7 +192,7 @@ def test_qat_8da4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer group_size = 16 @@ -226,7 +226,7 @@ def test_qat_8da4w_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATQuantizer with torch.device("meta"): m = M() @@ -241,7 +241,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.prototype.qat.api import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, enable_8da4w_fake_quant, @@ -294,7 +294,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.prototype.qat.api import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, ) @@ -425,7 +425,7 @@ def test_qat_4w_primitives(self): # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_linear(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear + from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear group_size = 128 @@ -455,7 +455,7 @@ def test_qat_4w_linear(self): # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_quantizer(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer group_size = 32 diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/quantization/prototype/qat/__init__.py new file mode 100644 index 0000000000..ca0b1c1a00 --- /dev/null +++ b/torchao/quantization/prototype/qat/__init__.py @@ -0,0 +1,17 @@ +from .api import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, +) + +__all__ = [ + "disable_4w_fake_quant", + "disable_8da4w_fake_quant", + "enable_4w_fake_quant", + "enable_8da4w_fake_quant", + "Int4WeightOnlyQATQuantizer", + "Int8DynActInt4WeightQATQuantizer", +] diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat/api.py similarity index 69% rename from torchao/quantization/prototype/qat.py rename to torchao/quantization/prototype/qat/api.py index f64351d7c6..668b737877 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat/api.py @@ -4,12 +4,10 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Optional, Tuple +from typing import Any, Optional import torch import torch.nn.functional as F -from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib -from torch.library import impl from torchao.quantization.GPTQ import ( _check_linear_int4_k, @@ -20,14 +18,13 @@ Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) -from torchao.quantization.quant_primitives import ( - fake_quantize_affine_cachemask, - ZeroPointDomain, -) +from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.quantization.unified import TwoStepQuantizer -from torchao.quantization.utils import ( - _get_per_token_block_size, - get_group_qparams_symmetric, +from torchao.quantization.utils import get_group_qparams_symmetric +from .utils import ( + _choose_qparams_per_token_asymmetric, + _fake_quantize_per_channel_group, + _fake_quantize_per_token, ) @@ -163,7 +160,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x, self.scales_precision, self.zero_points_precision, ) (act_qmin, act_qmax) = self._get_qmin_qmax(8) - x_fq = fake_quantize_per_token( + x_fq = _fake_quantize_per_token( x, act_scales, act_zp, act_qmin, act_qmax, ) else: @@ -177,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # TODO: pass zp dtype to `get_group_qparams_symmetric` instead weight_zp = weight_zp.to(self.zero_points_precision) (weight_qmin, weight_qmax) = self._get_qmin_qmax(4) - w_fq = fake_quantize_per_channel_group( + w_fq = _fake_quantize_per_channel_group( self.weight, weight_scales, weight_zp, @@ -349,7 +346,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: scales, zero_points = get_groupwise_affine_qparams( self.weight, n_bit, self.groupsize, self.scales_precision, ) - w_fq = fake_quantize_per_channel_group( + w_fq = _fake_quantize_per_channel_group( self.weight, scales, zero_points, @@ -373,135 +370,3 @@ def disable_4w_fake_quant(mod: torch.nn.Module): """ if isinstance(mod, Int4WeightOnlyQATLinear): mod.disable_fake_quant() - - -# ======================== -# | QUANT PRIMITIVES | -# ======================== - -class _GenericFakeQuantize(torch.autograd.Function): - """ - Implementation of generic fake quantize with backward STE. - - With the appropriate input tensor shape, this can be used to express - grouped per channel fake quantize or per token fake quantize. - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - input: torch.Tensor, - scales: torch.Tensor, - zero_points: torch.Tensor, - quant_min: int, - quant_max: int, - block_size: List[int], - zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT, - ) -> torch.Tensor: - # Note: for bf16 inputs, casting them to fp32 has the unexpected - # side effect of reducing memory footprint significantly, presumably - # because bf16 * fp32 kernels are not as memory efficient - assert input.dtype == torch.float32 - assert scales.dtype == torch.float32 - assert zero_points.dtype == torch.int32 - - (fq, mask) = fake_quantize_affine_cachemask( - input, - block_size, - scales, - zero_points, - torch.int32, - quant_min, - quant_max, - zero_point_domain, - ) - - ctx.save_for_backward(mask) - return fq - - @staticmethod - def backward(ctx, gy): - (mask,) = ctx.saved_tensors - return gy * mask, None, None, None, None, None, None - -def fake_quantize_per_channel_group( - input: torch.Tensor, - scales: torch.Tensor, - zero_points: torch.Tensor, - quant_min: int, - quant_max: int, - group_size: int, - zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT, -) -> torch.Tensor: - assert group_size > 1 - assert input.shape[-1] % group_size == 0 - assert input.dim() == 2 - block_size = (1, group_size) - return _GenericFakeQuantize.apply( - input, scales, zero_points, quant_min, quant_max, block_size, zero_point_domain, - ) - -def fake_quantize_per_token( - input: torch.Tensor, - scales: torch.Tensor, - zero_points: torch.Tensor, - quant_min: int, - quant_max: int, -) -> torch.Tensor: - from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check - - _per_token_quant_qparam_dim_check(input, scales, zero_points) - block_size = _get_per_token_block_size(input) - fq_input = input.to(torch.float32) - fq = _GenericFakeQuantize.apply( - fq_input, scales, zero_points, quant_min, quant_max, block_size, - ) - return fq.reshape_as(input).to(input.dtype) - -# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py. -# The version in pytorch does not have backward support yet so we add -# it here for now until https://github.com/pytorch/pytorch/pull/123452 -# is landed. -def _choose_qparams_per_token_asymmetric( - input: torch.Tensor, - scales_precision: torch.dtype = torch.float32, - zero_points_precision: torch.dtype = torch.float32, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Choose quantization parameters for per token quantization. This means for a N dimension Tensor - (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize - every N elements with the same quantization parameter. The dimension for scales/zero_points - will be (M1 * M2 ... * Mn) - - Args: - input (torch.Tensor): original float32/float16 Tensor - scales_precision (torch.dtype): precision of returned scales - zero_points_precision (torch.dtype): precision of returned zero points - - Returns: - scales and zero_points, both float32 Tensors - """ - # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 - qmin, qmax = -128, 127 - min_val = torch.amin(input, dim=-1, keepdim=True) - max_val = torch.amax(input, dim=-1, keepdim=True) - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - eps = torch.finfo(torch.float32).eps # use xnnpack eps? - - # scale - scale = (max_val_pos - min_val_neg) / float(qmax - qmin) - scale = scale.clamp(min=eps) - - # zero point - descaled_min = min_val_neg / scale - descaled_max = max_val_pos / scale - zero_point_from_min_error = qmin + descaled_min - zero_point_from_max_error = qmax + descaled_max - zero_point = torch.where( - zero_point_from_min_error + zero_point_from_max_error > 0, - qmin - descaled_min, - qmax - descaled_max, - ) - zero_point = torch.clamp(zero_point, qmin, qmax).round() - - return scale.to(scales_precision), zero_point.to(zero_points_precision) diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py new file mode 100644 index 0000000000..fee9f631b4 --- /dev/null +++ b/torchao/quantization/prototype/qat/utils.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +import torch + +from torchao.quantization.quant_primitives import ( + fake_quantize_affine_cachemask, + ZeroPointDomain, +) +from torchao.quantization.utils import ( + _get_per_token_block_size, +) + + +class _GenericFakeQuantize(torch.autograd.Function): + """ + Implementation of generic fake quantize with backward STE. + + With the appropriate input tensor shape, this can be used to express + grouped per channel fake quantize or per token fake quantize. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + block_size: List[int], + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ) -> torch.Tensor: + # Note: for bf16 inputs, casting them to fp32 has the unexpected + # side effect of reducing memory footprint significantly, presumably + # because bf16 * fp32 kernels are not as memory efficient + assert input.dtype == torch.float32 + assert scales.dtype == torch.float32 + assert zero_points.dtype == torch.int32 + + (fq, mask) = fake_quantize_affine_cachemask( + input, + block_size, + scales, + zero_points, + torch.int32, + quant_min, + quant_max, + zero_point_domain, + ) + + ctx.save_for_backward(mask) + return fq + + @staticmethod + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None, None + +def _fake_quantize_per_channel_group( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + group_size: int, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, +) -> torch.Tensor: + assert group_size > 1 + assert input.shape[-1] % group_size == 0 + assert input.dim() == 2 + block_size = (1, group_size) + return _GenericFakeQuantize.apply( + input, scales, zero_points, quant_min, quant_max, block_size, zero_point_domain, + ) + +def _fake_quantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check + + _per_token_quant_qparam_dim_check(input, scales, zero_points) + block_size = _get_per_token_block_size(input) + fq_input = input.to(torch.float32) + fq = _GenericFakeQuantize.apply( + fq_input, scales, zero_points, quant_min, quant_max, block_size, + ) + return fq.reshape_as(input).to(input.dtype) + +# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py. +# The version in pytorch does not have backward support yet so we add +# it here for now until https://github.com/pytorch/pytorch/pull/123452 +# is landed. +def _choose_qparams_per_token_asymmetric( + input: torch.Tensor, + scales_precision: torch.dtype = torch.float32, + zero_points_precision: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + scales_precision (torch.dtype): precision of returned scales + zero_points_precision (torch.dtype): precision of returned zero points + + Returns: + scales and zero_points, both float32 Tensors + """ + # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 + qmin, qmax = -128, 127 + min_val = torch.amin(input, dim=-1, keepdim=True) + max_val = torch.amax(input, dim=-1, keepdim=True) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + eps = torch.finfo(torch.float32).eps # use xnnpack eps? + + # scale + scale = (max_val_pos - min_val_neg) / float(qmax - qmin) + scale = scale.clamp(min=eps) + + # zero point + descaled_min = min_val_neg / scale + descaled_max = max_val_pos / scale + zero_point_from_min_error = qmin + descaled_min + zero_point_from_max_error = qmax + descaled_max + zero_point = torch.where( + zero_point_from_min_error + zero_point_from_max_error > 0, + qmin - descaled_min, + qmax - descaled_max, + ) + zero_point = torch.clamp(zero_point, qmin, qmax).round() + + return scale.to(scales_precision), zero_point.to(zero_points_precision)