diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 2fe684f1b1..2bb94984bd 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -37,7 +37,6 @@ choose_qparams_affine, quantize_affine, dequantize_affine, - MappingType, ) from torchao.quantization.utils import ( dequantize_per_channel, @@ -1436,7 +1435,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype): api(model) size2 = torchao.utils.get_model_size_in_bytes(model) self.assertTrue(size2 < size) - + diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index be8ef5795f..c879e944e9 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -22,10 +22,6 @@ from torchao.dtypes import ( AffineQuantizedTensor, ) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) from torchao.quantization.subclass import ( LinearActQuantizedTensor, Int8WeightOnlyQuantizedLinearWeight, diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index 83c7d22fb4..129a30b947 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -6,8 +6,6 @@ choose_qparams_affine, quantize_affine, dequantize_affine, - ZeroPointDomain, - MappingType, int_scaled_matmul, ) from torchao.quantization.utils import ( @@ -98,12 +96,12 @@ class AffineQuantizedTensor(torch.Tensor): shape (torch.Size): the shape for the Tensor quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float" if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization - default is ZeroPointDomain.INT + default is "int" input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object dtype: dtype for external representation of the tensor, e.g. torch.float32 """ @@ -116,7 +114,7 @@ def __new__( shape: torch.Size, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: str = "int", dtype=None, strides=None, ): @@ -138,7 +136,7 @@ def __init__( shape: torch.Size, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: str = "int", dtype=None, strides=None, ): @@ -184,7 +182,7 @@ def __tensor_unflatten__( def from_float( cls, input_float: torch.Tensor, - mapping_type: MappingType, + mapping_type: str, block_size: Tuple[int, ...], target_dtype: torch.dtype, quant_min: Optional[int] = None, @@ -193,7 +191,7 @@ def from_float( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: str = "int", extended_layout: str = "plain", # TODO: this is only for "tensor_core_tiled", need to figure out # the proper API for this arg @@ -520,7 +518,7 @@ def get_plain(self): target_dtype = torch.int32 quant_min = 0 quant_max = 15 - zero_point_domain = ZeroPointDomain.FLOAT + zero_point_domain = "int" assert len(block_size) == 2 and block_size[0] == 1 groupsize = block_size[-1] dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) @@ -597,7 +595,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): weight_is_uint4 and weight_qtensor.dtype == torch.bfloat16 and len(weight_qtensor.shape) == 2 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and + weight_qtensor.zero_point_domain == "float" and weight_qtensor.extended_layout == "tensor_core_tiled" ): assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" @@ -640,7 +638,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): len(weight_qtensor.block_size) == 2 and weight_qtensor.block_size[0] == 1 and weight_qtensor.block_size[1] == weight_qtensor.shape[1] and - weight_qtensor.zero_point_domain == ZeroPointDomain.INT and + weight_qtensor.zero_point_domain == "int" and weight_qtensor.extended_layout == "plain" ): # TODO: enable cpu and mps efficient path diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2a65d3c831..8bf4c87982 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -31,10 +31,6 @@ to_linear_act_quantized, ) -from .quant_primitives import ( - MappingType, - ZeroPointDomain, -) from .weight_only import WeightOnlyInt8QuantLinear from .unified import Quantizer, TwoStepQuantizer from .GPTQ import ( @@ -272,7 +268,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Union[str, Callable[ # weight settings groupsize = 32 - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" block_size = (1, groupsize) target_dtype = torch.int32 quant_min = 0 @@ -280,7 +276,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Union[str, Callable[ eps = 1e-6 preserve_zero = False zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + zero_point_domain = "float" apply_weight_quant = lambda x: to_affine_quantized( x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, @@ -321,7 +317,7 @@ def apply_8da4w_quant(weight): from torchao.dtypes import to_affine_quantized # weight settings - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" block_size = (1, groupsize) target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -338,7 +334,7 @@ def get_per_token_block_size(x): return block_size # input settings - input_mapping_type = MappingType.ASYMMETRIC + input_mapping_type = "asymmetric" input_target_dtype = torch.int8 input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) @@ -363,7 +359,7 @@ def apply_int4wo_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" block_size = (1, groupsize) target_dtype = torch.int32 quant_min = 0 @@ -371,7 +367,7 @@ def apply_int4wo_quant(weight): eps = 1e-6 preserve_zero = False zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + zero_point_domain = "float" return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles) return apply_int4wo_quant @@ -385,7 +381,7 @@ def apply_int8wo_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 @@ -407,7 +403,7 @@ def apply_int8dyn_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized # weight settings - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" def get_weight_block_size(x): return (1, x.shape[1]) target_dtype = torch.int8 @@ -421,7 +417,7 @@ def get_per_token_block_size(x): block_size[i] = 1 return block_size - input_mapping_type = MappingType.SYMMETRIC + input_mapping_type = "symmetric" input_target_dtype = torch.int8 input_eps = 1e-5 input_quant_min = -127 diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a78c42605a..d49806289d 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,13 +4,16 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +from enum import Enum, auto from typing import List, Optional, Tuple, Dict import torch from torchao.kernel.intmm import int_scaled_matmul from torchao.kernel.intmm import safe_int_mm -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_5, +) __all__ = [ @@ -21,31 +24,6 @@ "dequantize_affine", ] -class MappingType(Enum): - """How floating point number is mapped to integer number - - symmetric mapping means floating point range is symetrically mapped to integer range - let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) - we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) - e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) - - asymmetric mapping means we just directly map the floating point range to integer range, - for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter - based on this mapping - e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) - """ - SYMMETRIC = 0 - ASYMMETRIC = 1 - -class ZeroPointDomain(Enum): - """Enum that indicate whether zero_point is in integer domain or floating point domain - - integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) - float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale - """ - INT = 0 - FLOAT = 1 - """ Map from dtype to the bound value of integers TODO: maybe can replace this with call to torch.iinfo @@ -130,17 +108,32 @@ def _get_reduction_params(block_size, input_size): cur_dim += 1 return shape_for_reduction, reduction_dims +def register_custom_op(name: str): + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + if TORCH_VERSION_AFTER_2_5: + opdef = torch.library.custom_op(name, mutates_args=())(fn) + opdef.register_fake(fn) + register_decomposition([opdef._opoverload])(fn) + return opdef + else: + return fn + + return decorator + +@register_custom_op("quant::quantize_affine") def quantize_affine( input: torch.Tensor, - block_size: Tuple[int, ...], + block_size: List[int], scale: torch.Tensor, zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, -): + zero_point_domain: str = "int", +) -> torch.Tensor: """ Args: input (torch.Tensor): original float32, float16 or bfloat16 Tensor @@ -151,12 +144,12 @@ def quantize_affine( output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" for "float" if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization - default is ZeroPointDomain.INT + default is "int" Note: How can block_size represent different granularities? @@ -170,6 +163,11 @@ def quantize_affine( per_group (groupsize=2) | (3, 3, 10, 2) per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + Note: + zero_point_domain also affects how the floating point value is quantized: + + integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) + float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale Output: quantized tensor with requested dtype @@ -188,12 +186,12 @@ def quantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == "int": quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ).to(output_dtype) else: - assert zero_point_domain == ZeroPointDomain.FLOAT + assert zero_point_domain == "float" mid_point = (quant_max + quant_min + 1) / 2 min_val = zero_point - scale * mid_point quant = ( @@ -205,15 +203,16 @@ def quantize_affine( return quant +@register_custom_op("quant::dequantize_affine") def dequantize_affine( input: torch.Tensor, - block_size: Tuple[int, ...], + block_size: List[int], scale: torch.Tensor, zero_point: Optional[torch.Tensor], input_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: str = "int", *, output_dtype: torch.dtype = torch.float32, ): @@ -228,12 +227,12 @@ def dequantize_affine( quant_min (Optional[int]): minimum quantized value for input Tensor quant_max (Optional[int]): maximum quantized value for input Tensor output_dtype (torch.dtype): dtype for output Tensor, default is fp32 - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float" if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization - default is ZeroPointDomain.INT + default is "int" Output: dequantized Tensor, with requested dtype or fp32 @@ -255,7 +254,7 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == "int": # Force a copy to avoid input modification due # to upcoming in-place operations. dequant = input.to(torch.int32, copy=True) @@ -264,7 +263,7 @@ def dequantize_affine( dequant = dequant.to(output_dtype) dequant *= scale else: - assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" + assert zero_point_domain == "float", f"Unexpected zero point domain: {zero_point_domain}" mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification dequant = input - mid_point @@ -275,10 +274,11 @@ def dequantize_affine( return dequant.view(original_shape).to(output_dtype) +@register_custom_op("quant::choose_qparams_affine") def choose_qparams_affine( input: torch.Tensor, - mapping_type: MappingType, - block_size: Tuple[int, ...], + mapping_type: str, + block_size: List[int], target_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, @@ -286,13 +286,13 @@ def choose_qparams_affine( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + zero_point_domain: str = "int", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input (torch.Tensor): fp32, bf16, fp16 input Tensor - mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric - block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + mapping_type (str): determines how the qparams are calculated, "symmetric" or "asymmetric" + block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam e.g. when size is the same as the input tensor dimension, we are using per tensor quantization target_dtype (torch.dtype): dtype for target quantized Tensor quant_min (Optional[int]): minimum quantized value for target quantized Tensor @@ -310,18 +310,32 @@ def choose_qparams_affine( If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float" if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization - default is ZeroPointDomain.INT + default is "int" + + + Note: + How floating point number is mapped to integer number? + + symmetric mapping means floating point range is symetrically mapped to integer range + let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) + we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) + e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) + + asymmetric mapping means we just directly map the floating point range to integer range, + for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter + based on this mapping + e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) Output: Tuple of scales and zero_points Tensor with requested dtype """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) - assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}" + assert mapping_type in ["symmetric", "asymmetric"], f"Unsupported mapping type: {mapping_type}" if scale_dtype is None: scale_dtype = input.dtype @@ -342,21 +356,22 @@ def choose_qparams_affine( min_val_neg = min_val max_val_pos = max_val - if mapping_type == MappingType.SYMMETRIC: + if mapping_type == "symmetric": max_val_pos = torch.max(-min_val_neg, max_val_pos) scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - if zero_point_domain != ZeroPointDomain.INT: + if zero_point_domain != "int": raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: + assert mapping_type == "asymmetric" scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) else: - assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain" + assert zero_point_domain == "float", "if not preserve_zero, zero_point must be in FLOAT domain" mid_point = (quant_max + quant_min + 1) / 2 zero_point = min_val_neg + scale * mid_point diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index a2801a622f..1eeb03b591 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -9,10 +9,6 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from .quant_primitives import ( - MappingType, -) - from .utils import ( find_multiple, dequantize_per_channel, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3e3943c93c..f1ab56f0e5 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -10,8 +10,6 @@ import torch.nn.utils.parametrize as parametrize from torchao.utils import find_multiple from .quant_primitives import ( - MappingType, - ZeroPointDomain, choose_qparams_affine, quantize_affine, dequantize_affine, @@ -132,7 +130,7 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): # and slightly modified def quantize_activation_per_token_absmax(t): # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" block_size = list(t.shape) for i in range(len(block_size) - 1): block_size[i] = 1 @@ -241,7 +239,7 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): block_size = (1, x.shape[1]) zero_point_dtype = torch.int64 - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype) quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max) return quant, scale, zero_point @@ -278,7 +276,7 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16 assert w.dim() == 2 assert n_bit <= 8, f"only n_bit smaller than 8 is supported, got: {n_bit}" - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" target_dtype = torch.int32 block_size = (1, groupsize) quant_min = 0 @@ -298,7 +296,7 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16 scale_dtype=scale_dtype, zero_point_dtype=zero_point_dtype, preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT + zero_point_domain="float", ) return scale.to(dtype=dtype).reshape(w.shape[0], -1), zero_point.to( @@ -347,7 +345,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_min = 0 quant_max = 2 ** n_bit - 1 - return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) + return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain="float") def groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, @@ -367,7 +365,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( input_dtype = torch.int32 quant_min = 0 quant_max = 2**n_bit - 1 - return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) + return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain="float", output_dtype=scales.dtype) def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): @@ -401,7 +399,7 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float assert w.dim() == 2 assert n_bit <= 8, f"unsupported n_bit: {n_bit}" - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" block_size = (1, groupsize) eps = torch.finfo(torch.float32).eps ranges = {}