6
6
from torchao .dtypes import (
7
7
AffineQuantizedTensor ,
8
8
Float8Layout ,
9
+ MarlinSparseLayout ,
9
10
PlainLayout ,
11
+ SemiSparseLayout ,
10
12
TensorCoreTiledLayout ,
11
13
)
14
+ from torchao .dtypes .utils import Layout
12
15
from torchao .float8 .inference import Float8MMConfig
13
16
from torchao .kernel import safe_int_mm
14
17
from torchao .quantization .linear_activation_quantized_tensor import (
46
49
"DEFAULT_AUTOQUANT_CLASS_LIST" ,
47
50
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST" ,
48
51
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST" ,
52
+ "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST" ,
49
53
"OTHER_AUTOQUANT_CLASS_LIST" ,
50
54
"ALL_AUTOQUANT_CLASS_LIST" ,
51
55
]
@@ -406,6 +410,8 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT
406
410
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
407
411
"""
408
412
413
+ layout : Layout = PlainLayout ()
414
+
409
415
@classmethod
410
416
def from_float (cls , weight ):
411
417
# TODO test if this is valid
@@ -414,6 +420,9 @@ def from_float(cls, weight):
414
420
# if in_features <= 16:
415
421
# return weight
416
422
423
+ if weight .dim () != 2 :
424
+ return weight
425
+
417
426
# avoid circular dep
418
427
from torchao .dtypes import to_affine_quantized_intx
419
428
@@ -439,7 +448,7 @@ def get_per_token_block_size(x):
439
448
input_eps = 1e-5
440
449
input_quant_min = - 127
441
450
input_quant_max = 127
442
- _layout = PlainLayout ()
451
+ _layout = cls . layout
443
452
input_quant_func = lambda x : to_affine_quantized_intx (
444
453
x ,
445
454
input_mapping_type ,
@@ -526,6 +535,16 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
526
535
return res_f
527
536
528
537
538
+ class AQInt8DynamicallyQuantizedSemiSparseLinearWeight (
539
+ AQInt8DynamicallyQuantizedLinearWeight
540
+ ):
541
+ layout : Layout = SemiSparseLayout ()
542
+
543
+ @classmethod
544
+ def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
545
+ return super ()._autoquant_test (act_mat , weight , bias , best_time , None )
546
+
547
+
529
548
class AQInt8WeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
530
549
"""
531
550
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
@@ -613,14 +632,16 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
613
632
"""
614
633
615
634
group_size : int = 32
635
+ layout : Layout = TensorCoreTiledLayout (inner_k_tiles = 8 )
616
636
617
637
@classmethod
618
638
def from_float (cls , weight ):
619
639
group_size = cls .group_size
620
- _layout = TensorCoreTiledLayout ( inner_k_tiles = 8 )
640
+ _layout = cls . layout
621
641
622
642
if weight .shape [- 1 ] % group_size != 0 :
623
643
return weight
644
+
624
645
use_hqq = True
625
646
mapping_type = MappingType .ASYMMETRIC
626
647
block_size = (1 , group_size )
@@ -631,6 +652,13 @@ def from_float(cls, weight):
631
652
preserve_zero = False
632
653
zero_point_dtype = torch .bfloat16
633
654
zero_point_domain = ZeroPointDomain .FLOAT
655
+
656
+ if isinstance (_layout , MarlinSparseLayout ):
657
+ mapping_type = MappingType .SYMMETRIC
658
+ preserve_zero = True
659
+ zero_point_domain = ZeroPointDomain .INT
660
+ use_hqq = False
661
+
634
662
return super (AQInt4G32WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_intx (
635
663
weight ,
636
664
mapping_type ,
@@ -665,6 +693,13 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight(
665
693
group_size : int = 256
666
694
667
695
696
+ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight (
697
+ AQInt4G32WeightOnlyQuantizedLinearWeight
698
+ ):
699
+ group_size : int = 128
700
+ layout : Layout = MarlinSparseLayout ()
701
+
702
+
668
703
class AQDefaultLinearWeight (torch .Tensor , AQMixin ):
669
704
"""
670
705
A class to be used in concert with AutoQuantizableLinearWeight to provide a
@@ -949,16 +984,24 @@ def get_weight_block_size(x):
949
984
]
950
985
951
986
OTHER_AUTOQUANT_CLASS_LIST = [
987
+ AQDefaultLinearWeight ,
952
988
AQFloat8WeightOnlyQuantizedLinearWeight ,
953
989
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight ,
954
990
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight ,
955
991
]
956
992
993
+ DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST = [
994
+ AQDefaultLinearWeight ,
995
+ AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight ,
996
+ AQInt8DynamicallyQuantizedSemiSparseLinearWeight ,
997
+ ]
998
+
957
999
ALL_AUTOQUANT_CLASS_LIST = list (
958
1000
set (
959
1001
DEFAULT_AUTOQUANT_CLASS_LIST
960
1002
+ DEFAULT_INT4_AUTOQUANT_CLASS_LIST
961
1003
+ DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
1004
+ + DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST
962
1005
)
963
1006
)
964
1007
if is_sm_at_least_89 ():
0 commit comments