diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 0b5fd64120..5cd5e9d459 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -9,6 +9,7 @@ # ruff: noqa: F841 +import copy import unittest import torch @@ -21,6 +22,7 @@ weight_observer_range_neg_127_to_127, ) from torch.fx import Node +from torch.testing import FileCheck from torch.testing._internal.common_quantization import ( NodeSpec as ns, ) @@ -1630,6 +1632,101 @@ def forward(self, x): if key != FROM_NODE_KEY: self.assertEqual(n.meta[key], weight_meta[key]) + def test_constant_folding_pass(self): + from torchao.quantization import ( + MappingType, + PerGroup, + PerToken, + ) + from torchao.quantization.pt2e._affine_quantization import ( + AffineQuantizedMinMaxObserver, + ) + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + from torchao.quantization.pt2e.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + ) + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.linear.default + ): + input_act = node.args[0] + assert isinstance(input_act, torch.fx.Node) + weight = node.args[1] + assert isinstance(weight, torch.fx.Node) + + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=None, + is_dynamic=False, + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( + # TODO: maybe align the arg name here + target_dtype=torch.uint8, + mapping_type=MappingType.SYMMETRIC, + granularity=PerToken(), + ), + ) + + weight_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=None, + is_dynamic=False, + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( + target_dtype=torch.uint8, + mapping_type=MappingType.SYMMETRIC, + granularity=PerGroup(group_size=128), + ), + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + }, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 20) + + def forward(self, x): + return self.linear(x) + + example_inputs = (torch.randn(5, 128),) + model = M() + quantizer = BackendAQuantizer() + m = torch.export.export(model.eval(), example_inputs, strict=True).module() + m = prepare_pt2e(m, quantizer) + # Calibration + m(*example_inputs) + # Get the quantized model + m_fold = copy.deepcopy(m) + m_fold = convert_pt2e(m_fold, fold_quantize=True) + + # If fold, check the graph only contains frozed params and no linear_weight + FileCheck().check("_frozen_param0").check_not("linear_weight").run(m_fold.code) + + m_not_fold = copy.deepcopy(m) + m_not_fold = convert_pt2e(m_not_fold, fold_quantize=False) + + # If not fold, check the graph doesn't contain frozed params and contain linear_weight + FileCheck().check_not("_frozen_param0").check("linear_weight").run( + m_not_fold.code + ) + def test_save_load(self): """Test save/load a quantized model""" m = self._get_pt2e_quantized_linear()