From fa4db667139c3f19655913c5f7e39f34c43e2b96 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 28 May 2024 17:38:20 -0700 Subject: [PATCH] Replace implementation for int8 dynamic quantization with call to `quantize` Summary: Previously we added `quantize` as a general API (https://github.com/pytorch/ao/pull/256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 1 + test/quantization/test_quant_api.py | 62 +++++++++++++- torchao/dtypes/aqt.py | 118 ++++++++++++++++++--------- torchao/quantization/quant_api.py | 19 +++-- torchao/quantization/subclass.py | 53 ++++++++++-- torchao/quantization/utils.py | 5 +- 6 files changed, 206 insertions(+), 52 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 3bc8ded79..2cd34f427 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1033,6 +1033,7 @@ def _test_lin_weight_subclass_api_impl( @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 35b010783..6cdd9b148 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -118,6 +118,26 @@ def forward(self, x): x = self.linear2(x) return x + +def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): + """ + The deprecated implementation for int8 dynamic quant API, used as a reference for + numerics and performance + """ + from torchao.quantization.quant_api import _in_features_greater_than_16 + from torchao.quantization.quant_api import _is_linear + from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight + + if filter_fn is None: + filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( + *args + ) + + _replace_with_custom_fn_if_matches_filter( + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + ) + class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() @@ -492,8 +512,8 @@ def test_quantized_tensor_subclass_int8(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_dyn_quant(self): - # use 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + # use multiples of 1024 so that we don't need padding + m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") @@ -525,6 +545,44 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # make sure it compiles torch._export.aot_compile(m_unwrapped, example_inputs) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation") + def test_quantized_tensor_subclass_int8_dyn_quant_perf(self): + m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + m_ref = copy.deepcopy(m) + # setting batch_size to 20 to be compatible with the kernel + example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") + + from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors + change_linear_weights_to_int8_dqtensors(m) + + # reference + _ref_change_linear_weights_to_int8_dqtensors(m_ref) + + res = m(*example_inputs) + ref = m_ref(*example_inputs) + + self.assertTrue(torch.equal(res, ref)) + + # perf comparison + from torchao.utils import benchmark_model + # warmup + WARMUP = 5 + RUNS = 100 + input_tensor = example_inputs[0] + m = torch.compile(m, mode='max-autotune', fullgraph=True) + + benchmark_model(m, WARMUP, input_tensor) + elapsed_time = benchmark_model(m, RUNS, input_tensor) + + m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True) + benchmark_model(m_ref, WARMUP, input_tensor) + ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor) + + print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}") + self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time) + if __name__ == "__main__": diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index f4b758ddc..f660a759c 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn): fn(self.zero_point), ) + def _change_shape(self, shape): + return self.__class__( + self.int_data.view(shape), self.scale, self.zero_point + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -186,6 +191,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) + if func is aten.view.default: + assert len(args) == 2 + new = args[0]._change_shape(args[1]) + return return_and_correct_aliasing(func, args, kwargs, new) + raise NotImplementedError( f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" ) @@ -245,6 +255,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] + # TODO: fix the unflatten logic return cls(packed_weight, scale_and_zero) def to(self, *args, **kwargs): @@ -470,6 +481,11 @@ def _apply_fn_to_data(self, fn): strides=self.stride(), ) + def _change_shape(self, shape, block_size): + return self.__class__( + self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride() + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): # Note: we only added cpu path here for 8da4w, this is for executorch, in the future @@ -491,13 +507,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) -@implements_aqt_torch_function(torch.nn.functional.linear) -def functional_linear(*args, **kwargs): - input_tensor, weight_qtensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) +def _quantized_linear_op(input_tensor, weight_qtensor, bias): is_cuda = weight_qtensor.is_cuda is_cpu = weight_qtensor.device == torch.device("cpu") if isinstance(weight_qtensor, AffineQuantizedTensor): @@ -508,14 +518,10 @@ def functional_linear(*args, **kwargs): # if input tensor is quantized, either dispatch to the int8 mm kernel # or just dequantize the input tensor input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) - input_tensor_dtype_is_expected = input_tensor.dtype in [ - torch.float, - torch.bfloat16 - ] if ( is_cuda and input_is_int8 and - input_tensor_dtype_is_expected and + input_tensor.dtype == weight_qtensor.dtype and input_tensor.layout == "plain" and weight_qtensor.layout == "plain" ): @@ -576,45 +582,83 @@ def functional_linear(*args, **kwargs): weight_qtensor.block_size[1] == weight_qtensor.shape[1] and weight_qtensor.layout == "plain" ): - # TODO: enable mps path as well + # TODO: enable cpu and mps efficient path # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale) - else: - weight_tensor = weight_qtensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - else: + w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t().contiguous() + orig_dtype = input_tensor.dtype + y = ( + torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + * weight_qtensor.scale + ) + y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y.to(orig_dtype) + + # is_cpu and is_mps only, some issue with is_contiguous() currently + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale) + + raise NotImplementedError("No specialized dispatch found for quantized linear op") + + +@implements_aqt_torch_function(torch.nn.functional.linear) +def functional_linear(*args, **kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - @implements_aqt_aten_ops([aten.mm.default, aten.addmm.default]) def aten_mm(func, *args, **kwargs): if not args[0].is_floating_point(): raise NotImplementedError(f"{func} is not implemented for non floating point input") + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[1], args[2], args[0], ) + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(bias, input_tensor, weight_tensor) else: - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[0], args[1], - None if len(args) == 2 else args[2], + None ) - weight_tensor = weight_qtensor.dequantize() - return func(input_tensor, weight_tensor, bias) + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(input_tensor, weight_tensor) @implements_aqt_aten_ops([aten.detach.default]) def detach(func, *args, **kwargs): @@ -641,10 +685,10 @@ def _to_copy(func, *args, **kwargs): @implements_aqt_aten_ops([aten.t.default]) def t(func, *args, **kwargs): - # TODO: need to implement this - # args[0].transposed = not args[0].transposed - # new = args[0]._change_shape(args[0].shape[::-1]) - # return return_and_correct_aliasing(func, args, kwargs, new) - raise Exception("transpose not implemented yet") + block_size = args[0].block_size + assert len(block_size) == 2 + transposed_block_size = (block_size[1], block_size[0]) + new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size) + return return_and_correct_aliasing(func, args, kwargs, new) to_aq = AffineQuantizedTensor.from_float diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7ec88c749..04019c209 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -25,7 +25,11 @@ from typing import Any, Callable from .dynamic_quant import DynamicallyPerAxisQuantizedLinear -from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 +from .utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + unwrap_tensor_subclass, +) from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, @@ -33,6 +37,7 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, to_laq, + LinearActQuantizedTensor, ) from .quant_primitives import ( @@ -187,9 +192,13 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): *args ) - _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn - ) + if TORCH_VERSION_AFTER_2_4: + quantize(model, get_apply_int8dyn_quant(), filter_fn) + unwrap_tensor_subclass(model, filter_fn) + else: + _replace_with_custom_fn_if_matches_filter( + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + ) def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): @@ -282,7 +291,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - apply_weight_quant = lambda x: to_aqt(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + apply_weight_quant = lambda x: to_aq(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) # apply to modules under block0 submodule def filter_fn(module, fqn): diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index ee13512e9..972699f0b 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -610,6 +610,7 @@ def __new__( dtype = original_weight_tensor.dtype kwargs["dtype"] = dtype kwargs["requires_grad"] = False + kwargs["device"] = original_weight_tensor.device shape = original_weight_tensor.shape return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] @@ -664,6 +665,27 @@ def _apply_fn_to_data(self, fn): self.input_quant_func, ) + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.original_weight_tensor.to(**kwargs), + self.input_quant_func, + ) + def __torch_dispatch__(cls, func, types, args, kwargs): if ( func in [aten.mm.default, aten.addmm.default] @@ -674,25 +696,29 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"need mat1 shape: {args[1].shape} final" f"dim to match mat2 shape: {args[2].shape} first dim " ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[1], args[2], args[0], ) - aqt = self.input_quant_func(input_tensor) - return func(bias, aqt, weight_tensor) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return func(bias, aqt, original_weight_tensor) else: + # aten.mm.default assert args[0].shape[-1] == args[1].shape[0], ( f"need mat1 shape: {args[0].shape} final dim" f"to match mat2 shape: {args[1].shape} first dim" ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor = ( args[0], args[1], - None if len(args) == 2 else args[2], ) - aqt = self.input_quant_func(input_tensor) - return func(aqt, weight_tensor, bias) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return func(aqt, original_weight_tensor) if func is aten.detach.default: return return_and_correct_aliasing( @@ -704,6 +730,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) + if func is aten._to_copy.default: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + if func is aten.t.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t) + ) + raise NotImplementedError( f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 78a76863f..e6787b0cf 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -133,11 +133,14 @@ def right_inverse(self, tensor): def unwrap_tensor_subclass(model, filter_fn=None): for name, child in model.named_children(): + # make sure child.weight is a tensor subclass if ( isinstance(child, torch.nn.Linear) and hasattr(child, "weight") and type(child.weight) is not torch.Tensor and - isinstance(child.weight, torch.Tensor) + type(child.weight) is not torch.nn.Parameter and + isinstance(child.weight, torch.Tensor) and + issubclass(type(child.weight), torch.Tensor) ): parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass()) unwrap_tensor_subclass(child)