diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 92d2dcd5c..327948954 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -777,6 +777,18 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype): AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype ) + def test_autoquantizable_flatten_unflatten(self): + from torchao.quantization import DEFAULT_AUTOQUANT_CLASS_LIST + weight = torch.randn(16, 32) + qtensor_class_list = DEFAULT_AUTOQUANT_CLASS_LIST + aqw = AutoQuantizableLinearWeight.from_float(weight, qtensor_class_list) + tensor_data_name_dict, tensor_attributes = aqw.__tensor_flatten__() + tensor_data_dict = {name: getattr(aqw, name) for name in tensor_data_name_dict} + outer_size = aqw.size() + outer_stride = aqw.stride() + reconstructed = type(aqw).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") @unittest.skipIf(not is_H100, "Need H100 to run") diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index ee6bf9885..87cb5e265 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -217,7 +217,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None ): weight = tensor_data_dict["weight"] - qtensor_class_list, mode, dtype, shape = tensor_attributes[0] + qtensor_class_list, mode, dtype, shape = tensor_attributes return cls( weight, qtensor_class_list,