Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/e2e/multicard/test_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@

Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
"""
import os
from unittest.mock import patch

from tests.e2e.conftest import VllmRunner

# NZ will cause precision error in Qwen3-Next
# When it is fixed, this set-up can be removed
_IS_ENABLE_NZ = "VLLM_ASCEND_ENABLE_NZ"


@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_TP4():
example_prompts = [
"Hello, my name is",
Expand All @@ -36,8 +43,10 @@ def test_models_distributed_Qwen3_NEXT_TP4():
distributed_executor_backend="mp",
enforce_eager=True) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model


@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
example_prompts = [
"Hello, my name is",
Expand All @@ -54,3 +63,50 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
}) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model


@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
max_tokens = 20

with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp") as vllm_model:
ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": 1
}) as spec_vllm_model:
spec_outputs = spec_vllm_model.generate_greedy(example_prompts,
max_tokens)
del spec_vllm_model

matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0]
spec_token_ids = spec_output[0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1]}")
print(f"spec_output: {spec_output[1]}")

assert matches > int(0.66 * len(ref_outputs))
1 change: 1 addition & 0 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
mock_get_dcp_group.return_value = dcp_group

self.mock_vllm_config = MagicMock()
self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64
self.mock_vllm_config.compilation_config.cudagraph_mode = None
Expand Down
11 changes: 11 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def __init__(
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0

self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"

AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold

def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
return False
Expand Down
4 changes: 4 additions & 0 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def register_model():
"PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)

ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")

ModelRegistry.register_model(
"Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP")
18 changes: 18 additions & 0 deletions vllm_ascend/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ def _forward(
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv

# 2.1: process the mutli-query part
if spec_sequence_masks is not None:
mixed_qkv_spec = mixed_qkv_spec.view(
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=spec_state_indices_tensor[:, 0]
[:attn_metadata.num_spec_decodes],
num_accepted_tokens=num_accepted_tokens,
validate_data=False,
)
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')

# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
# - "cache_indices" updates the conv_state cache in positions
Expand Down
109 changes: 109 additions & 0 deletions vllm_ascend/models/qwen3_next_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next MTP model."""
import torch
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.qwen3_next_mtp import (
Qwen3NextMTP, Qwen3NextMultiTokenPredictor)
from vllm.model_executor.models.utils import (
make_empty_intermediate_tensors_factory, maybe_prefix)
from vllm.transformers_utils.configs import Qwen3NextConfig

from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer,
Qwen3NextRMSNorm)


@support_torch_compile
class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3NextMultiTokenPredictor, self).__init__()

model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config

self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size

self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)

self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)

self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f'{prefix}.fc')

# use old version mtp layer name to avoid a exception in vllm
self.layers = torch.nn.ModuleList(
CustomQwen3NextDecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
) for idx in range(self.num_mtp_layers))

self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)


@support_torch_compile
class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"]
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, \
"Qwen3NextMTP currently does not support prefix caching"

self.quant_config = vllm_config.quant_config

super(Qwen3NextMTP, self).__init__()
self.config = config
self.model = CustomQwen3NextMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"))
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
2 changes: 1 addition & 1 deletion vllm_ascend/ops/casual_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def causal_conv1d_ref(
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
final_states_out[..., :(width - 1)].copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def get_spec_decode_method(method,
is_torchair_graph=False):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
elif method in ("eagle", "eagle3"):
return EagleProposer(vllm_config, device, runner)
elif method == 'deepseek_mtp':
elif method in ('deepseek_mtp', 'qwen3_next_mtp'):
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)
Expand Down
43 changes: 37 additions & 6 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
from typing import Optional

import numpy as np
Expand All @@ -12,7 +13,6 @@
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import cdiv
Expand Down Expand Up @@ -42,6 +42,26 @@

PADDING_SLOT_ID = -1

_MTP_MODELS = {
"DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"Qwen3NextForCausalLM":
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
}

_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'

_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}


def _load_model(architecture):
if architecture not in _MTP_MODELS:
raise ValueError("Invalid architecture for mtp.")
module_name, model_name = _MTP_MODELS[architecture]
module = importlib.import_module(module_name)
model = getattr(module, model_name)
return model


class MtpProposer(Proposer):

Expand Down Expand Up @@ -150,9 +170,7 @@ def load_model(self, model) -> None:
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = DeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)

self._init_mtp_model()
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
Expand Down Expand Up @@ -228,8 +246,7 @@ def generate_token_ids(self,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
attn_metadata = self._get_attn_metadata(attn_metadata)

if self.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
Expand Down Expand Up @@ -311,6 +328,20 @@ def generate_token_ids(self,

return draft_token_ids

def _init_mtp_model(self):
architecture = self.vllm_config.model_config.architecture
target_device = self.vllm_config.device_config.device
model = _load_model(architecture)
self.model = model(vllm_config=self.vllm_config).to(target_device)

def _get_attn_metadata(self, attn_metadata):
if attn_metadata is not None and isinstance(attn_metadata, dict):
architecture = self.vllm_config.model_config.architecture
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
attn_metadata = attn_metadata[layer_name]

return attn_metadata

def _prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
Expand Down
11 changes: 5 additions & 6 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,7 @@ def _prepare_inputs(
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_draft_tokens=self.num_draft_tokens.
num_decode_draft_tokens_cpu=self.num_draft_tokens.
gpu[:num_reqs],
)
attn_metadata_i = builder.build(
Expand Down Expand Up @@ -1943,11 +1943,10 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens,
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
or self.drafter.name == SpecDcodeType.EAGLE3):
attn_state = AscendAttentionState.ChunkedPrefill
else:
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.ChunkedPrefill
# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
Expand Down Expand Up @@ -2543,7 +2542,7 @@ def propose_draft_token_ids(sampled_token_ids):
with ProfileExecuteDuration().capture_async("Draft"):
if self.speculative_config:
use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.method == "deepseek_mtp" and \
self.speculative_config.method in ("deepseek_mtp", "qwen3_next_mtp") and \
not self.speculative_config.disable_padded_drafter_batch
if use_padded_batch_for_eagle:
# EAGLE speculative decoding can use the GPU sampled tokens
Expand Down
Loading