Skip to content

Commit 53ce4a0

Browse files
authored
[0.9.1]support deepseek w4a8 quantization (#1320)
### What this PR does / why we need it? Supports Deepseek-R1 w4a8 quantization. Since R1 w4a8 uses mixed quantization, only the MOE layer uses w4a8_dynamic quantization, so we added the w4a8_dynamic.py file, which includes the AscendW4A8DynamicFusedMoEMethod class. ### Does this PR introduce _any_ user-facing change? no, use `--quantization=ascend` is engouh. ### How was this patch tested? #### 1.How to get weights using Modelslim ##### Installation steps Use the branch master, the commit id is: 298e175d69b3b855111a1e09bbe2fcd12fdb4e24 git clone https://gitee.com/ascend/msit.git cd msit/msmodelslim bash install.sh ##### The required transformers environment pip install transformers==4.48.2 ##### Generate w4a8 weights cd /example/DeepSeek Command reference: msmodelslim/example/DeepSeek/README.md Execute the [pre-check](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#运行前必检) and [DeepSeek-R1 w4a8 mix quantization](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-混合量化前三层-mlpw8a8-dynamic-量化mla共享专家w8a8量化路由专家w4a8-dynamic量化) chapter Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path} --mindie_format ##### Adapt to vllm-ascend Since mindie_format generates mindie format, some adaptation modifications are needed for vllm-ascend to use it: `quant_model_description_w8a8_dynamic.json` rename to `quant_model_description.json`, and change `"group_size": 0` to `"group_size": 256` Modification in `config.json`:`"model_type":deepseekv2` is changed to `"model_type":deepseek_v3` ; `quantization_config` is removed; #### 2.How to run w4a8 TP + EP: python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 --enable_expert_parallel --quantization ascend --port $3 --max-model-len $4 --max-num-seqs $5 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 2048 --max-num-seqs 128 --enforce-eager DP+TP+EP: python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 2 -dp2 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 2048 --max-num-seqs 128 --enforce-eager #### 3.Use constraints export VLLM_USE_V1=1 # v1 --------- Signed-off-by: pichangping <1337510399@qq.com>
1 parent f1353d5 commit 53ce4a0

File tree

7 files changed

+571
-1
lines changed

7 files changed

+571
-1
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ jobs:
204204
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
205205
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
206206
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
207+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_w4a8_deepseek.py::test_deepseek_W4A8
207208
fi
208209
209210
- name: Run vllm-project/vllm-ascend test on V0 engine

tests/multicard/test_offline_inference_distributed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
from unittest.mock import patch
2525

26+
import pytest
2627
from modelscope import snapshot_download # type: ignore
2728
from vllm import SamplingParams
2829
from vllm.model_executor.models.registry import ModelRegistry
@@ -104,6 +105,7 @@ def test_models_distributed_DeepSeek_dbo():
104105
vllm_model.generate(example_prompts, sampling_params)
105106

106107

108+
@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
107109
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
108110
def test_models_distributed_DeepSeekV3_dbo():
109111
example_prompts = ["The president of the United States is"] * 41
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import os
18+
19+
import pytest
20+
21+
from tests.conftest import VllmRunner
22+
23+
24+
@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
25+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
26+
reason="w4a8_dynamic is not supported on v0")
27+
def test_deepseek_W4A8(monkeypatch: pytest.MonkeyPatch):
28+
with monkeypatch.context() as m:
29+
m.setenv("VLLM_USE_V1", "1")
30+
31+
prompts = [
32+
"Hello, my name is",
33+
"The president of the United States is",
34+
"The capital of France is",
35+
"The future of AI is",
36+
]
37+
dtype = "bfloat16"
38+
max_tokens = 5
39+
with VllmRunner(
40+
"vllm-ascend/DeepSeek-R1-w4a8-pruning",
41+
dtype=dtype,
42+
tensor_parallel_size=2,
43+
enforce_eager=True,
44+
quantization="ascend",
45+
enable_expert_parallel=True,
46+
additional_config={
47+
"torchair_graph_config": {
48+
"enabled": False,
49+
},
50+
"ascend_scheduler_config": {
51+
"enabled": True,
52+
}
53+
},
54+
) as vllm_model:
55+
# use greedy sampler to make sure the generated results are fix
56+
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
57+
58+
golden_results = [
59+
'Hello, my name is逸研究发现IPPudsimentary',
60+
'The president of the United States is逸 Ban Corporealistically',
61+
'The capital of France is逸 Ban Corporealistically',
62+
'The future of AI is逸 Ban Corporealistically',
63+
]
64+
assert len(golden_results) == len(vllm_output)
65+
for i in range(len(vllm_output)):
66+
assert golden_results[i] == vllm_output[i][1]
67+
print(f"Generated text: {vllm_output[i][1]!r}")

vllm_ascend/models/deepseek_v2.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
28+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2929

3030
import torch
3131
import torch_npu
@@ -765,6 +765,14 @@ def forward(
765765
inputs_embeds)
766766
return hidden_states
767767

768+
def load_weights(self, weights: Iterable[tuple[str,
769+
torch.Tensor]]) -> set[str]:
770+
weights = filter(lambda x: ".module." not in x[0], weights)
771+
# weights = ((name, data) for name, data in weights if ".module." not in name)
772+
loaded_params = super().load_weights(weights)
773+
774+
return loaded_params
775+
768776

769777
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
770778
pass

vllm_ascend/quantization/quant_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ def create_weights(
305305
param = torch.nn.Parameter(param_value, requires_grad=False)
306306
layer.register_parameter(param_key, param)
307307
set_weight_attrs(param, extra_weight_attrs)
308+
if "weight_scale_second" in param_key or "weight_offset_second" in param_key:
309+
setattr(param, "quant_method",
310+
FusedMoeWeightScaleSupported.GROUP.value)
311+
param.quant_method = FusedMoeWeightScaleSupported.GROUP.value
308312

309313
def apply(
310314
self,

vllm_ascend/quantization/quantizer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
2626
wrapper_rmsnorm_init)
27+
from .w4a8_dynamic import AscendW4A8DynamicFusedMoEMethod
2728
from .w8a8 import AscendW8A8LinearMethod
2829
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
2930
AscendW8A8DynamicLinearMethod)
@@ -281,7 +282,15 @@ def build_moe_method():
281282
return AscendW8A8DynamicFusedMoEMethod()
282283

283284

285+
class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):
286+
287+
@staticmethod
288+
def build_moe_method():
289+
return AscendW4A8DynamicFusedMoEMethod()
290+
291+
284292
SUPPORT_ASCEND_QUANTIZER_TYPE = {
285293
"W8A8": W8A8Quantizer,
286294
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
295+
"W4A8_DYNAMIC": W4A8DYNAMICQuantizer
287296
}

0 commit comments

Comments
 (0)