Skip to content

Commit df3bcf9

Browse files
committed
[Feat] Adapted mtp function to Qwen3-next
Signed-off-by: drslark <slarksblood@qq.com>
1 parent fcc9a0e commit df3bcf9

File tree

10 files changed

+254
-13
lines changed

10 files changed

+254
-13
lines changed

tests/e2e/multicard/test_qwen3_next.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,69 @@ def test_models_distributed_Qwen3_NEXT_TP4():
3636
distributed_executor_backend="mp",
3737
enforce_eager=True) as vllm_model:
3838
vllm_model.generate_greedy(example_prompts, max_tokens)
39+
del vllm_model
40+
41+
42+
def test_models_distributed_Qwen3_NEXT_MTP_TP4():
43+
example_prompts = [
44+
"Hello, my name is",
45+
] * 4
46+
max_tokens = 5
47+
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
48+
tensor_parallel_size=4,
49+
max_model_len=4096,
50+
gpu_memory_utilization=0.8,
51+
distributed_executor_backend="mp",
52+
speculative_config={
53+
"method": "qwen3_next_mtp",
54+
"num_speculative_tokens": 1
55+
}) as spec_vllm_model:
56+
spec_vllm_model.generate_greedy(example_prompts, max_tokens)
57+
del spec_vllm_model
58+
59+
60+
def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
61+
example_prompts = [
62+
"Hello, my name is",
63+
"The president of the United States is",
64+
"The capital of France is",
65+
"The future of AI is",
66+
]
67+
max_tokens = 20
68+
69+
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
70+
tensor_parallel_size=4,
71+
max_model_len=4096,
72+
gpu_memory_utilization=0.8,
73+
distributed_executor_backend="mp",
74+
enforce_eager=True) as vllm_model:
75+
ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
76+
del vllm_model
77+
78+
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
79+
tensor_parallel_size=4,
80+
max_model_len=4096,
81+
gpu_memory_utilization=0.8,
82+
distributed_executor_backend="mp",
83+
speculative_config={
84+
"method": "qwen3_next_mtp",
85+
"num_speculative_tokens": 1
86+
},
87+
enforce_eager=True) as spec_vllm_model:
88+
spec_outputs = spec_vllm_model.generate_greedy(example_prompts,
89+
max_tokens)
90+
del spec_vllm_model
91+
92+
matches = 0
93+
misses = 0
94+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
95+
ref_token_ids = ref_output[0]
96+
spec_token_ids = spec_output[0]
97+
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
98+
matches += 1
99+
else:
100+
misses += 1
101+
print(f"ref_output: {ref_output[1]}")
102+
print(f"spec_output: {spec_output[1]}")
103+
104+
assert matches > int(0.66 * len(ref_outputs))

tests/ut/attention/test_attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
6565

6666
def setUp(self):
6767
self.mock_vllm_config = MagicMock()
68+
self.mock_vllm_config.speculative_config = None
6869
self.mock_vllm_config.model_config.max_model_len = 640
6970
self.mock_vllm_config.cache_config.block_size = 64
7071
self.mock_device = 'cpu:0'

vllm_ascend/attention/attention_v1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,17 @@ def __init__(
237237
self.model_config.max_model_len,
238238
AscendAttentionBackend.get_supported_block_size()[0])
239239

