Skip to content

Commit 77ffc22

Browse files
committed
[V1][BUGFIX][0.10.1] FIX mtp on main branch
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent 6f1047d commit 77ffc22

File tree

5 files changed

+131
-3
lines changed

5 files changed

+131
-3
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import pytest
6+
from vllm import SamplingParams
7+
8+
from tests.e2e.conftest import VllmRunner
9+
10+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
11+
12+
13+
@pytest.fixture
14+
def sampling_config():
15+
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
16+
17+
18+
@pytest.fixture
19+
def model_name():
20+
return "wemaster/deepseek_mtp_main_random_bf16"
21+
22+
23+
def test_mtp_torchair_correctness(
24+
sampling_config: SamplingParams,
25+
model_name: str,
26+
):
27+
example_prompts = [
28+
"Hello, my name is",
29+
"The president of the United States is",
30+
"The capital of France is",
31+
"The future of AI is",
32+
]
33+
'''
34+
Compare the outputs of a original LLM and a speculative LLM
35+
should be the same when using mtp speculative decoding.
36+
'''
37+
with VllmRunner(model_name,
38+
tensor_parallel_size=1,
39+
gpu_memory_utilization=0.7,
40+
max_model_len=256,
41+
enforce_eager=False,
42+
additional_config={
43+
"torchair_graph_config": {
44+
"enabled": True,
45+
"use_cached_graph": False,
46+
"graph_batch_sizes": [1, 2, 4],
47+
},
48+
}) as ref_llm:
49+
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
50+
51+
with VllmRunner(model_name,
52+
tensor_parallel_size=1,
53+
max_num_seqs=256,
54+
gpu_memory_utilization=0.7,
55+
distributed_executor_backend="mp",
56+
enable_expert_parallel=True,
57+
speculative_config={
58+
"method": "deepseek_mtp",
59+
"num_speculative_tokens": 1,
60+
},
61+
enforce_eager=False,
62+
max_model_len=2000,
63+
additional_config={
64+
"torchair_graph_config": {
65+
"enabled": True,
66+
"use_cached_graph": False,
67+
"graph_batch_sizes": [1, 2, 4],
68+
}
69+
}) as spec_llm:
70+
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
71+
72+
matches = 0
73+
misses = 0
74+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
75+
ref_token_ids = ref_output[0][0]
76+
spec_token_ids = spec_output[0][0]
77+
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
78+
matches += 1
79+
else:
80+
misses += 1
81+
print(f"ref_output: {ref_output[1][0]}")
82+
print(f"spec_output: {spec_output[1][0]}")
83+
84+
# Heuristic: expect at least 66% of the prompts to match exactly
85+
# Upon failure, inspect the outputs to check for inaccuracy.
86+
assert matches > int(0.66 * len(ref_outputs))
87+
del spec_llm

tests/ut/ops/test_fused_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
AscendUnquantizedFusedMoEMethod,
3131
unified_apply_mlp)
3232
from vllm_ascend.ops.layers.experts_selector import select_experts
33+
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
3334
from vllm_ascend.utils import AscendSocVersion, adapt_patch
3435

3536
adapt_patch(True)
@@ -319,6 +320,19 @@ def test_init_with_quant(self, mock_dist_env, default_moe_config):
319320
assert moe.quant_method is not None
320321
assert moe.quant_method == mock_quant_method
321322

323+
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
324+
mock_quant_config = MagicMock()
325+
mock_quant_method = MockFusedMoEMethod()
326+
mock_quant_config.get_quant_method.return_value = mock_quant_method
327+
mock_quant_config.is_layer_skipped_ascend.return_value = True
328+
329+
quantized_moe = AscendFusedMoE(**default_moe_config,
330+
quant_config=mock_quant_config)
331+
332+
assert quantized_moe.quant_method is not None
333+
assert isinstance(quantized_moe.quant_method,
334+
AscendUnquantizedFusedMoEMethod)
335+
322336
@pytest.mark.parametrize(
323337
"others_param",
324338
[[None,

tests/ut/torchair/ops/test_torchair_fused_moe.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
2424

2525
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
26+
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
2627
from vllm_ascend.torchair.ops.torchair_fused_moe import (
2728
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
2829
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
@@ -240,6 +241,19 @@ def test_init_with_quant(self, mock_dist_env, default_moe_config):
240241
assert moe.quant_method is not None
241242
assert moe.quant_method == mock_quant_method
242243

244+
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
245+
mock_quant_config = MagicMock()
246+
mock_quant_method = MockFusedMoEMethod()
247+
mock_quant_config.get_quant_method.return_value = mock_quant_method
248+
mock_quant_config.is_layer_skipped_ascend.return_value = True
249+
250+
moe = TorchairAscendFusedMoE(**default_moe_config,
251+
quant_config=mock_quant_config)
252+
253+
assert moe.quant_method is not None
254+
assert isinstance(moe.quant_method,
255+
TorchairAscendUnquantizedFusedMoEMethod)
256+
243257
@pytest.mark.parametrize(
244258
"others_param",
245259
[[None,

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
4949
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
5050
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
51+
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
5152
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
5253
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5354
get_all_reduce_merge_state,
@@ -1291,7 +1292,13 @@ def __init__(
12911292
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
12921293
self.moe)
12931294
else:
1294-
self.quant_method = quant_config.get_quant_method(self, prefix)
1295+
if quant_config.is_layer_skipped_ascend(
1296+
prefix, quant_config.packed_modules_mapping):
1297+
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
1298+
self.moe)
1299+
else:
1300+
self.quant_method = AscendFusedMoEMethod(
1301+
quant_config, prefix, quant_config.packed_modules_mapping)
12951302

12961303
assert self.quant_method is not None
12971304

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
1919
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
2020
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
21+
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
22+
TorchairDeepSeekMTP
2123
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
2224
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
2325

@@ -266,8 +268,12 @@ def load_model(self) -> None:
266268
with set_default_torch_dtype(
267269
draft_model_config.dtype), set_current_vllm_config(
268270
self.vllm_config):
269-
self.model = CustomDeepSeekMTP(
270-
vllm_config=self.vllm_config).to(target_device)
271+
if self.torchair_graph_enabled:
272+
self.model = TorchairDeepSeekMTP(
273+
vllm_config=self.vllm_config).to(target_device)
274+
else:
275+
self.model = CustomDeepSeekMTP(
276+
vllm_config=self.vllm_config).to(target_device)
271277

272278
draft_attn_layer_names = (
273279
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -

0 commit comments

Comments
 (0)