Skip to content

Commit

Permalink
Additional fixes for autoquant serialization (#1486)
Browse files Browse the repository at this point in the history
Summary:
Can't overwrite `.layout` attribute of a Tensor since `tensor.layout` should be `torch.layout`

discovered when loading the autoquantized model https://huggingface.co/jerryzh168/llama3-8b-autoquant

Test Plan:
tested locally with huggingface transformers
Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jan 3, 2025
1 parent d9fe2c2 commit c59bce5
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
"""

layout: Layout = PlainLayout()
aq_layout: Layout = PlainLayout()

@classmethod
def from_float(cls, weight):
Expand Down Expand Up @@ -442,7 +442,7 @@ def get_weight_block_size(x):
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
_layout = cls.layout
_layout = cls.aq_layout
block_size = get_weight_block_size(weight)

weight = to_affine_quantized_intx(
Expand Down Expand Up @@ -616,12 +616,13 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""

group_size: int = 32
layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8)
# can't override the `layout` attribute
aq_layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8)

@classmethod
def from_float(cls, weight):
group_size = cls.group_size
_layout = cls.layout
_layout = cls.aq_layout

if weight.shape[-1] % group_size != 0:
return weight
Expand Down Expand Up @@ -681,7 +682,7 @@ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight(
AQInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 128
layout: Layout = MarlinSparseLayout()
aq_layout: Layout = MarlinSparseLayout()


class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
Expand Down

0 comments on commit c59bce5

Please sign in to comment.