Skip to content

Commit

Permalink
Fix an autoquant bug in flatten/unflatten
Browse files Browse the repository at this point in the history
Summary:
att

Test Plan:
python test/integration/test_integration.py -k test_autoquantizable_flatten_unflatten

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Nov 14, 2024
1 parent 06ad55a commit 01ad96a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
12 changes: 12 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,18 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

def test_autoquantizable_flatten_unflatten(self):
from torchao.quantization import DEFAULT_AUTOQUANT_CLASS_LIST
weight = torch.randn(16, 32)
qtensor_class_list = DEFAULT_AUTOQUANT_CLASS_LIST
aqw = AutoQuantizableLinearWeight.from_float(weight, qtensor_class_list)
tensor_data_name_dict, tensor_attributes = aqw.__tensor_flatten__()
tensor_data_dict = {name: getattr(aqw, name) for name in tensor_data_name_dict}
outer_size = aqw.size()
outer_stride = aqw.stride()
reconstructed = type(aqw).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)


@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None
):
weight = tensor_data_dict["weight"]
qtensor_class_list, mode, dtype, shape = tensor_attributes[0]
qtensor_class_list, mode, dtype, shape = tensor_attributes
return cls(
weight,
qtensor_class_list,
Expand Down

0 comments on commit 01ad96a

Please sign in to comment.