240+
self.speculative_config = vllm_config.speculative_config
241+
self.decode_threshold = 1
242+
if self.speculative_config:
243+
spec_token_num = self.speculative_config.num_speculative_tokens
244+
self.decode_threshold += spec_token_num
245+
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
246+
npu_fused_infer_attention_score TND layout's limit of 16, \
247+
got {self.decode_threshold}"
248+
249+
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
250+
240251
def reorder_batch(self, input_batch,
241252
scheduler_output: "SchedulerOutput") -> bool:
242253
return False

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def register_model():
3535
"PanguProMoEForCausalLM",
3636
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
3737
)
38+
3839
ModelRegistry.register_model(
3940
"Qwen3NextForCausalLM",
4041
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
42+
43+
ModelRegistry.register_model(
44+
"Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP")

vllm_ascend/models/qwen3_next.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,24 @@ def _forward(
260260
mixed_qkv_spec = None
261261
mixed_qkv_non_spec = mixed_qkv
262262

263+
# 2.1: process the mutli-query part
264+
if spec_sequence_masks is not None:
265+
mixed_qkv_spec = mixed_qkv_spec.view(
266+
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
267+
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
268+
mixed_qkv_spec = causal_conv1d_update(
269+
mixed_qkv_spec,
270+
conv_state,
271+
conv_weights,
272+
self.conv1d.bias,
273+
self.activation,
274+
conv_state_indices=spec_state_indices_tensor[:, 0]
275+
[:attn_metadata.num_spec_decodes],
276+
num_accepted_tokens=num_accepted_tokens,
277+
validate_data=False,
278+
)
279+
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
280+
263281
# 2.2: process the remaining part
264282
if attn_metadata.num_prefills > 0:
265283
# - "cache_indices" updates the conv_state cache in positions
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Inference-only Qwen3Next MTP model."""
4+
import torch
5+
from vllm.compilation.decorators import support_torch_compile
6+
from vllm.config import VllmConfig
7+
from vllm.model_executor.layers.linear import ColumnParallelLinear
8+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
9+
from vllm.model_executor.layers.vocab_parallel_embedding import (
10+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
11+
from vllm.model_executor.models.interfaces import SupportsPP
12+
from vllm.model_executor.models.qwen3_next_mtp import (
13+
Qwen3NextMTP, Qwen3NextMultiTokenPredictor)
14+
from vllm.model_executor.models.utils import (
15+
make_empty_intermediate_tensors_factory, maybe_prefix)
16+
from vllm.transformers_utils.configs import Qwen3NextConfig
17+
18+
from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer,
19+
Qwen3NextRMSNorm)
20+
21+
22+
@support_torch_compile
23+
class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor):
24+
25+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
26+
super(Qwen3NextMultiTokenPredictor, self).__init__()
27+
28+
model_config = vllm_config.model_config
29+
quant_config = vllm_config.quant_config
30+
lora_config = vllm_config.lora_config
31+
config: Qwen3NextConfig = model_config.hf_config
32+
33+
self.config = config
34+
lora_vocab = ((lora_config.lora_extra_vocab_size *
35+
(lora_config.max_loras or 1)) if lora_config else 0)
36+
self.vocab_size = config.vocab_size + lora_vocab
37+
self.org_vocab_size = config.vocab_size
38+
39+
self.mtp_start_layer_idx = config.num_hidden_layers
40+
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
41+
42+
self.embed_tokens = VocabParallelEmbedding(
43+
self.vocab_size,
44+
config.hidden_size,
45+
org_num_embeddings=config.vocab_size,
46+
)
47+
48+
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
49+
self.config.hidden_size,
50+
gather_output=True,
51+
bias=False,
52+
return_bias=False,
53+
quant_config=quant_config,
54+
prefix=f'{prefix}.fc')
55+
56+
# use old version mtp layer name to avoid a exception in vllm
57+
self.layers = torch.nn.ModuleList(
58+
CustomQwen3NextDecoderLayer(
59+
vllm_config,
60+
layer_type="full_attention",
61+
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
62+
) for idx in range(self.num_mtp_layers))
63+
64+
self.make_empty_intermediate_tensors = (
65+
make_empty_intermediate_tensors_factory(
66+
["hidden_states", "residual"], config.hidden_size))
67+
68+
self.norm = Qwen3NextRMSNorm(config.hidden_size,
69+
eps=config.rms_norm_eps)
70+
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
71+
eps=config.rms_norm_eps)
72+
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
73+
eps=config.rms_norm_eps)
74+
75+
76+
@support_torch_compile
77+
class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP):
78+
packed_modules_mapping = {
79+
"qkv_proj": [
80+
"q_proj",
81+
"k_proj",
82+
"v_proj",
83+
],
84+
"gate_up_proj": ["up_proj", "down_proj"]
85+
}
86+
87+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
88+
config = vllm_config.model_config.hf_config
89+
self.vllm_config = vllm_config
90+
cache_config = vllm_config.cache_config
91+
assert not cache_config.enable_prefix_caching, \
92+
"Qwen3NextMTP currently does not support prefix caching"
93+
94+
self.quant_config = vllm_config.quant_config
95+
96+
super(Qwen3NextMTP, self).__init__()
97+
self.config = config
98+
self.model = CustomQwen3NextMultiTokenPredictor(
99+
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
100+
self.unpadded_vocab_size = config.vocab_size
101+
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
102+
config.hidden_size,
103+
org_num_embeddings=config.vocab_size,
104+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
105+
prefix=maybe_prefix(prefix, "lm_head"))
106+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
107+
config.vocab_size)
108+
self.make_empty_intermediate_tensors = (
109+
self.model.make_empty_intermediate_tensors)

vllm_ascend/ops/casual_conv1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def causal_conv1d_ref(
5555
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
5656
dtype_in) # (batch, dim, width - 1)
5757
if final_states_out is not None:
58-
final_states_out.copy_(final_states)
58+
final_states_out[..., :(width - 1)].copy_(final_states)
5959
else:
6060
final_states_out = final_states
6161
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)

vllm_ascend/spec_decode/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def get_spec_decode_method(method,
3535
if is_torchair_graph:
3636
return TorchairMtpProposer(vllm_config, device, runner)
3737
return MtpProposer(vllm_config, device, runner)
38+
elif method == 'qwen3_next_mtp':
39+
return MtpProposer(vllm_config, device, runner)
3840
else:
3941
raise ValueError("Unknown speculative decoding method: "
4042
f"{method}")

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
from typing import Optional
23

34
import numpy as np
@@ -12,7 +13,6 @@
1213
from vllm.model_executor.model_loader import get_model_loader
1314
from vllm.model_executor.model_loader.utils import \
1415
process_weights_after_loading
15-
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
1616
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1717
from vllm.utils import cdiv
1818
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@@ -41,6 +41,26 @@
4141

4242
PADDING_SLOT_ID = -1
4343

44+
_MTP_MODELS = {
45+
"DeepseekV3ForCausalLM":
46+
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
47+
"Qwen3NextForCausalLM":
48+
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
49+
}
50+
51+
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
52+
53+
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
54+
55+
56+
def _load_model(architecture):
57+
if architecture not in _MTP_MODELS:
58+
raise ValueError("Invalid architecture for mtp.")
59+
module_name, model_name = _MTP_MODELS[architecture]
60+
module = importlib.import_module(module_name)
61+
model = getattr(module, model_name)
62+
return model
63+
4464

4565
class MtpProposer(Proposer):
4666

@@ -146,9 +166,7 @@ def load_model(self, model) -> None:
146166
with set_default_torch_dtype(
147167
draft_model_config.dtype), set_current_vllm_config(
148168
self.vllm_config):
149-
self.model = DeepSeekMTP(
150-
vllm_config=self.vllm_config).to(target_device)
151-
169+
self._init_mtp_model()
152170
draft_attn_layer_names = (get_layers_from_vllm_config(
153171
self.vllm_config, AttentionLayerBase).keys() -
154172
target_attn_layer_names)
@@ -217,8 +235,7 @@ def generate_token_ids(self,
217235
attn_metadata=None,
218236
aux_hidden_states: torch.Tensor = None):
219237
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
220-
if attn_metadata is not None and isinstance(attn_metadata, dict):
221-
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
238+
attn_metadata = self._get_attn_metadata(attn_metadata)
222239

223240
if self.speculative_config.disable_padded_drafter_batch:
224241
# When padded-batch is disabled, the sampled_token_ids should be
@@ -300,6 +317,20 @@ def generate_token_ids(self,
300317

301318
return draft_token_ids
302319

320+
def _init_mtp_model(self):
321+
architecture = self.vllm_config.model_config.architecture
322+
target_device = self.vllm_config.device_config.device
323+
model = _load_model(architecture)
324+
self.model = model(vllm_config=self.vllm_config).to(target_device)
325+
326+
def _get_attn_metadata(self, attn_metadata):
327+
if attn_metadata is not None and isinstance(attn_metadata, dict):
328+
architecture = self.vllm_config.model_config.architecture
329+
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
330+
attn_metadata = attn_metadata[layer_name]
331+
332+
return attn_metadata
333+
303334
def _prepare_inputs(
304335
self,
305336
common_attn_metadata: CommonAttentionMetadata,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,7 +1851,7 @@ def _prepare_inputs(
18511851
extra_attn_metadata_args = dict(
18521852
num_accepted_tokens=self.num_accepted_tokens.
18531853
gpu[:num_reqs],
1854-
num_draft_tokens=self.num_draft_tokens.
1854+
num_decode_draft_tokens_cpu=self.num_draft_tokens.
18551855
gpu[:num_reqs],
18561856
)
18571857
attn_metadata_i = builder.build(
@@ -1944,11 +1944,10 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens,
19441944
attn_state = AscendAttentionState.SpecDecoding
19451945
# Speculative decoding.
19461946
elif np.all(num_valid_tokens == 1):
1947-
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
1948-
or self.drafter.name == SpecDcodeType.EAGLE3):
1949-
attn_state = AscendAttentionState.ChunkedPrefill
1950-
else:
1947+
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
19511948
attn_state = AscendAttentionState.SpecDecoding
1949+
else:
1950+
attn_state = AscendAttentionState.ChunkedPrefill
19521951
# splitfuse
19531952
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
19541953
attn_state = AscendAttentionState.ChunkedPrefill
@@ -2544,7 +2543,7 @@ def propose_draft_token_ids(sampled_token_ids):
25442543
with ProfileExecuteDuration().capture_async("Draft"):
25452544
if self.speculative_config:
25462545
use_padded_batch_for_eagle = self.speculative_config and \
2547-
self.speculative_config.method == "deepseek_mtp" and \
2546+
self.speculative_config.method in ("deepseek_mtp", "qwen3_next_mtp") and \
25482547
not self.speculative_config.disable_padded_drafter_batch
25492548
if use_padded_batch_for_eagle:
25502549
# EAGLE speculative decoding can use the GPU sampled tokens

0 commit comments

Comments
 (0)