Skip to content

Commit 688350a

Browse files
authored
[bugfixed] fix the bug when run the inference of quantized ds-w8a8-mtp (#2134)
When run the inference of ds-w8a8-mtp, it reported 'ParamllelLMhead has no attribute 'params_dtype''. 1. add wrapper of vocab_parallel_embedding, fixed the bugs when running deepseek-w8a8-mtp Signed-off-by: curryliu <120010041@link.cuhk.edu.cn> - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@ad57f23 --------- Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
1 parent 4b3a210 commit 688350a

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

vllm_ascend/quantization/func_wrapper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,39 @@
2222
from vllm.logger import logger
2323
from vllm.model_executor.layers.layernorm import RMSNorm
2424
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
25+
from vllm.model_executor.layers.vocab_parallel_embedding import (
26+
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
27+
28+
29+
# func refers to vocabParallelEmbedding.__init__
30+
def wrapper_vocab_parallel_embedding_init(func):
31+
32+
def init(
33+
self,
34+
num_embeddings: int,
35+
embedding_dim: int,
36+
params_dtype: Optional[torch.dtype] = None,
37+
org_num_embeddings: Optional[int] = None,
38+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
39+
quant_config: Optional[QuantizationConfig] = None,
40+
prefix: str = "",
41+
):
42+
func(
43+
self,
44+
num_embeddings,
45+
embedding_dim,
46+
params_dtype,
47+
org_num_embeddings,
48+
padding_size,
49+
quant_config,
50+
prefix,
51+
)
52+
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
53+
if params_dtype is None:
54+
params_dtype = torch.get_default_dtype()
55+
self.params_dtype = params_dtype
56+
57+
return init
2558

2659

2760
# func refers to RMSNorm.__init__

vllm_ascend/quantization/quantizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
from vllm.logger import logger
2424

25-
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
25+
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
26+
wrapper_vocab_parallel_embedding_init)
2627
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
2728
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
2829
AscendW8A8LinearMethod)
@@ -75,6 +76,9 @@ def __init__(self, quant_description):
7576
VLLMAscendQuantizer.apply_patch(
7677
"vllm.model_executor.layers.layernorm.RMSNorm",
7778
"forward_oot", [wrapper_rmsnorm_forward_oot])
79+
VLLMAscendQuantizer.apply_patch(
80+
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding",
81+
"__init__", [wrapper_vocab_parallel_embedding_init])
7882
break
7983
VLLMAscendQuantizer.patched = True
8084
logger.info("Using the vLLM Ascend Quantizer version now!")

0 commit comments

Comments
 (0)