Skip to content

Commit

Permalink
is_linear fix for MHA (#1141)
Browse files Browse the repository at this point in the history
* [wip] fix for mha

Summary: filter fn may need access to parent types of module as in the
case with mha

Test Plan: TODO

Reviewers:

Subscribers:

Tasks:

Tags:

* testing

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing dtype and device in test

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing test

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Oct 25, 2024
1 parent e85c1a3 commit fec5420
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
28 changes: 28 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,34 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
def test_autoquant_mha(self, device, dtype):
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
class MHAModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.mha = torch.nn.MultiheadAttention(4096, 32)
self.lin = torch.nn.Linear(4096, 4096)

def forward(self, x):
y = self.mha(x, x, x)[0]
return self.lin(y)

mod = MHAModel().to(device).to(dtype)
input = torch.randn(1,1,4096).to(device).to(dtype)
out=mod(*input)

torchao.autoquant(mod, set_inductor_config=False)
assert not isinstance(mod.mha.out_proj.weight, AutoQuantizableLinearWeight)
assert isinstance(mod.lin.weight, AutoQuantizableLinearWeight)
mod(*input)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
assert len(AUTOQUANT_CACHE)>0



@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
def test_autoquant_manual(self, device, dtype):
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def _is_linear(mod, *args):
and not isinstance(mod.weight, AffineQuantizedTensor)
and not isinstance(mod.weight, LinearActivationQuantizedTensor)
and not isinstance(mod.weight, AffineFakeQuantizedTensor)
and not isinstance(mod, nn.modules.linear.NonDynamicallyQuantizableLinear)
)

import torch.nn.utils.parametrize as parametrize
Expand Down

0 comments on commit fec5420

Please sign in to comment.