diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 7570700c65..8ec6acccc9 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -685,6 +685,13 @@ def ffn_or_attn_only(mod, fqn): qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, ) + if "autoquant-sparse" == quantization: + model = autoquant( + model, + manual=True, + qtensor_class_list = torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) if "autoquant-all" == quantization: all_qtensor_classes = ( torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index afc625a47d..09a3448d6a 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -364,6 +364,8 @@ def mlp_only(mod, name): autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) elif "autoquant-float8" == compress: autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST) + elif "autoquant-sparse" == compress: + autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST) elif "autoquant-all" == compress: autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST) else: diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 8b46d97dc6..14dfbab52b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -14,6 +14,7 @@ DEFAULT_AUTOQUANT_CLASS_LIST, DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, OTHER_AUTOQUANT_CLASS_LIST, autoquant, ) @@ -92,6 +93,7 @@ "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", + "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 949d156349..b8cd0125f0 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -6,9 +6,12 @@ from torchao.dtypes import ( AffineQuantizedTensor, Float8Layout, + MarlinSparseLayout, PlainLayout, + SemiSparseLayout, TensorCoreTiledLayout, ) +from torchao.dtypes.utils import Layout from torchao.float8.inference import Float8MMConfig from torchao.kernel import safe_int_mm from torchao.quantization.linear_activation_quantized_tensor import ( @@ -46,6 +49,7 @@ "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", + "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", "ALL_AUTOQUANT_CLASS_LIST", ] @@ -406,6 +410,8 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ + layout: Layout = PlainLayout() + @classmethod def from_float(cls, weight): # TODO test if this is valid @@ -414,6 +420,9 @@ def from_float(cls, weight): # if in_features <= 16: # return weight + if weight.dim() != 2: + return weight + # avoid circular dep from torchao.dtypes import to_affine_quantized_intx @@ -439,7 +448,7 @@ def get_per_token_block_size(x): input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - _layout = PlainLayout() + _layout = cls.layout input_quant_func = lambda x: to_affine_quantized_intx( x, input_mapping_type, @@ -526,6 +535,16 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): return res_f +class AQInt8DynamicallyQuantizedSemiSparseLinearWeight( + AQInt8DynamicallyQuantizedLinearWeight +): + layout: Layout = SemiSparseLayout() + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + return super()._autoquant_test(act_mat, weight, bias, best_time, None) + + class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight @@ -613,14 +632,16 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ group_size: int = 32 + layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8) @classmethod def from_float(cls, weight): group_size = cls.group_size - _layout = TensorCoreTiledLayout(inner_k_tiles=8) + _layout = cls.layout if weight.shape[-1] % group_size != 0: return weight + use_hqq = True mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -631,6 +652,13 @@ def from_float(cls, weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT + + if isinstance(_layout, MarlinSparseLayout): + mapping_type = MappingType.SYMMETRIC + preserve_zero = True + zero_point_domain = ZeroPointDomain.INT + use_hqq = False + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx( weight, mapping_type, @@ -665,6 +693,13 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 +class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight( + AQInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 128 + layout: Layout = MarlinSparseLayout() + + class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a @@ -949,16 +984,24 @@ def get_weight_block_size(x): ] OTHER_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, ] +DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, + AQInt8DynamicallyQuantizedSemiSparseLinearWeight, +] + ALL_AUTOQUANT_CLASS_LIST = list( set( DEFAULT_AUTOQUANT_CLASS_LIST + DEFAULT_INT4_AUTOQUANT_CLASS_LIST + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + + DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST ) ) if is_sm_at_least_89(): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 96ccb1889c..99da86b87b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -676,6 +676,9 @@ def apply_int4_weight_only_quant(weight): mapping_type = MappingType.SYMMETRIC preserve_zero = True zero_point_domain = ZeroPointDomain.INT + assert ( + group_size == 128 or group_size == weight.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" return to_affine_quantized_intx( weight,