diff --git a/README.md b/README.md index 5b87ae838885..c119ad42ac4b 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ vLLM is fast with: - Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph -- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. +- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, INT8, and FP8. - Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. - Speculative decoding - Chunked prefill diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index a5e63843cf62..2d9f5e52bd65 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -74,7 +75,7 @@ def __repr__(self) -> str: f"group_size={self.group_size}, sym={self.sym})") @classmethod - def get_name(cls): ## use str will trigger preci issue + def get_name(cls) -> QuantizationMethods: return "auto-round" @classmethod @@ -142,18 +143,18 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): prefix, layer.__class__.__name__, weight_bits, group_size, sym) if backend == "auto" or "marlin" in backend: + AWQ_TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + use_marlin = (weight_bits + in AWQ_TYPE_MAP) and check_marlin_supported( + AWQ_TYPE_MAP[weight_bits], group_size, not sym) + if isinstance(layer, FusedMoE): - use_marlin = check_moe_marlin_supports_layer(layer, group_size) - else: + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size) - AWQ_TYPE_MAP = { - 4: scalar_types.uint4, - 8: scalar_types.uint8, - } - use_marlin = ((weight_bits, sym) in AWQ_TYPE_MAP - and check_marlin_supported( - AWQ_TYPE_MAP[(weight_bits)], group_size, - not sym)) else: use_marlin = False if use_marlin: @@ -180,10 +181,11 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) config = { - "linear_quant_method": "awq", - "weight_bits": weight_bits, + "quant_method": "awq", + "bits": weight_bits, "group_size": group_size, "zero_point": not sym, + "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix) @@ -213,18 +215,18 @@ def apply_gptq_quant_layer(self, prefix, layer.__class__.__name__, weight_bits, group_size, sym) if backend == "auto" or "marlin" in backend: + GPTQ_TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP + and check_marlin_supported( + GPTQ_TYPE_MAP[(weight_bits, sym)], + group_size, + has_zp=not sym)) if isinstance(layer, FusedMoE): - use_marlin = check_moe_marlin_supports_layer(layer, group_size) - else: - GPTQ_TYPE_MAP = { - (4, True): scalar_types.uint4b8, - (8, True): scalar_types.uint8b128, - } - use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP - and check_marlin_supported( - GPTQ_TYPE_MAP[(weight_bits, sym)], - group_size, - has_zp=not sym)) + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size) else: use_marlin = False if use_marlin: @@ -251,11 +253,11 @@ def apply_gptq_quant_layer(self, from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) config = { - "linear_quant_method": "gptq", - "weight_bits": weight_bits, + "quant_method": "gptq", + "bits": weight_bits, "group_size": group_size, "sym": sym, - "lm_head_quantized": False, + "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix)