Skip to content

Commit 810d097

Browse files
committed
[Feat] Adapted mtp function to Qwen3-next
Signed-off-by: drslark <slarksblood@qq.com>
1 parent fcc9a0e commit 810d097

File tree

9 files changed

+168
-11
lines changed

9 files changed

+168
-11
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
6565

6666
def setUp(self):
6767
self.mock_vllm_config = MagicMock()
68+
self.mock_vllm_config.speculative_config = None
6869
self.mock_vllm_config.model_config.max_model_len = 640
6970
self.mock_vllm_config.cache_config.block_size = 64
7071
self.mock_device = 'cpu:0'

vllm_ascend/attention/attention_v1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,17 @@ def __init__(
237237
self.model_config.max_model_len,
238238
AscendAttentionBackend.get_supported_block_size()[0])
239239

240+
self.speculative_config = vllm_config.speculative_config
241+
self.decode_threshold = 1
242+
if self.speculative_config:
243+
spec_token_num = self.speculative_config.num_speculative_tokens
244+
self.decode_threshold += spec_token_num
245+
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
246+
npu_fused_infer_attention_score TND layout's limit of 16, \
247+
got {self.decode_threshold}"
248+
249+
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
250+
240251
def reorder_batch(self, input_batch,
241252
scheduler_output: "SchedulerOutput") -> bool:
242253
return False

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def register_model():
3535
"PanguProMoEForCausalLM",
3636
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
3737
)
38+
3839
ModelRegistry.register_model(
3940
"Qwen3NextForCausalLM",
4041
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
42+
43+
ModelRegistry.register_model(
44+
"Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP")

vllm_ascend/models/qwen3_next.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,24 @@ def _forward(
260260
mixed_qkv_spec = None
261261
mixed_qkv_non_spec = mixed_qkv
262262

263+
# 2.1: process the mutli-query part
264+
if spec_sequence_masks is not None:
265+
mixed_qkv_spec = mixed_qkv_spec.view(
266+
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
267+
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
268+
mixed_qkv_spec = causal_conv1d_update(
269+
mixed_qkv_spec,
270+
conv_state,
271+
conv_weights,
272+
self.conv1d.bias,
273+
self.activation,
274+
conv_state_indices=spec_state_indices_tensor[:, 0]
275+
[:attn_metadata.num_spec_decodes],
276+
num_accepted_tokens=num_accepted_tokens,
277+
validate_data=False,
278+
)
279+
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
280+
263281
# 2.2: process the remaining part
264282
if attn_metadata.num_prefills > 0:
265283
# - "cache_indices" updates the conv_state cache in positions
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Inference-only Qwen3Next MTP model."""
4+
import torch
5+
from vllm.compilation.decorators import support_torch_compile
6+
from vllm.config import VllmConfig
7+
from vllm.model_executor.layers.linear import ColumnParallelLinear
8+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
9+
from vllm.model_executor.layers.vocab_parallel_embedding import (
10+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
11+
from vllm.model_executor.models.interfaces import SupportsPP
12+
from vllm.model_executor.models.qwen3_next_mtp import (
13+
Qwen3NextMTP, Qwen3NextMultiTokenPredictor)
14+
from vllm.model_executor.models.utils import (
15+
make_empty_intermediate_tensors_factory, maybe_prefix)
16+
from vllm.transformers_utils.configs import Qwen3NextConfig
17+
18+
from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer,
19+
Qwen3NextRMSNorm)
20+
21+
22+
@support_torch_compile
23+
class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor):
24+
25+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
26+
super(Qwen3NextMultiTokenPredictor, self).__init__()
27+
28+
model_config = vllm_config.model_config
29+
quant_config = vllm_config.quant_config
30+
lora_config = vllm_config.lora_config
31+
config: Qwen3NextConfig = model_config.hf_config
32+
33+
self.config = config
34+
lora_vocab = ((lora_config.lora_extra_vocab_size *
35+
(lora_config.max_loras or 1)) if lora_config else 0)
36+
self.vocab_size = config.vocab_size + lora_vocab
37+
self.org_vocab_size = config.vocab_size
38+
39+
self.mtp_start_layer_idx = config.num_hidden_layers
40+
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
41+
42+
self.embed_tokens = VocabParallelEmbedding(
43+
self.vocab_size,
44+
config.hidden_size,
45+
org_num_embeddings=config.vocab_size,
46+
)
47+
48+
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
49+
self.config.hidden_size,
50+
gather_output=True,
51+
bias=False,
52+
return_bias=False,
53+
quant_config=quant_config,
54+
prefix=f'{prefix}.fc')
55+
56+
# use old version mtp layer name to avoid a exception in vllm
57+
self.layers = torch.nn.ModuleList(
58+
CustomQwen3NextDecoderLayer(
59+
vllm_config,
60+
layer_type="full_attention",
61+
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
62+
) for idx in range(self.num_mtp_layers))
63+
64+
self.make_empty_intermediate_tensors = (
65+
make_empty_intermediate_tensors_factory(
66+
["hidden_states", "residual"], config.hidden_size))
67+
68+
self.norm = Qwen3NextRMSNorm(config.hidden_size,
69+
eps=config.rms_norm_eps)
70+
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
71+
eps=config.rms_norm_eps)
72+
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
73+
eps=config.rms_norm_eps)
74+
75+
76+
@support_torch_compile
77+
class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP):
78+
packed_modules_mapping = {
79+
"qkv_proj": [
80+
"q_proj",
81+
"k_proj",
82+
"v_proj",
83+
],
84+
"gate_up_proj": ["up_proj", "down_proj"]
85+
}
86+
87+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
88+
config = vllm_config.model_config.hf_config
89+
self.vllm_config = vllm_config
90+
cache_config = vllm_config.cache_config
91+
assert not cache_config.enable_prefix_caching, \
92+
"Qwen3NextMTP currently does not support prefix caching"
93+
94+
self.quant_config = vllm_config.quant_config
95+
96+
super(Qwen3NextMTP, self).__init__()
97+
self.config = config
98+
self.model = CustomQwen3NextMultiTokenPredictor(
99+
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
100+
self.unpadded_vocab_size = config.vocab_size
101+
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
102+
config.hidden_size,
103+
org_num_embeddings=config.vocab_size,
104+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
105+
prefix=maybe_prefix(prefix, "lm_head"))
106+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
107+
config.vocab_size)
108+
self.make_empty_intermediate_tensors = (
109+
self.model.make_empty_intermediate_tensors)

vllm_ascend/ops/casual_conv1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def causal_conv1d_ref(
5555
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
5656
dtype_in) # (batch, dim, width - 1)
5757
if final_states_out is not None:
58-
final_states_out.copy_(final_states)
58+
final_states_out[..., :(width - 1)].copy_(final_states)
5959
else:
6060
final_states_out = final_states
6161
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)

vllm_ascend/spec_decode/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def get_spec_decode_method(method,
3535
if is_torchair_graph:
3636
return TorchairMtpProposer(vllm_config, device, runner)
3737
return MtpProposer(vllm_config, device, runner)
38+
elif method == 'qwen3_next_mtp':
39+
return MtpProposer(vllm_config, device, runner)
3840
else:
3941
raise ValueError("Unknown speculative decoding method: "
4042
f"{method}")

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,18 @@ def load_model(self, model) -> None:
146146
with set_default_torch_dtype(
147147
draft_model_config.dtype), set_current_vllm_config(
148148
self.vllm_config):
149-
self.model = DeepSeekMTP(
150-
vllm_config=self.vllm_config).to(target_device)
151-
149+
architecture = self.vllm_config.model_config.architecture
150+
if architecture == "DeepseekV3ForCausalLM":
151+
self.model = DeepSeekMTP(
152+
vllm_config=self.vllm_config).to(target_device)
153+
elif architecture == "Qwen3NextForCausalLM":
154+
# use lazy import to avoid a patch bug
155+
from vllm_ascend.models.qwen3_next_mtp import \
156+
CustomQwen3NextMTP
157+
self.model = CustomQwen3NextMTP(
158+
vllm_config=self.vllm_config).to(target_device)
159+
else:
160+
raise ValueError("Invalid architecture for mtp.")
152161
draft_attn_layer_names = (get_layers_from_vllm_config(
153162
self.vllm_config, AttentionLayerBase).keys() -
154163
target_attn_layer_names)
@@ -218,7 +227,11 @@ def generate_token_ids(self,
218227
aux_hidden_states: torch.Tensor = None):
219228
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
220229
if attn_metadata is not None and isinstance(attn_metadata, dict):
221-
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
230+
architecture = self.vllm_config.model_config.architecture
231+
if architecture == "Qwen3NextForCausalLM":
232+
attn_metadata = attn_metadata['model.layers.3.self_attn.attn']
233+
else:
234+
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
222235

223236
if self.speculative_config.disable_padded_drafter_batch:
224237
# When padded-batch is disabled, the sampled_token_ids should be

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,7 +1851,7 @@ def _prepare_inputs(
18511851
extra_attn_metadata_args = dict(
18521852
num_accepted_tokens=self.num_accepted_tokens.
18531853
gpu[:num_reqs],
1854-
num_draft_tokens=self.num_draft_tokens.
1854+
num_decode_draft_tokens_cpu=self.num_draft_tokens.
18551855
gpu[:num_reqs],
18561856
)
18571857
attn_metadata_i = builder.build(
@@ -1944,11 +1944,10 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens,
19441944
attn_state = AscendAttentionState.SpecDecoding
19451945
# Speculative decoding.
19461946
elif np.all(num_valid_tokens == 1):
1947-
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
1948-
or self.drafter.name == SpecDcodeType.EAGLE3):
1949-
attn_state = AscendAttentionState.ChunkedPrefill
1950-
else:
1947+
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
19511948
attn_state = AscendAttentionState.SpecDecoding
1949+
else:
1950+
attn_state = AscendAttentionState.ChunkedPrefill
19521951
# splitfuse
19531952
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
19541953
attn_state = AscendAttentionState.ChunkedPrefill
@@ -2544,7 +2543,7 @@ def propose_draft_token_ids(sampled_token_ids):
25442543
with ProfileExecuteDuration().capture_async("Draft"):
25452544
if self.speculative_config:
25462545
use_padded_batch_for_eagle = self.speculative_config and \
2547-
self.speculative_config.method == "deepseek_mtp" and \
2546+
self.speculative_config.method in ("deepseek_mtp", "qwen3_next_mtp") and \
25482547
not self.speculative_config.disable_padded_drafter_batch
25492548
if use_padded_batch_for_eagle:
25502549
# EAGLE speculative decoding can use the GPU sampled tokens

0 commit comments

Comments
 (0)