From 22b937f7887e355f5c315d6b896ed259b156120f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 18 Aug 2025 13:34:46 -0700 Subject: [PATCH] Add Int4TilePackedTo4dTensor for int4 quantization and tile packed to 4d packing This commit introduces Int4TilePackedTo4dTensor, a new tensor subclass for int4 weight-only quantization using tensor core tiled packing format. Key features: - Implements tensor core tiled packing for efficient computation on tensor cores - Supports PackingFormat.TILE_PACKED_TO_4D in Int4WeightOnlyConfig version 2 - Optimized for tinygemm int4mm kernel (_weight_int4pack_mm) - Includes comprehensive test suite The implementation follows the same pattern as other int4 tensor subclasses but uses a specialized packing format optimized for tensor core matrix multiplication performance. Changes: - Add Int4TilePackedTo4dTensor implementation - Update Int4WeightOnlyConfig version 2 to support TILE_PACKED_TO_4D packing format - Add TILE_PACKED_TO_4D to PackingFormat enum - Add comprehensive tests including serialization, different group sizes, and error conditions - Update __init__.py files to export new tensor class Test: python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py --- .../test_int4_tile_packed_to_4d_tensor.py | 270 +++++++++++++++ torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 18 +- .../quantize_/common/packing_format.py | 5 + .../quantize_/workflows/__init__.py | 2 + .../quantize_/workflows/int4/__init__.py | 7 - .../int4/int4_tile_packed_to_4d_tensor.py | 312 ++++++++++++++++++ torchao/testing/utils.py | 13 +- 8 files changed, 615 insertions(+), 14 deletions(-) create mode 100644 test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py create mode 100644 torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py new file mode 100644 index 0000000000..1c0e33c960 --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import Int4WeightOnlyConfig, quantize_ +from torchao.quantization.quantize_.common.packing_format import PackingFormat +from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import ( + Int4TilePackedTo4dTensor, +) +from torchao.quantization.utils import compute_error +from torchao.testing.utils import TorchAOIntegrationTestCase +from torchao.utils import is_sm_at_least_90 + +INT4_CONFIG = Int4WeightOnlyConfig( + group_size=128, + packing_format=PackingFormat.TILE_PACKED_TO_4D, + version=2, +) + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Need sm90+") +class TestInt4TilePackedTo4dTensor(TorchAOIntegrationTestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 128), + ], + ) + def test_linear(self, sizes): + config = INT4_CONFIG + dtype = torch.bfloat16 + device = "cuda" + + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + def test_module_path(self): + config = INT4_CONFIG + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear.cuda(), config) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + def test_slice(self): + """Note: we use multiples of 1024 for both in_features and out_features + so that padding does not affect the weight after slicing + """ + config = INT4_CONFIG + dtype = torch.bfloat16 + device = "cuda" + + # Create a 2048x2048 linear layer for testing + dummy = torch.nn.Linear(2048, 2048, bias=False, dtype=dtype, device=device) + + # Create reference sliced linear layers + dummy1 = torch.nn.Linear(2048, 1024, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 1024), requires_grad=False + ) + dummy2 = torch.nn.Linear(1024, 2048, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 1024), requires_grad=False + ) + + # Quantize the main linear layer + quantize_(dummy, config) + + # Shape analysis for TilePackedTo4d format: + # Original weight shape: (2048, 2048) -> no padding needed (already multiple of 1024) + # n = 2048, k = 2048, inner_k_tiles = 8, group_size = 128 + # + # qdata shape: [n/8, k/(inner_k_tiles*16), 32, inner_k_tiles/2] + # = [2048/8, 2048/(8*16), 32, 8/2] + # = [256, 16, 32, 4] + # + # scale_and_zero shape: [in_features/group_size, out_features, 2] (packed format) + # = [2048/128, 2048, 2] = [16, 2048, 2] + + # Test slicing along output dimension (dim=0: 2048 -> 1024) + weight1 = dummy.weight.narrow(0, 0, 1024) + + # qdata slicing: narrow from [256, 16, 32, 4] to [128, 16, 32, 4] + # Calculation: 1024 out_features / 2048 total * 256 qdata_dim0 = 128 + expected_qdata_slice_0 = dummy.weight.qdata.narrow(0, 0, 128) + self.assertEqual(weight1.qdata, expected_qdata_slice_0) + + # scale_and_zero slicing: narrow from [16, 2048, 2] to [16, 1024, 2] + # slicing 0th dim of qdata means we have to slice 1th dim of scale_and_zero + expected_scale_zero_slice_0 = dummy.weight.scale_and_zero.narrow(1, 0, 1024) + self.assertEqual(weight1.scale_and_zero, expected_scale_zero_slice_0) + + # Test slicing along input dimension (dim=1: 2048 -> 1024) + weight2 = dummy.weight.narrow(1, 0, 1024) + + # qdata slicing: narrow from [256, 16, 32, 4] to [256, 8, 32, 4] + # k = 2048 + # Calculation: 1024 in_features (1/2 of in_features) corresponds to 1/2 of qdata dimension 1 + # which is k / (inner_k_tiles * 16) / 2 = 2048 / (8 * 16) / 2 = 8 + expected_qdata_slice_1 = dummy.weight.qdata.narrow(1, 0, 8) + self.assertEqual(weight2.qdata, expected_qdata_slice_1) + + # scale_and_zero slicing: narrow from [16, 2048, 2] to [8, 2048, 2] + expected_scale_zero_slice_1 = dummy.weight.scale_and_zero.narrow(0, 0, 8) + self.assertEqual(weight2.scale_and_zero, expected_scale_zero_slice_1) + + # Verify that sliced weights produce similar results to reference implementations + input1 = torch.randn(2, 2048, dtype=dtype, device=device) + res_ref1 = dummy1(input1) + + # Create a new linear layer with the sliced weight + test_linear1 = torch.nn.Linear( + 2048, 1024, bias=False, dtype=dtype, device=device + ) + test_linear1.weight = torch.nn.Parameter( + weight1.contiguous(), requires_grad=False + ) + res1 = test_linear1(input1) + self.assertGreater(compute_error(res_ref1, res1), 14) + + input2 = torch.randn(2, 1024, dtype=dtype, device=device) + res_ref2 = dummy2(input2) + + # Create a new linear layer with the sliced weight + test_linear2 = torch.nn.Linear( + 1024, 2048, bias=False, dtype=dtype, device=device + ) + test_linear2.weight = torch.nn.Parameter( + weight2.contiguous(), requires_grad=False + ) + res2 = test_linear2(input2) + self.assertGreater(compute_error(res_ref2, res2), 14) + + def test_slice_preserves_aliasing(self): + config = INT4_CONFIG + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert ( + param.data.scale_and_zero.data_ptr() == param_data.scale_and_zero.data_ptr() + ) + + def test_cant_initialize_in_cpu(self): + config = INT4_CONFIG + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + # make sure there is no cpu implementation of the packing op currently + with self.assertRaisesRegex( + NotImplementedError, + "Could not run 'aten::_convert_weight_to_int4pack' with arguments from the 'CPU' backend. ", + ): + quantize_(linear, config) + + def test_to_device(self): + # test calling to on the tensor that's already on the same device works + config = INT4_CONFIG + + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) + quantize_(linear, config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) + quantize_(linear, config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) + quantize_(linear, config) + linear.to(device) + + def test_slice_and_copy_similar_to_vllm(self): + self._test_slice_and_copy_similar_to_vllm(INT4_CONFIG) + + @parametrize("device", ["cuda"]) + @parametrize("dtype", [torch.bfloat16]) + def test_mm_int4wo(self, device, dtype): + weight = torch.randn(512, 1024).to(device).to(dtype) + weight = weight.t() + + l = torch.nn.Linear(512, 1024).to(device).to(dtype) + l.weight = torch.nn.Parameter(weight) + quantize_(l, INT4_CONFIG) + # weight shape: 1024 x 512 + weight = l.weight + + input = torch.randn(1, 512, device=device, dtype=dtype) + # make sure it runs + torch.nn.functional.linear(input, weight) + + @parametrize("group_size", [32, 64, 128]) + def test_different_group_sizes(self, group_size): + """Test with different group sizes""" + dtype = torch.bfloat16 + device = "cuda" + hp_tensor = torch.randn(256, 512, dtype=dtype, device=device) + block_size = (1, group_size) + + tensor = Int4TilePackedTo4dTensor.from_hp(hp_tensor, block_size) + + self.assertEqual(tensor.shape, hp_tensor.shape) + self.assertEqual(tensor.block_size, block_size) + + def test_error_conditions(self): + """Test various error conditions""" + dtype = torch.bfloat16 + device = "cuda" + hp_tensor = torch.randn(128, 256, dtype=dtype, device=device) + + # Test invalid block_size length + with self.assertRaises(AssertionError): + Int4TilePackedTo4dTensor.from_hp( + hp_tensor, (64,) + ) # block_size length mismatch + + # Test non-groupwise quantization + with self.assertRaises(AssertionError): + Int4TilePackedTo4dTensor.from_hp( + hp_tensor, (2, 64) + ) # first element should be 1 + + +instantiate_parametrized_tests(TestInt4TilePackedTo4dTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 90e42747b4..ab49fb1d12 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -94,6 +94,7 @@ Int4OpaqueTensor, Int4PreshuffledTensor, Int4Tensor, + Int4TilePackedTo4dTensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -166,6 +167,7 @@ "Int4MarlinSparseTensor", "IntxOpaqueTensor", "IntxUnpackedToInt8Tensor", + "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", # smooth quant - subject to change diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 798ff2efd9..a5e86cf726 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -76,6 +76,7 @@ Int4OpaqueTensor, Int4PreshuffledTensor, Int4Tensor, + Int4TilePackedTo4dTensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, QuantizeTensorToFloat8Kwargs, @@ -1142,6 +1143,12 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight + elif packing_format == PackingFormat.TILE_PACKED_TO_4D: + new_weight = Int4TilePackedTo4dTensor.from_hp( + weight, + block_size, + ) + return new_weight else: raise ValueError(f"Unsupported packing format: {packing_format}") @@ -1516,10 +1523,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. + warnings.warn( + """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. from torchao.dtypes import SemiSparseLayout - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""") + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" + ) return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) @@ -2095,7 +2104,10 @@ def __post_init__(self): assert self.granularity.axis == 0, ( f"axis must be 0 with PerAxis, but got {self.granularity.axis}" ) - assert self.mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], ( + assert self.mapping_type in [ + MappingType.ASYMMETRIC, + MappingType.SYMMETRIC, + ], ( f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" ) diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index ba969fff00..788e554692 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -41,6 +41,11 @@ class PackingFormat(str, Enum): """ UNPACKED_TO_INT8 = "unpacked_to_int8" + """ + tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization + """ + TILE_PACKED_TO_4D = "tile_packed_to_4d" + """ Opaque packing format that's used for tensors that does not have a predefined packing format (that may be decided on hardware, tensor shape, library availability etc.) and it's not diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 863608050e..fb4c6bcc11 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -14,6 +14,7 @@ from .int4.int4_tensor import ( Int4Tensor, ) +from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor from .intx.intx_opaque_tensor import ( IntxOpaqueTensor, ) @@ -25,6 +26,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", + "Int4TilePackedTo4dTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "IntxOpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py index 3394822214..e69de29bb2 100644 --- a/torchao/quantization/quantize_/workflows/int4/__init__.py +++ b/torchao/quantization/quantize_/workflows/int4/__init__.py @@ -1,7 +0,0 @@ -from .int4_preshuffled_tensor import Int4PreshuffledTensor -from .int4_tensor import Int4Tensor - -__all__ = [ - "Int4PreshuffledTensor", - "Int4Tensor", -] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py new file mode 100644 index 0000000000..f7237932df --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py @@ -0,0 +1,312 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch + +from torchao.utils import TorchAOBaseTensor, fill_defaults, find_multiple + +__all__ = [ + "Int4TilePackedTo4dTensor", +] + +aten = torch.ops.aten + + +class Int4TilePackedTo4dTensor(TorchAOBaseTensor): + """ + int4 quantization with tile packed to 4d packing format for groupwise quantization + + Tensor Attributes: + qdata: tile packed to 4d int4 weight, 4-d tensor of dimension: + [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] + (unpacked Tensor shape is n * k) + (inner_k_tiles is fixed to 8 for Int4TilePackedTo4dTensor) + scale_and_zero: combined scale and zero point tensor packed for tinygemm kernels + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity, + for example groupwise quantization will have block_size (1, group_size) + shape: shape of the original Tensor + + Note on Details for tile packed to 4d packing format: + + This is used by tinygemm kernels `_weight_int4pack_mm`. The weight is stored as + a 4-d packed tensor with specific packing format for efficient computation on tensor cores. + The packing format optimizes for tensor core matrix multiplication performance. + """ + + tensor_data_names = ["qdata", "scale_and_zero"] + tensor_attribute_names = ["block_size", "shape"] + + def __new__( + cls, + qdata: torch.Tensor, + scale_and_zero: torch.Tensor, + block_size: List[int], + shape: torch.Size, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = torch.bfloat16 # This tensor subclass only supports bfloat16 + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale_and_zero: torch.Tensor, + block_size: List[int], + shape: torch.Size, + ): + self.qdata = qdata + self.scale_and_zero = scale_and_zero + self.block_size = block_size + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + + @classmethod + def from_hp( + cls, + hp_tensor: torch.Tensor, + block_size: List[int], + ): + assert len(block_size) == hp_tensor.ndim, ( + f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {hp_tensor.ndim=}" + ) + + assert all(x == 1 for x in block_size[:-1]), ( + f"Only per group quantization is supported, got block_size: {block_size}" + ) + + assert hp_tensor.dtype == torch.bfloat16, ( + f"Only bfloat16 is supported for Int4TilePackedTo4dTensor, got {hp_tensor.dtype}" + ) + + original_shape = hp_tensor.shape + # use a fixed inner_k_tiles value to simplify the argument list and config + # for Int4TilePackedTo4dTensor + inner_k_tiles = 8 + + # Validate kernel requirements + orig_out_features, orig_in_features = hp_tensor.shape[-2:] + # TODO: relax checks to enable quantizing in other platoforms and run in A100 + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Cannot use tinygemm int4 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for tensor core kernels." + ) + + # Pre-process: pad to required dimensions + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + hp_tensor_padded = torch.nn.functional.pad( + hp_tensor, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + + # Quantize + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + + from torchao.quantization.quant_primitives import ( + MappingType, + _choose_qparams_affine_tinygemm, + _quantize_affine_tinygemm, + ) + + # Calculate scale and zero_point for tinygemm + scale, zero_point = _choose_qparams_affine_tinygemm( + hp_tensor_padded, + mapping_type=MappingType.ASYMMETRIC, + block_size=tuple(block_size), + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + scale_dtype=hp_tensor.dtype, + zero_point_dtype=hp_tensor.dtype, + ) + + # Quantize for tinygemm + int_data = _quantize_affine_tinygemm( + hp_tensor_padded, + block_size, + scale, + zero_point, + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + + # Convert to packed format + def quant_2d(int_data_2d): + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( + torch.uint8 + ) + return torch.ops.aten._convert_weight_to_int4pack( + int_data_2d.contiguous(), inner_k_tiles + ) + + if int_data.dim() == 3: # for moe quant + num_experts = int_data.shape[0] + packed_weight_list = [] + for expert in range(num_experts): + packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0)) + packed_weight = torch.cat(packed_weight_list, dim=0) + scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1) + zero_point = ( + zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1) + if zero_point is not None + else None + ) + else: + assert int_data.dim() == 2 + packed_weight = quant_2d(int_data) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = ( + zero_point.reshape(int_data.shape[0], -1) + if zero_point is not None + else None + ) + + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) + + return cls( + qdata=packed_weight, + scale_and_zero=scale_and_zero, + block_size=block_size, + shape=original_shape, + ) + + +implements = Int4TilePackedTo4dTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale_and_zero.is_contiguous(), ( + "Expected scale_and_zero to be contiguous" + ) + + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + # weight is packed from padded (out_features, in_features) weight tensor + # (same dimension requirement as F.linear weight) + packed_weight = weight_tensor.qdata + scale_and_zero = weight_tensor.scale_and_zero + original_shape = weight_tensor.shape + + orig_act_size = input_tensor.size() + orig_dtype = input_tensor.dtype + + # Folds batch dimension into the first dimension + act_mat = input_tensor.reshape(-1, input_tensor.shape[-1]).to(torch.bfloat16) + pad_size = find_multiple(act_mat.shape[-1], 1024) + act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[-1] + if act_mat.numel() == 0: # handling for empty input + y = act_mat + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat, packed_weight, groupsize, scale_and_zero + ) + # remove out_feature padding + orig_out_features = original_shape[-2] + y = y[:, :orig_out_features] + + # Unfold the batch dimension + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias.to(y.dtype) + return y.to(orig_dtype) + + +@implements(aten.slice.Tensor) +def _(func, _types, args, _kwargs): + """Slice operation for tensor core tiled packed tensor""" + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + cur_shape = self.shape + + assert len(cur_shape) == 2 + assert self.qdata.dim() == 4 + # qdata has shape [n/8, k/(inner_k_tiles*16), 32, inner_k_tiles/2] + n_by_8, k_by_inner_tiles, _, _ = self.qdata.shape + sz_dim1, sz_dim0, _ = self.scale_and_zero.shape + + data_len = cur_shape[dim] + assert dim in [ + 0, + 1, + ], ( + f"Int4TilePackedTo4dTensor slice: attempting to run {func}, with dim={dim}, that is not supported" + ) + + if dim == 0: + pw_len = n_by_8 + sz_len = sz_dim0 + else: + pw_len = k_by_inner_tiles + sz_len = sz_dim1 + + if pw_len == 0 or sz_len == 0: + return Int4TilePackedTo4dTensor( + self.qdata, + self.scale_and_zero, + self.block_size, + self.shape, + ) + + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + qdata = aten.slice(self.qdata, dim, start_pw, end_pw, step) + scale_and_zero = aten.slice(self.scale_and_zero, 1 - dim, start_sz, end_sz, step) + + # Calculate new shape after slicing + new_shape = list(self.shape) + new_shape[dim] = end - start + + block_size = list(self.block_size) + block_size[dim] = min(block_size[dim], new_shape[dim]) + + return Int4TilePackedTo4dTensor( + qdata, + scale_and_zero, + block_size, + new_shape, + ) + + +Int4TilePackedTo4dTensor.__module__ = "torchao.quantization" + +# Allow a model with Int4TilePackedTo4dTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4TilePackedTo4dTensor]) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 33def3f998..762fb31b30 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -455,18 +455,23 @@ def _test_slice_and_copy_similar_to_vllm(self, config): param = l.weight param_data = param.data param_data = param_data.narrow(output_dim, start_idx, shard_size) - orig_value = param_data.qdata[0][0].item() + orig_value = param_data.qdata[0][0] loaded_weight = dummy_l.weight loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0] - assert orig_value != loaded_weight.qdata[0][0] + assert not torch.equal(orig_value, loaded_weight.qdata[0][0]) param_data.copy_(loaded_weight) # making sure param.data is updated to loaded_weight - assert param_data.qdata[0][0] == loaded_weight.qdata[0][0] - assert torch.equal(param_data.scale, loaded_weight.scale) + assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0]) + if hasattr(param_data, "scale"): + assert torch.equal(param_data.scale, loaded_weight.scale) if hasattr(param_data, "zero_point"): assert torch.equal(param_data.zero_point, loaded_weight.zero_point) + if hasattr(param_data, "scale_and_zero"): + assert torch.equal( + param_data.scale_and_zero, loaded_weight.scale_and_zero + ) def _test_moe_weight_reshape_ops(self, config): """This is testing the op call sequence in saving and loading quantization