Skip to content

Commit

Permalink
Add marlin and semi sparse + quant option to autoquant (#1399)
Browse files Browse the repository at this point in the history
* Add marlin and semi sparse + quant option to autoquant

Summary:
Added DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST for autoquant (v1) that contains:
AQDefaultLinearWeight, AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight (float16 only) and AQInt8DynamicallyQuantizedSemiSparseLinearWeight

Test Plan:
tested on llama and sam

python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress autoquant-sparse

+cuda,vit_h,32,10271,12,25.575582921440905,39.099793074967025,0.5424332682384179,max-autotune,torch.bfloat16,autoquant-sparse,False,True,True,32,154,4928,None,None
Baseline: around 22/23

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-sparse --precision float16
Average tokens/sec: 160.55

Base: Average tokens/sec: 110.47

Reviewers:

Subscribers:

Tasks:

Tags:

* ruff
  • Loading branch information
jerryzh168 authored Dec 11, 2024
1 parent 63b30ca commit 039cef4
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 2 deletions.
7 changes: 7 additions & 0 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,13 @@ def ffn_or_attn_only(mod, fqn):
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
if "autoquant-sparse" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list = torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
if "autoquant-all" == quantization:
all_qtensor_classes = (
torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST
Expand Down
2 changes: 2 additions & 0 deletions torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ def mlp_only(mod, name):
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant-float8" == compress:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
elif "autoquant-sparse" == compress:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST)
elif "autoquant-all" == compress:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST)
else:
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DEFAULT_AUTOQUANT_CLASS_LIST,
DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
OTHER_AUTOQUANT_CLASS_LIST,
autoquant,
)
Expand Down Expand Up @@ -92,6 +93,7 @@
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
"DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"ALL_AUTOQUANT_CLASS_LIST",
# top level API - manual
Expand Down
47 changes: 45 additions & 2 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from torchao.dtypes import (
AffineQuantizedTensor,
Float8Layout,
MarlinSparseLayout,
PlainLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
)
from torchao.dtypes.utils import Layout
from torchao.float8.inference import Float8MMConfig
from torchao.kernel import safe_int_mm
from torchao.quantization.linear_activation_quantized_tensor import (
Expand Down Expand Up @@ -46,6 +49,7 @@
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
"DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"ALL_AUTOQUANT_CLASS_LIST",
]
Expand Down Expand Up @@ -406,6 +410,8 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
"""

layout: Layout = PlainLayout()

@classmethod
def from_float(cls, weight):
# TODO test if this is valid
Expand All @@ -414,6 +420,9 @@ def from_float(cls, weight):
# if in_features <= 16:
# return weight

if weight.dim() != 2:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized_intx

Expand All @@ -439,7 +448,7 @@ def get_per_token_block_size(x):
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
_layout = PlainLayout()
_layout = cls.layout
input_quant_func = lambda x: to_affine_quantized_intx(
x,
input_mapping_type,
Expand Down Expand Up @@ -526,6 +535,16 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
return res_f


class AQInt8DynamicallyQuantizedSemiSparseLinearWeight(
AQInt8DynamicallyQuantizedLinearWeight
):
layout: Layout = SemiSparseLayout()

@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
return super()._autoquant_test(act_mat, weight, bias, best_time, None)


class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
Expand Down Expand Up @@ -613,14 +632,16 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""

group_size: int = 32
layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8)

@classmethod
def from_float(cls, weight):
group_size = cls.group_size
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
_layout = cls.layout

if weight.shape[-1] % group_size != 0:
return weight

use_hqq = True
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
Expand All @@ -631,6 +652,13 @@ def from_float(cls, weight):
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

if isinstance(_layout, MarlinSparseLayout):
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
use_hqq = False

return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(
weight,
mapping_type,
Expand Down Expand Up @@ -665,6 +693,13 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight(
group_size: int = 256


class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight(
AQInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 128
layout: Layout = MarlinSparseLayout()


class AQDefaultLinearWeight(torch.Tensor, AQMixin):
"""
A class to be used in concert with AutoQuantizableLinearWeight to provide a
Expand Down Expand Up @@ -949,16 +984,24 @@ def get_weight_block_size(x):
]

OTHER_AUTOQUANT_CLASS_LIST = [
AQDefaultLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
]

DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST = [
AQDefaultLinearWeight,
AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight,
AQInt8DynamicallyQuantizedSemiSparseLinearWeight,
]

ALL_AUTOQUANT_CLASS_LIST = list(
set(
DEFAULT_AUTOQUANT_CLASS_LIST
+ DEFAULT_INT4_AUTOQUANT_CLASS_LIST
+ DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
+ DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST
)
)
if is_sm_at_least_89():
Expand Down
3 changes: 3 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,9 @@ def apply_int4_weight_only_quant(weight):
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
assert (
group_size == 128 or group_size == weight.shape[-1]
), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}"

return to_affine_quantized_intx(
weight,
Expand Down

0 comments on commit 039cef4

Please sign in to comment.