Skip to content

Commit 039cef4

Browse files
authored
Add marlin and semi sparse + quant option to autoquant (#1399)
* Add marlin and semi sparse + quant option to autoquant Summary: Added DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST for autoquant (v1) that contains: AQDefaultLinearWeight, AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight (float16 only) and AQInt8DynamicallyQuantizedSemiSparseLinearWeight Test Plan: tested on llama and sam python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress autoquant-sparse +cuda,vit_h,32,10271,12,25.575582921440905,39.099793074967025,0.5424332682384179,max-autotune,torch.bfloat16,autoquant-sparse,False,True,True,32,154,4928,None,None Baseline: around 22/23 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-sparse --precision float16 Average tokens/sec: 160.55 Base: Average tokens/sec: 110.47 Reviewers: Subscribers: Tasks: Tags: * ruff
1 parent 63b30ca commit 039cef4

File tree

5 files changed

+59
-2
lines changed

5 files changed

+59
-2
lines changed

torchao/_models/llama/generate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,13 @@ def ffn_or_attn_only(mod, fqn):
685685
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
686686
example_input=inputs,
687687
)
688+
if "autoquant-sparse" == quantization:
689+
model = autoquant(
690+
model,
691+
manual=True,
692+
qtensor_class_list = torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
693+
example_input=inputs,
694+
)
688695
if "autoquant-all" == quantization:
689696
all_qtensor_classes = (
690697
torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST

torchao/_models/sam/eval_combo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ def mlp_only(mod, name):
364364
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
365365
elif "autoquant-float8" == compress:
366366
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
367+
elif "autoquant-sparse" == compress:
368+
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST)
367369
elif "autoquant-all" == compress:
368370
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST)
369371
else:

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DEFAULT_AUTOQUANT_CLASS_LIST,
1515
DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
1616
DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
17+
DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
1718
OTHER_AUTOQUANT_CLASS_LIST,
1819
autoquant,
1920
)
@@ -92,6 +93,7 @@
9293
"DEFAULT_AUTOQUANT_CLASS_LIST",
9394
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
9495
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
96+
"DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST",
9597
"OTHER_AUTOQUANT_CLASS_LIST",
9698
"ALL_AUTOQUANT_CLASS_LIST",
9799
# top level API - manual

torchao/quantization/autoquant.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
from torchao.dtypes import (
77
AffineQuantizedTensor,
88
Float8Layout,
9+
MarlinSparseLayout,
910
PlainLayout,
11+
SemiSparseLayout,
1012
TensorCoreTiledLayout,
1113
)
14+
from torchao.dtypes.utils import Layout
1215
from torchao.float8.inference import Float8MMConfig
1316
from torchao.kernel import safe_int_mm
1417
from torchao.quantization.linear_activation_quantized_tensor import (
@@ -46,6 +49,7 @@
4649
"DEFAULT_AUTOQUANT_CLASS_LIST",
4750
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
4851
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
52+
"DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST",
4953
"OTHER_AUTOQUANT_CLASS_LIST",
5054
"ALL_AUTOQUANT_CLASS_LIST",
5155
]
@@ -406,6 +410,8 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT
406410
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
407411
"""
408412

413+
layout: Layout = PlainLayout()
414+
409415
@classmethod
410416
def from_float(cls, weight):
411417
# TODO test if this is valid
@@ -414,6 +420,9 @@ def from_float(cls, weight):
414420
# if in_features <= 16:
415421
# return weight
416422

423+
if weight.dim() != 2:
424+
return weight
425+
417426
# avoid circular dep
418427
from torchao.dtypes import to_affine_quantized_intx
419428

@@ -439,7 +448,7 @@ def get_per_token_block_size(x):
439448
input_eps = 1e-5
440449
input_quant_min = -127
441450
input_quant_max = 127
442-
_layout = PlainLayout()
451+
_layout = cls.layout
443452
input_quant_func = lambda x: to_affine_quantized_intx(
444453
x,
445454
input_mapping_type,
@@ -526,6 +535,16 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
526535
return res_f
527536

528537

538+
class AQInt8DynamicallyQuantizedSemiSparseLinearWeight(
539+
AQInt8DynamicallyQuantizedLinearWeight
540+
):
541+
layout: Layout = SemiSparseLayout()
542+
543+
@classmethod
544+
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
545+
return super()._autoquant_test(act_mat, weight, bias, best_time, None)
546+
547+
529548
class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
530549
"""
531550
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
@@ -613,14 +632,16 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
613632
"""
614633

615634
group_size: int = 32
635+
layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8)
616636

617637
@classmethod
618638
def from_float(cls, weight):
619639
group_size = cls.group_size
620-
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
640+
_layout = cls.layout
621641

622642
if weight.shape[-1] % group_size != 0:
623643
return weight
644+
624645
use_hqq = True
625646
mapping_type = MappingType.ASYMMETRIC
626647
block_size = (1, group_size)
@@ -631,6 +652,13 @@ def from_float(cls, weight):
631652
preserve_zero = False
632653
zero_point_dtype = torch.bfloat16
633654
zero_point_domain = ZeroPointDomain.FLOAT
655+
656+
if isinstance(_layout, MarlinSparseLayout):
657+
mapping_type = MappingType.SYMMETRIC
658+
preserve_zero = True
659+
zero_point_domain = ZeroPointDomain.INT
660+
use_hqq = False
661+
634662
return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(
635663
weight,
636664
mapping_type,
@@ -665,6 +693,13 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight(
665693
group_size: int = 256
666694

667695

696+
class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight(
697+
AQInt4G32WeightOnlyQuantizedLinearWeight
698+
):
699+
group_size: int = 128
700+
layout: Layout = MarlinSparseLayout()
701+
702+
668703
class AQDefaultLinearWeight(torch.Tensor, AQMixin):
669704
"""
670705
A class to be used in concert with AutoQuantizableLinearWeight to provide a
@@ -949,16 +984,24 @@ def get_weight_block_size(x):
949984
]
950985

951986
OTHER_AUTOQUANT_CLASS_LIST = [
987+
AQDefaultLinearWeight,
952988
AQFloat8WeightOnlyQuantizedLinearWeight,
953989
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
954990
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
955991
]
956992

993+
DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST = [
994+
AQDefaultLinearWeight,
995+
AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight,
996+
AQInt8DynamicallyQuantizedSemiSparseLinearWeight,
997+
]
998+
957999
ALL_AUTOQUANT_CLASS_LIST = list(
9581000
set(
9591001
DEFAULT_AUTOQUANT_CLASS_LIST
9601002
+ DEFAULT_INT4_AUTOQUANT_CLASS_LIST
9611003
+ DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
1004+
+ DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST
9621005
)
9631006
)
9641007
if is_sm_at_least_89():

torchao/quantization/quant_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,9 @@ def apply_int4_weight_only_quant(weight):
676676
mapping_type = MappingType.SYMMETRIC
677677
preserve_zero = True
678678
zero_point_domain = ZeroPointDomain.INT
679+
assert (
680+
group_size == 128 or group_size == weight.shape[-1]
681+
), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}"
679682

680683
return to_affine_quantized_intx(
681684
weight,

0 commit comments

Comments
 (0)