From c59bce585ee83bca6a2e32f6099600a6e4bbe631 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 3 Jan 2025 10:41:04 -0800 Subject: [PATCH] Additional fixes for autoquant serialization (#1486) Summary: Can't overwrite `.layout` attribute of a Tensor since `tensor.layout` should be `torch.layout` discovered when loading the autoquantized model https://huggingface.co/jerryzh168/llama3-8b-autoquant Test Plan: tested locally with huggingface transformers Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/autoquant.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 15166aca0d..e7806f07ad 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -411,7 +411,7 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ - layout: Layout = PlainLayout() + aq_layout: Layout = PlainLayout() @classmethod def from_float(cls, weight): @@ -442,7 +442,7 @@ def get_weight_block_size(x): target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 - _layout = cls.layout + _layout = cls.aq_layout block_size = get_weight_block_size(weight) weight = to_affine_quantized_intx( @@ -616,12 +616,13 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ group_size: int = 32 - layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8) + # can't override the `layout` attribute + aq_layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8) @classmethod def from_float(cls, weight): group_size = cls.group_size - _layout = cls.layout + _layout = cls.aq_layout if weight.shape[-1] % group_size != 0: return weight @@ -681,7 +682,7 @@ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight( AQInt4G32WeightOnlyQuantizedLinearWeight ): group_size: int = 128 - layout: Layout = MarlinSparseLayout() + aq_layout: Layout = MarlinSparseLayout() class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):