From 13d643a2ebc1ae96569b009e8c8c5de1358d22d1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 7 May 2024 13:38:44 -0700 Subject: [PATCH] Enable dispatch to tinygemm int4 and int8 kernels for unified quantized tensor Summary: This adds some dispatch to the tinygemm kernels for cuda, although need to resolve implementation mismatch problem for tinygemm first Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 93 ++++++++++++++++++-- test/quantization/test_quant_primitives.py | 2 +- torchao/quantization/subclass.py | 98 +++++++++++++++++++++- 3 files changed, 180 insertions(+), 13 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 10d36f0c1b..6a15533b14 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -9,7 +9,6 @@ import unittest import torch import os -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( prepare_pt2e, convert_pt2e, @@ -36,7 +35,7 @@ def dynamic_quant(model, example_inputs): - m = capture_pre_autograd_graph(model, example_inputs) + m = torch.export.export(model, example_inputs).module() quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) @@ -50,14 +49,14 @@ def _apply_dynamic_quant(model): """ _replace_with_custom_fn_if_matches_filter( model, - lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))), + lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)), lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) return model def capture_and_prepare(model, example_inputs): - m = capture_pre_autograd_graph(model, example_inputs) + m = torch.export.export(model, example_inputs) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) m = prepare_pt2e(m, quantizer) # TODO: we can run the weight observer in convert_pt2e so that user don't need to run this @@ -88,13 +87,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module: return model class ToyLinearModel(torch.nn.Module): - def __init__(self): + def __init__(self, m=64, n=32, k=64): super().__init__() - self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) - self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) + self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, 64).to(torch.float),) + return (torch.randn(1, self.linear1.in_features).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -104,8 +103,10 @@ def forward(self, x): class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() + example_inputs = m.example_inputs() m = _apply_dynamic_quant(m) - quantized = m(*m.example_inputs()) + example_inputs = (torch.randn(1, 64),) + quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) # m = torch.compile(m, mode="max-autotune") @@ -442,7 +443,81 @@ def get_per_token_block_size(x): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + def test_quantized_tensor_subclass_int4(self): + from torchao.quantization.subclass import TinygemmAffineQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + import copy + + # weight settings + groupsize = 32 + mapping_type = MappingType.ASYMMETRIC + block_size = (1, groupsize) + eps = 1e-6 + preserve_zero = False + + # weight only quantization + input_quant_func = None + + # use 1024 so that we don't need padding + m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + m_copy = copy.deepcopy(m) + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) + + def to_quantized(weight): + return TinygemmAffineQuantizedTensor.from_float(weight, mapping_type, block_size, eps, input_quant_func=input_quant_func) + + m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) + m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) + assert isinstance(m.linear1.weight, TinygemmAffineQuantizedTensor) + assert isinstance(m.linear2.weight, TinygemmAffineQuantizedTensor) + + # reference + from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors + change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize) + + res = m(*example_inputs) + ref = m_copy(*example_inputs) + + torch.testing.assert_close(res, ref, rtol=0.00001, atol=0.02) + + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + def test_quantized_tensor_subclass_int8(self): + from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + import copy + + # weight settings + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # weight only quantization + input_quant_func = None + + m = ToyLinearModel().eval().to(torch.bfloat16) + m_copy = copy.deepcopy(m) + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) + + def to_quantized(weight): + block_size = (1, weight.shape[1]) + return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func) + + m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) + m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + assert isinstance(m.linear2.weight, AffineQuantizedTensor) + + # reference + from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors + change_linear_weights_to_int8_woqtensors(m_copy) + + res = m(*example_inputs) + ref = m_copy(*example_inputs) + torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2) if __name__ == "__main__": diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 291039e42a..5adbb183b7 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -353,7 +353,7 @@ def test_tinygemm_get_groupwise_affine_qparams(self): preserve_zero=False, ) - def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point): + def int_zero_point_to_float(zero_point, scale, quant_min, mid_point): return (quant_min - zero_point + mid_point) * scale mid_point = 2 ** (n_bit - 1) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 6128720d4d..3707ab953d 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -14,7 +14,9 @@ dynamically_quantize_per_channel, groupwise_affine_quantize_tensor, quant_int8_dynamic_per_token_linear, + pack_tinygemm_scales_and_zeros, unpack_tinygemm_scales_and_zeros, + groupwise_affine_quantize_tensor_from_qparams, choose_qparams_affine, quantize_affine, dequantize_affine, @@ -619,7 +621,7 @@ class AffineQuantizedTensor(torch.Tensor): shape (torch.Size): the shape for the Tensor quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes input Tensor as input and outputs an AffineQuantizedTensor object + input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object dtype: dtype for external representation of the tensor, e.g. torch.float32 """ @@ -635,6 +637,7 @@ def __new__( quant_max: Optional[int] = None, input_quant_func: Optional[Callable] = None, dtype=None, + # TODO: remove args and kwargs *args, **kwargs ): @@ -677,7 +680,9 @@ def __repr__(self): f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})" ) - def dequantize(self, output_dtype=torch.float32): + def dequantize(self, output_dtype=None): + if output_dtype is None: + output_dtype = self.dtype return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype) def __tensor_flatten__(self): @@ -740,7 +745,54 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args) > 2 else None, ) - if weight_qtensor.input_quant_func is not None: + if weight_qtensor.input_quant_func is None: + is_cuda = args[0].is_cuda + is_cpu = args[0].device == torch.device("cpu") + # weight only quantization + is_int8 = ( + weight_qtensor.int_data.dtype == torch.int8 and + weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and + weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127 + ) + is_uint4 = ( + weight_qtensor.int_data.dtype == torch.int32 and + weight_qtensor.quant_min == 0 and + weight_qtensor.quant_max == 15 + ) + + # TODO: enable cpu and mps path as well + # TODO: make sure weight dimension matches the expectation of the int4mm kernel + # TODO: move this to TinygemmAffineQuantizedTensor + if ( + is_cuda and + is_uint4 and + weight_qtensor.dtype == torch.bfloat16 and + len(weight_qtensor.shape) == 2 and + weight_qtensor.block_size[0] == 1 + ): + # groupwise int4 quantization + # TODO: currently doing packing on the fly, we'll need to figure out + # the API to do packing before hand + # TODO: expose the arg + innerKTiles = 8 + packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) + groupsize = weight_qtensor.block_size[-1] + return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) + elif ( + is_cpu and + is_int8 and + len(weight_qtensor.shape) == 2 and + len(weight_qtensor.block_size) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.block_size[1] == weight_qtensor.shape[1] + ): + # TODO: enable mps path as well + # per channel int8 weight only quantizated mm + return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + else: + # dynamic quantization + # TODO: enable int8 dynamic quant dispatch input_tensor = weight_qtensor.input_quant_func(input_tensor) input_tensor = input_tensor.dequantize() weight_tensor = weight_qtensor.dequantize() @@ -865,3 +917,43 @@ def __torch_dispatch__(cls, func, types, args, kwargs): kwargs, args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) + + +# TODO: add padding support +class TinygemmAffineQuantizedTensor(AffineQuantizedTensor): + @classmethod + def from_float( + cls, + input_float, + mapping_type, + block_size, + eps = None, + scale_dtype = None, + zero_point_dtype = None, + input_quant_func = None, + ): + # TODO: replace this with uint4 dtype + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + preserve_zero = False + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero) + def int_zero_point_to_float(zero_point, scale, quant_min, mid_point): + return (quant_min - zero_point + mid_point) * scale + + mid_point = (quant_min + quant_max + 1) / 2 + zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point) + n_bit = 4 + groupsize = block_size[1] + int_data = groupwise_affine_quantize_tensor_from_qparams(input_float, scale, zero_point_float, n_bit, groupsize) + return cls( + int_data, + scale, + zero_point_float, + block_size, + input_float.shape, + quant_min, + quant_max, + input_quant_func=input_quant_func, + dtype=input_float.dtype + )