diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py new file mode 100644 index 0000000000..d7d793685e --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + + +def get_config(group_size): + return Int4WeightOnlyConfig( + group_size=group_size, + packing_format="plain_int32", + version=2, + ) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.xpu.is_available(), "XPU not available") +class Int4PlainInt32Tensor(TestCase): + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + @parametrize("dtype", [torch.bfloat16, torch.half]) + @parametrize("group_size", [32, 64, 128]) + def test_linear(self, sizes, dtype, group_size): + device = "xpu" + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, get_config(group_size)) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @parametrize("dtype", [torch.bfloat16, torch.half]) + def test_module_path(self, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu") + quantize_(linear, get_config(group_size=128)) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(Int4PlainInt32Tensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ab49fb1d12..407a83bcd7 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -92,6 +92,7 @@ Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, + Int4PlainInt32Tensor, Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, @@ -163,6 +164,7 @@ "FbgemmConfig", # tensor subclasses "Int4Tensor", + "Int4PlainInt32Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", "IntxOpaqueTensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e83abd3953..682d07a2b1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -74,6 +74,7 @@ Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, + Int4PlainInt32Tensor, Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, @@ -522,7 +523,6 @@ def quantize_( torch._C._log_api_usage_once("torchao.quantization.quantize_") filter_fn = _is_linear if filter_fn is None else filter_fn - if isinstance(config, ModuleFqnToConfig): _replace_with_custom_fn_if_matches_filter_with_name( model, @@ -1131,6 +1131,12 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size, ) return new_weight + elif packing_format == PackingFormat.PLAIN_INT32: + new_weight = Int4PlainInt32Tensor.from_hp( + weight, + block_size, + ) + return new_weight elif packing_format == PackingFormat.MARLIN_SPARSE: new_weight = Int4MarlinSparseTensor.from_hp( weight, diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 788e554692..94d45917b9 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -41,6 +41,12 @@ class PackingFormat(str, Enum): """ UNPACKED_TO_INT8 = "unpacked_to_int8" + """ + plain_int32 is referring to the format used by int4 weight-only quantization. + which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32. + """ + PLAIN_INT32 = "plain_int32" + """ tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization """ diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index fb4c6bcc11..3402ffdefa 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -8,6 +8,9 @@ from .int4.int4_opaque_tensor import ( Int4OpaqueTensor, ) +from .int4.int4_plain_int32_tensor import ( + Int4PlainInt32Tensor, +) from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, ) @@ -26,6 +29,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", + "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py new file mode 100644 index 0000000000..388134f040 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + choose_qparams_affine, + quantize_affine, +) +from torchao.utils import ( + TorchAOBaseTensor, +) + +__all__ = [ + "Int4PlainInt32Tensor", +] + +aten = torch.ops.aten + + +class Int4PlainInt32Tensor(TorchAOBaseTensor): + """ + int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) + + Tensor Attributes: + qdata: (N, K/8), packed int4 weight, the data type is int32 here with 4*(int4*2), the original data type can be half and bfloat16 + scale: (K/group_size, N), dtype is the same as the original Tensor dtype + zero_point: (K/group_size, N), dtype is int8 + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity. + shape: shape of the original Tensor + + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape"] + + def __new__( + cls, + qdata, + scale, + zero_point, + block_size, + shape, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, qdata, scale, zero_point, block_size, shape): + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + assert w.ndim == 2 and w.device.type == "xpu", ( + f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + assert w.dtype in [torch.float16, torch.bfloat16], ( + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" + ) + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + scale_dtype = None + zero_point_dtype = torch.int32 + scale, zero_point = choose_qparams_affine( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + int_data = quantize_affine( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" + ) + packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + packed_weight.contiguous(), 8 + ) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + return Int4PlainInt32Tensor( + packed_weight, + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous().to(torch.int8), + block_size, + original_shape, + ) + + +implements = Int4PlainInt32Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert input_tensor.device.type == "xpu", ( + f"For XPU device only but got: {input_tensor.device}" + ) + assert isinstance(weight_tensor, Int4PlainInt32Tensor), ( + f"Expected weight_tensor to be Int4PlainInt32Tensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + act_mat = input_tensor + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + zero_point = weight_tensor.zero_point + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + act_mat, packed_weight, groupsize, scale, zero_point + ) + + # remove out_feature padding + assert weight_tensor.ndim == 2 + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) + + +Int4PlainInt32Tensor.__module__ = "torchao.quantization" + +# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4PlainInt32Tensor])