From 9d12e5ec0b9ddc75cf33f0a1108ec381a5259e1e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 13 Aug 2024 18:16:47 -0700 Subject: [PATCH] Use `torch.uint1` to `torch.uint7` for Uintx tensor subclass Summary: Previously we are using bit_width for uintx quantization, but we can actually use `dtype` directly. But there are still some workaround to convert from torch dtype to bit_width right now, if we want to remove all the hacks, we'd need to support Uintx tensor subclass properly and have `torch.uintx` dispatch to the tensor subclass this is probably not the highest priority for now since good perf is more important. Test Plan: python test/dtypes/test_affine_quantized.py pytest test/dtypes/test_uintx.py Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_affine_quantized.py | 44 ++++++++++++++++++++++- test/dtypes/test_uintx.py | 44 ++++++++++++----------- torchao/dtypes/affine_quantized_tensor.py | 2 ++ torchao/dtypes/uintx/Uintx.py | 35 ++++++++++++++---- torchao/quantization/quant_api.py | 24 ++++++------- torchao/quantization/quant_primitives.py | 17 ++++++--- 6 files changed, 121 insertions(+), 45 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 2c1762c3a..648309b9a 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -9,13 +9,17 @@ int8_dynamic_activation_int8_weight, int8_dynamic_activation_int8_semi_sparse_weight, ) +from torchao.dtypes import ( + to_affine_quantized, +) import torch import unittest import tempfile from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_5, ) - +from torchao.dtypes.uintx.Uintx import _DTYPE_TO_BIT_WIDTH class TestAffineQuantized(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -51,6 +55,44 @@ def test_weights_only(self): else: _ = torch.load(f, weights_only=False) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "sub byte dtype requires torch 2.3+") + def test_uintx_target_dtype(self): + from torchao.quantization.quant_api import uintx_weight_only + for dtype in _DTYPE_TO_BIT_WIDTH.keys(): + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + # make sure it runs + uintx_weight_only(dtype)(l) + l = torch.compile(l) + l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "sub byte dtype requires torch 2.3+") + def test_uintx_model_size(self): + from torchao.quantization.quant_api import uintx_weight_only + from torchao.utils import get_model_size_in_bytes + # scale size = 1/64 * 2 bytes = 1/32 bytes + # zero_point size = 1/64 * 4 bytes = 1/16 bytes + # dtype data size = 1 * bit_width/8 = bit_width/8 bytes + _dtype_to_ratio = { + torch.uint1: (1/8 + 1/16 + 1/32) / 2, + torch.uint2: (2/8 + 1/16 + 1/32) / 2, + torch.uint3: (3/8 + 1/16 + 1/32) / 2, + torch.uint4: (4/8 + 1/16 + 1/32) / 2, + torch.uint5: (5/8 + 1/16 + 1/32) / 2, + torch.uint6: (6/8 + 1/16 + 1/32) / 2, + torch.uint7: (7/8 + 1/16 + 1/32) / 2, + } + for dtype in _DTYPE_TO_BIT_WIDTH.keys(): + l = torch.nn.Sequential( + torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda") + ) + bf16_size = get_model_size_in_bytes(l) + # make sure it runs + uintx_weight_only(dtype)(l[0]) + quantized_size = get_model_size_in_bytes(l) + self.assertTrue(bf16_size * _dtype_to_ratio[dtype] == quantized_size) if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index d17f90c64..6103d7a21 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -6,7 +6,10 @@ from torchao.dtypes.uintx.Uintx import to_uintx from torchao.quantization.quant_api import quantize_, uintx_weight_only -from torchao.utils import TORCH_VERSION_AFTER_2_5 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_5, +) from torchao.quantization.quant_primitives import ( MappingType, @@ -16,7 +19,12 @@ dequantize_affine, ) -bit_widths = (1, 2, 3, 4, 5, 6, 7) +# torch.uintx dtypes are introduced in 2.3 +if TORCH_VERSION_AFTER_2_3: + dtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7) +else: + dtypes = () + group_sizes = [32, 64, 128] devices = ["cpu", "cuda"] @pytest.fixture(autouse=True) @@ -36,57 +44,51 @@ def __init__(self, scale, device): def forward(self, x): return self.net(x) -@pytest.mark.parametrize("bit_width", bit_widths) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build") -def test_uintx_weight_only_model_quant(bit_width, group_size, device): +def test_uintx_weight_only_model_quant(dtype, group_size, device): scale = 512 fp16 = Linear16(scale, device) - quantize_(fp16, uintx_weight_only(bit_width, group_size=group_size)) + quantize_(fp16, uintx_weight_only(dtype, group_size=group_size)) uintx = torch.compile(fp16, fullgraph=True) test_input = torch.randn(scale*2, dtype=torch.float16, device=device) output = uintx.forward(test_input) assert output != None, "model quantization failed" -@pytest.mark.parametrize("bit_width", bit_widths) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build") -def test_uintx_weight_only_quant(bit_width, group_size, device): +def test_uintx_weight_only_quant(dtype, group_size, device): input_float = torch.randn((1, 256), dtype=torch.float16, device = device) mapping_type = MappingType.SYMMETRIC - quant_min = 0 - quant_max = 2 ** bit_width - 1 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int32 zero_point_domain = ZeroPointDomain.INT - target_dtype = torch.uint8 block_size = (1, group_size) scale, zero_point = choose_qparams_affine( input_float, mapping_type, block_size, - target_dtype, quant_min, quant_max, eps, torch.float32, - zero_point_dtype, True, zero_point_domain + dtype, eps=eps, scale_dtype=torch.float32, + zero_point_dtype=zero_point_dtype, preserve_zero=True, zero_point_domain=zero_point_domain ) aqt = quantize_affine( input_float, block_size, scale, - zero_point, target_dtype, - quant_min = quant_min, - quant_max = quant_max, - zero_point_domain = zero_point_domain + zero_point, dtype, + zero_point_domain=zero_point_domain ) + # Note: output will be uint8 tensor for sub byte tensors for now - q = to_uintx(aqt, bit_width, -1) + q = to_uintx(aqt, dtype, -1) assert q != None, "quantization failed" deqaunt = dequantize_affine( q, block_size, scale, - zero_point, target_dtype, - quant_min = quant_min, - quant_max = quant_max, - zero_point_domain = zero_point_domain + zero_point, dtype, + zero_point_domain=zero_point_domain ) assert deqaunt != None, "deqauntization failed" diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 686ed925a..c694f32f1 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -30,6 +30,7 @@ aten = torch.ops.aten + ############################### # Base Layout Tensor Subclass # ############################### @@ -200,6 +201,7 @@ def from_float( scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + # Note: output will be uint8 tensor for sub byte tensors for now int_data = layout_type.post_process(int_data) layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) diff --git a/torchao/dtypes/uintx/Uintx.py b/torchao/dtypes/uintx/Uintx.py index 9fdaab0f4..5af45deee 100644 --- a/torchao/dtypes/uintx/Uintx.py +++ b/torchao/dtypes/uintx/Uintx.py @@ -11,10 +11,30 @@ _dispatch__torch_dispatch__, ) from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls - +from torchao.utils import TORCH_VERSION_AFTER_2_3 aten = torch.ops.aten +# Note: Uintx does not work for torch 2.3 and below +_DTYPE_TO_BIT_WIDTH = {} +_BIT_WIDTH_TO_DTYPE = {} + +if TORCH_VERSION_AFTER_2_3: + _DTYPE_TO_BIT_WIDTH = { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, + } + + _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} +else: + print("uintx feature need torch 2.3+, please upgrade pytorch") + + class UintxTensor(torch.Tensor): """ Splits int data into packed shards based on bit size @@ -90,7 +110,8 @@ def get_plain(self): def apply_transformation(self, fn): og = self.get_plain() new = fn(og) - return self.from_uint8(new, self.bit_width, self.pack_dim) + dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width] + return self.from_uint8(new, dtype, self.pack_dim) # temporary until kernels on packed tensors are created def apply_fn_to_shards(self, fn): @@ -98,7 +119,9 @@ def apply_fn_to_shards(self, fn): return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim) @classmethod - def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1): + def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1): + assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BITWIDTH.keys()}" + bit_width = _DTYPE_TO_BIT_WIDTH[dtype] shards = pack(int_data, bit_width, dim=pack_dim) shape = list(int_data.shape) shape[pack_dim] = shape[pack_dim] * bit_width // 8 @@ -107,7 +130,6 @@ def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1): implements = UintxTensor.implements - @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( @@ -137,16 +159,17 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)) ) + # quantization api integrations to_uintx = UintxTensor.from_uint8 @dataclass(frozen=True) class UintxLayoutType(LayoutType): - bit_width: int + dtype: torch.dtype pack_dim: int = -1 def post_process(self, input: torch.Tensor) -> torch.Tensor: - return to_uintx(input, self.bit_width, self.pack_dim) + return to_uintx(input, self.dtype, self.pack_dim) @register_layout_cls(UintxLayoutType) class UintxAQTLayout(PlainAQTLayout): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a329989a..d55a077e5 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -483,36 +483,36 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) -def uintx_weight_only(bit_width, group_size=64, pack_dim=-1): +def uintx_weight_only(dtype, group_size=64, pack_dim=-1): """ Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where - x is the number of bits specified by the `bit_width` argument + x is the number of bits specified by `dtype` + + Args: + `dtype`: torch.uint1 to torch.uint7 sub byte dtypes + `group_size`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained, defaults to 64 + `pack_dim`: the dimension we use for packing, defaults to -1 """ from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, - choose_qparams_affine, - quantize_affine, - dequantize_affine, ) from torchao.dtypes.uintx.Uintx import UintxLayoutType from torchao.dtypes import to_affine_quantized from torchao.quantization.quant_api import _get_linear_subclass_inserter - def apply_uintx_weight_only_quant(weight): - layout_type = UintxLayoutType(bit_width=bit_width, pack_dim=pack_dim) + def apply_uintx_weight_only_quant(weight): + layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim) mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) - quant_min = 0 - quant_max = 2**bit_width - 1 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int32 zero_point_domain = ZeroPointDomain.INT return to_affine_quantized( - weight, mapping_type, block_size, torch.uint8, - quant_min = quant_min, quant_max = quant_max, - eps = eps, zero_point_dtype=zero_point_dtype, + weight, mapping_type, block_size, dtype, + eps=eps, zero_point_dtype=zero_point_dtype, zero_point_domain=zero_point_domain, layout_type=layout_type, ) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1d958be84..09ab1f537 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -65,9 +65,10 @@ class ZeroPointDomain(Enum): torch.int16: (-(2**15), 2**15 - 1), torch.int32: (-(2**31), 2**31 - 1), } +_SUB_BYTE_DTYPE_BOUNDS: Dict[torch.dtype, Tuple[int, int]] = {} if TORCH_VERSION_AFTER_2_3: - _DTYPE_TO_QVALUE_BOUNDS.update({ + _SUB_BYTE_DTYPE_BOUNDS = { torch.uint1: (0, 2**1-1), torch.uint2: (0, 2**2-1), torch.uint3: (0, 2**3-1), @@ -75,7 +76,10 @@ class ZeroPointDomain(Enum): torch.uint5: (0, 2**5-1), torch.uint6: (0, 2**6-1), torch.uint7: (0, 2**7-1), - }) + } + _DTYPE_TO_QVALUE_BOUNDS.update( + _SUB_BYTE_DTYPE_BOUNDS + ) quant_lib = torch.library.Library("quant", "FRAGMENT") @@ -213,6 +217,10 @@ def _quantize_affine( """op definition that has compatible signatures with custom op library """ quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) + # workaround for uintx dtypes, since we don't have native Uintx dtype connected with + # torch.uintx dtypes yet + if output_dtype in _SUB_BYTE_DTYPE_BOUNDS: + output_dtype = torch.uint8 return _quantize_affine_no_dtype_cast( input, block_size, @@ -325,10 +333,9 @@ def _dequantize_affine( ) -> torch.Tensor: """op definition that has compatible signatures with custom op library """ - - # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size - assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}" + if input_dtype not in _SUB_BYTE_DTYPE_BOUNDS: + assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}" assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) return _dequantize_affine_no_dtype_check(