From 101bf6345a14c59f407a6a4bb2142cdf0e7f133e Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Fri, 23 May 2025 10:17:35 +0800 Subject: [PATCH 1/5] fix moe args issues Signed-off-by: wenhuach21 --- README.md | 4 +++- .../layers/quantization/auto_round.py | 17 +++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 5b87ae838885..a57e1c077592 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,9 @@ vLLM is fast with: - Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph -- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. +- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, + INT8, and + FP8. - Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. - Speculative decoding - Chunked prefill diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index a5e63843cf62..73bdf7f19598 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -74,7 +75,7 @@ def __repr__(self) -> str: f"group_size={self.group_size}, sym={self.sym})") @classmethod - def get_name(cls): ## use str will trigger preci issue + def get_name(cls) -> QuantizationMethods : ## use str will trigger preci issue return "auto-round" @classmethod @@ -145,7 +146,6 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): if isinstance(layer, FusedMoE): use_marlin = check_moe_marlin_supports_layer(layer, group_size) else: - AWQ_TYPE_MAP = { 4: scalar_types.uint4, 8: scalar_types.uint8, @@ -180,10 +180,11 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) config = { - "linear_quant_method": "awq", - "weight_bits": weight_bits, + "quant_method": "awq", + "bits": weight_bits, "group_size": group_size, - "zero_point": not sym, + "sym": sym, + "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix) @@ -251,11 +252,11 @@ def apply_gptq_quant_layer(self, from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) config = { - "linear_quant_method": "gptq", - "weight_bits": weight_bits, + "quant_method": "gptq", + "bits": weight_bits, "group_size": group_size, "sym": sym, - "lm_head_quantized": False, + "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix) From 6d310ceefd96bb89e93a40d3de22fb508d947f9d Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Fri, 23 May 2025 10:24:16 +0800 Subject: [PATCH 2/5] fix ruff issue Signed-off-by: wenhuach21 --- vllm/model_executor/layers/quantization/auto_round.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index 73bdf7f19598..59008900cd89 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -75,7 +75,7 @@ def __repr__(self) -> str: f"group_size={self.group_size}, sym={self.sym})") @classmethod - def get_name(cls) -> QuantizationMethods : ## use str will trigger preci issue + def get_name(cls) -> QuantizationMethods: return "auto-round" @classmethod From 1ee43e5b4c4c90286d64335775a19f69804c6d69 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Fri, 23 May 2025 11:27:51 +0800 Subject: [PATCH 3/5] fix Signed-off-by: wenhuach21 --- .../layers/quantization/auto_round.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index 59008900cd89..00fc14c98e23 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -143,17 +143,17 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): prefix, layer.__class__.__name__, weight_bits, group_size, sym) if backend == "auto" or "marlin" in backend: + AWQ_TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported( + AWQ_TYPE_MAP[weight_bits], group_size, + not sym) + if isinstance(layer, FusedMoE): - use_marlin = check_moe_marlin_supports_layer(layer, group_size) - else: - AWQ_TYPE_MAP = { - 4: scalar_types.uint4, - 8: scalar_types.uint8, - } - use_marlin = ((weight_bits, sym) in AWQ_TYPE_MAP - and check_marlin_supported( - AWQ_TYPE_MAP[(weight_bits)], group_size, - not sym)) + use_marlin = use_marlin and check_moe_marlin_supports_layer(layer, group_size) + else: use_marlin = False if use_marlin: @@ -183,7 +183,7 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): "quant_method": "awq", "bits": weight_bits, "group_size": group_size, - "sym": sym, + "zero_point": not sym, "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( @@ -214,18 +214,17 @@ def apply_gptq_quant_layer(self, prefix, layer.__class__.__name__, weight_bits, group_size, sym) if backend == "auto" or "marlin" in backend: + GPTQ_TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP + and check_marlin_supported( + GPTQ_TYPE_MAP[(weight_bits, sym)], + group_size, + has_zp=not sym)) if isinstance(layer, FusedMoE): - use_marlin = check_moe_marlin_supports_layer(layer, group_size) - else: - GPTQ_TYPE_MAP = { - (4, True): scalar_types.uint4b8, - (8, True): scalar_types.uint8b128, - } - use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP - and check_marlin_supported( - GPTQ_TYPE_MAP[(weight_bits, sym)], - group_size, - has_zp=not sym)) + use_marlin = use_marlin and check_moe_marlin_supports_layer(layer, group_size) else: use_marlin = False if use_marlin: From 60ac95863ee4373fc7ac4dde39adcf2dbddda6e0 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Fri, 23 May 2025 11:28:43 +0800 Subject: [PATCH 4/5] fix ruff issue Signed-off-by: wenhuach21 --- .../model_executor/layers/quantization/auto_round.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index 00fc14c98e23..2d9f5e52bd65 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -147,12 +147,13 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): 4: scalar_types.uint4, 8: scalar_types.uint8, } - use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported( - AWQ_TYPE_MAP[weight_bits], group_size, - not sym) + use_marlin = (weight_bits + in AWQ_TYPE_MAP) and check_marlin_supported( + AWQ_TYPE_MAP[weight_bits], group_size, not sym) if isinstance(layer, FusedMoE): - use_marlin = use_marlin and check_moe_marlin_supports_layer(layer, group_size) + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size) else: use_marlin = False @@ -224,7 +225,8 @@ def apply_gptq_quant_layer(self, group_size, has_zp=not sym)) if isinstance(layer, FusedMoE): - use_marlin = use_marlin and check_moe_marlin_supports_layer(layer, group_size) + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size) else: use_marlin = False if use_marlin: From cddf0e4b53ac83a306d39b61b846d8a7fea9df79 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Fri, 23 May 2025 11:43:56 +0800 Subject: [PATCH 5/5] fix readme Signed-off-by: wenhuach21 --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index a57e1c077592..c119ad42ac4b 100644 --- a/README.md +++ b/README.md @@ -58,9 +58,7 @@ vLLM is fast with: - Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph -- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, - INT8, and - FP8. +- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, INT8, and FP8. - Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. - Speculative decoding - Chunked prefill