diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 2425d341e2..fefda2583a 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -66,20 +66,27 @@ from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os from parameterized import parameterized +import itertools +import logging from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 +logger = logging.getLogger("INFO") + torch.manual_seed(0) config.cache_size_limit = 100 -COMMON_DEVICE_DTYPE=[ - ("cpu", torch.float32), - ("cpu", torch.float16), - ("cpu", torch.bfloat16), - ("cuda", torch.float32), - ("cuda", torch.float16), - ("cuda", torch.bfloat16), +TENSOR_SUBCLASS_APIS = [ + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + change_linear_weights_to_int4_woqtensors, ] +COMMON_DEVICES = ["cpu", "cuda"] + +COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + +COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)) + def combine_parameters(a, b): new_tuples = [] for (tuple1, tuple2) in itertools.product(a, b): @@ -87,10 +94,17 @@ def combine_parameters(a, b): return new_tuples def run_supported_device_dtype(test_method): + """Assumes that the 3rd arg (args[2]) of the decorated method is device and + there is a `test_dtype` kwarg or the 4th arg (args[3]) that indicates the dtype for testing + """ def wrapper(*args, **kwargs): - if args[2] == "cuda" and not torch.cuda.is_available(): + if len(args) < 3: + raise unittest.SkipTest("Not enoguh args") + device = args[2] + dtype = kwargs["test_dtype"] if "test_dtype" in kwargs else args[3] + if device == "cuda" and not torch.cuda.is_available(): raise unittest.SkipTest(f"Need CUDA available.") - if args[2] == "cuda" and torch.cuda.is_available() and kwargs['test_dtype'] == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): + if device == "cuda" and torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0): raise unittest.SkipTest("Need CUDA and SM80+ available.") return test_method(*args, **kwargs) return wrapper @@ -1145,6 +1159,7 @@ def _test_handle_save_load_meta_impl( min_sqnr=35, test_dtype=torch.bfloat16 ): + logger.info(f"TestSaveLoad: {api}, {test_device}, {test_dtype}") m, k, n = 32, 64, 32 class test_model(nn.Module): @@ -1170,7 +1185,9 @@ def forward(self, x): api(model) torch.save(model.state_dict(), "test.pth") # get quantized reference - model_qc = torch.compile(model, mode="max-autotune") + # model_qc = torch.compile(model, mode="max-autotune") + model_qc = torch.export.export(model, (x,)).module() + # model_qc = model ref_q = model_qc(x).detach() assert SQNR(ref_f, ref_q) > min_sqnr @@ -1187,7 +1204,8 @@ def forward(self, x): model = model.to(device=test_device, dtype=test_dtype).eval() # get quantized reference - model_qc = torch.compile(model, mode="max-autotune") + # model_qc = torch.compile(model, mode="max-autotune") + model_qc = model test = model_qc(x).detach() assert SQNR(ref_f, test) > min_sqnr @@ -1404,5 +1422,52 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) + +class TestAOTI(unittest.TestCase): + @run_supported_device_dtype + @torch.no_grad() + @parameterized.expand( + list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), + ) + def test_aoti(self, api, test_device, test_dtype): + logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}") + if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": + self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet") + m, k, n = 32, 64, 32 + + class test_model(nn.Module): + def __init__(self): + super().__init__() + self.lin1 = nn.Linear(k, n) + self.relu = nn.ReLU() + self.lin2 = nn.Linear(n, n) + + def forward(self, x): + x = self.lin1(x) + x = self.relu(x) + x = self.lin2(x) + return x + + x = torch.randn(m, k, dtype=test_dtype, device=test_device) + + # get float reference + model = test_model().to(dtype=test_dtype, device=test_device).eval() + ref_f = model(x) + + print("calling quant") + api(model) + + # running model + print("running model") + model(x) + print("model:", model) + print("model weight:", model.lin1.weight) + + # make sure it compiles + example_inputs = (x,) + print("compiling model") + torch._export.aot_compile(model, example_inputs) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a830d52d78..eee0925298 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -117,13 +117,50 @@ def apply_dynamic_quant(model, filter_fn=None): change_linear_weights_to_int8_dqtensors(model, filter_fn) -def _get_subclass_inserter(cls, **kwargs): +import torch.nn.utils.parametrize as parametrize + + +class ConstructTensorSubclass(torch.nn.Module): + def __init__(self, tensor_subclass_ctr, transposed, shape, groupsize, inner_k_tiles, dtype): + super().__init__() + self.tensor_subclass_ctr = tensor_subclass_ctr + self.transposed = transposed + self.shape = shape + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.dtype = dtype + + def forward(self, int_data, scales_and_zeros): + return Int4WeightOnlyQuantizedLinearWeight.from_qtensor_components(int_data, scales_and_zeros, self.transposed, self.shape, self.groupsize, self.inner_k_tiles, dtype=self.dtype) + + def right_inverse(self, tensor_subclass_instance): + # new_kwargs = {"groupsize": self.groupsize, "inner_k_tiles": self.inner_k_tiles} + # tensor_subclass_instance = self.tensor_subclass_ctr.from_float(input_float, **new_kwargs) + return tensor_subclass_instance.int_data, tensor_subclass_instance.scales_and_zeros + + +def _get_subclass_inserter(cls, use_param=False, **kwargs): method = kwargs.pop("method", "from_float") def insert_subclass(lin): - lin.weight = torch.nn.Parameter( - # cls.from_float(...) - getattr(cls, method)(lin.weight, **kwargs), requires_grad=False - ) + if use_param: + new_kwargs = {} + if "groupsize" in kwargs: + new_kwargs["groupsize"] = kwargs["groupsize"] + if "inner_k_tiles" in kwargs: + new_kwargs["inner_k_tiles"] = kwargs["inner_k_tiles"] + int_data, scales_and_zeros, transposed, groupsize, inner_k_tiles = cls.to_qtensor_components(lin.weight, **new_kwargs) + kwargs["transposed"] = transposed + kwargs["shape"] = lin.weight.shape + kwargs["dtype"] = lin.weight.dtype + kwargs["groupsize"] = groupsize + kwargs["inner_k_tiles"] = inner_k_tiles + lin.weight = torch.nn.Parameter(cls(int_data, scales_and_zeros, **kwargs), requires_grad=False) + parametrize.register_parametrization(lin, "weight", ConstructTensorSubclass(cls, **kwargs)) + else: + lin.weight = torch.nn.Parameter( + # cls.from_float(...) + getattr(cls, method)(lin.weight, **kwargs), requires_grad=False + ) return lin return insert_subclass @@ -168,9 +205,10 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): """ filter_fn = kwargs.pop("filter_fn", _is_linear) + print("kwargs in change linear to int4:", kwargs) _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, **kwargs), + _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, use_param=True, **kwargs), filter_fn, ) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 64689b8d95..74e11db86c 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -502,9 +502,21 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8): Int4WeightOnlyQuantizedLinearWeight.from_float(model.lin_mod.weight) ) """ + int_data, scales_and_zeros, transposed, groupsize, inner_k_tils = cls.to_qtensor_components(input_float, groupsize, inner_k_tiles) + return cls( + int_data, + scales_and_zeros, + transposed, + input_float.shape, + groupsize, + inner_k_tiles, + dtype=input_float.dtype, + ) + + @classmethod + def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): assert groupsize in [256, 128, 64, 32] assert inner_k_tiles in [8, 4, 2] - orig_shape = input_float.shape orig_out_features, orig_in_features = input_float.shape # padding @@ -520,13 +532,25 @@ def from_float(cls, input_float, groupsize=128, inner_k_tiles=8): input_float, 4, groupsize ) int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) + return int_data, scales_and_zeros, False, groupsize, inner_k_tiles + @classmethod + def from_qtensor_components( + cls, + int_data, + scales_and_zeros, + transposed, + shape, + groupsize, + inner_k_tiles, + **kwargs + ): return cls( int_data, scales_and_zeros, - False, - orig_shape, + transposed, + shape, groupsize, inner_k_tiles, - dtype=input_float.dtype, + **kwargs, )