From fec5420a753bdc1259cf10f75b40fbc9b0c36487 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Fri, 25 Oct 2024 14:15:54 -0700 Subject: [PATCH] is_linear fix for MHA (#1141) * [wip] fix for mha Summary: filter fn may need access to parent types of module as in the case with mha Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: * testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing dtype and device in test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 28 ++++++++++++++++++++++++++++ torchao/quantization/quant_api.py | 1 + 2 files changed, 29 insertions(+) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6b90b38a9..078dfe727 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1327,6 +1327,34 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") + def test_autoquant_mha(self, device, dtype): + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + class MHAModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.mha = torch.nn.MultiheadAttention(4096, 32) + self.lin = torch.nn.Linear(4096, 4096) + + def forward(self, x): + y = self.mha(x, x, x)[0] + return self.lin(y) + + mod = MHAModel().to(device).to(dtype) + input = torch.randn(1,1,4096).to(device).to(dtype) + out=mod(*input) + + torchao.autoquant(mod, set_inductor_config=False) + assert not isinstance(mod.mha.out_proj.weight, AutoQuantizableLinearWeight) + assert isinstance(mod.lin.weight, AutoQuantizableLinearWeight) + mod(*input) + from torchao.quantization.autoquant import AUTOQUANT_CACHE + assert len(AUTOQUANT_CACHE)>0 + + + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_manual(self, device, dtype): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a2e67bb6c..503da1008 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -235,6 +235,7 @@ def _is_linear(mod, *args): and not isinstance(mod.weight, AffineQuantizedTensor) and not isinstance(mod.weight, LinearActivationQuantizedTensor) and not isinstance(mod.weight, AffineFakeQuantizedTensor) + and not isinstance(mod, nn.modules.linear.NonDynamicallyQuantizableLinear) ) import torch.nn.utils.parametrize as parametrize