Skip to content

Commit

Permalink
Merge branch 'main' into add-api
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 authored May 24, 2024
2 parents 899ed62 + 163cb93 commit 82ec155
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
44 changes: 41 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
AQInt8DynamicallyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3
AQWeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,

)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
Expand Down Expand Up @@ -1104,7 +1105,6 @@ def test_weight_only_quant(self):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This test is flaky, we'll enable later")
def test_weight_only_quant_force_mixed_mm(self, device, dtype):
if device != "cuda":
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
Expand All @@ -1127,7 +1127,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
sqnr = compute_error(y_ref, y_wo)
self.assertGreaterEqual(sqnr, 42.75)
if device == "cuda":
self.assertTrue("mixed_mm" in code)
self.assertTrue("mixed_mm" in code, f"got code: {code}")

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down Expand Up @@ -1472,6 +1472,44 @@ def forward(self, x, y):
sqnr = SQNR(out, out2)
self.assertTrue(sqnr >= 30)

@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
[
(16, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_autoquant_double_access(self, device, dtype, m, k, n):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")

class DoubleAccess(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin1 = torch.nn.Linear(k, n)
self.lin2 = torch.nn.Linear(n, k)
self.lin3 = torch.nn.Linear(k, n)
self.lin3.weight = self.lin1.weight

def forward(self, x):
x = self.lin1(x)
x = self.lin2(x)
x = self.lin3(x)
return x

x_in = torch.randn(m, k, device=device, dtype=dtype)
model = DoubleAccess().to(device).to(dtype)
model(x_in)
torchao.autoquant(model)
assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight)
model(x_in)




class TestAOTI(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_qtensor.int_data)
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
Expand Down Expand Up @@ -384,7 +384,7 @@ def change_autoquantizable_to_quantized(model, **kwargs):
torch._dynamo.reset()

@torch.no_grad()
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs):
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["interpolate", .85], **aq_kwargs):
"""
wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model.
AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)
from .autoquant import autoquant
from .autoquant import autoquant, AutoQuantizableLinearWeight


__all__ = [
Expand Down Expand Up @@ -93,6 +93,7 @@ def _is_linear(mod, *args):
isinstance(mod, torch.nn.Linear)
and hasattr(mod, "weight")
and not isinstance(mod.weight, QuantizedLinearWeightBase)
and not isinstance(mod.weight, AutoQuantizableLinearWeight)
)


Expand Down

0 comments on commit 82ec155

Please sign in to comment.