diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py index 9fda522021..cf3382318d 100644 --- a/tests/e2e/multicard/test_qwen3_next.py +++ b/tests/e2e/multicard/test_qwen3_next.py @@ -20,10 +20,17 @@ Run `pytest tests/e2e/multicard/test_qwen3_next.py`. """ +import os +from unittest.mock import patch from tests.e2e.conftest import VllmRunner +# NZ will cause precision error in Qwen3-Next +# When it is fixed, this set-up can be removed +_IS_ENABLE_NZ = "VLLM_ASCEND_ENABLE_NZ" + +@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"}) def test_models_distributed_Qwen3_NEXT_TP4(): example_prompts = [ "Hello, my name is", @@ -36,8 +43,10 @@ def test_models_distributed_Qwen3_NEXT_TP4(): distributed_executor_backend="mp", enforce_eager=True) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model +@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"}) def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY(): example_prompts = [ "Hello, my name is", @@ -54,3 +63,50 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY(): "cudagraph_capture_sizes": [1, 8, 24, 48, 60] }) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + +@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"}) +def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY(): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + max_tokens = 20 + + with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct", + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.8, + distributed_executor_backend="mp") as vllm_model: + ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct", + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.8, + distributed_executor_backend="mp", + speculative_config={ + "method": "qwen3_next_mtp", + "num_speculative_tokens": 1 + }) as spec_vllm_model: + spec_outputs = spec_vllm_model.generate_greedy(example_prompts, + max_tokens) + del spec_vllm_model + + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + ref_token_ids = ref_output[0] + spec_token_ids = spec_output[0] + if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output[1]}") + print(f"spec_output: {spec_output[1]}") + + assert matches > int(0.66 * len(ref_outputs)) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 6806563e35..9c732f61de 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -77,6 +77,7 @@ def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group): mock_get_dcp_group.return_value = dcp_group self.mock_vllm_config = MagicMock() + self.mock_vllm_config.speculative_config = None self.mock_vllm_config.model_config.max_model_len = 640 self.mock_vllm_config.cache_config.block_size = 64 self.mock_vllm_config.compilation_config.cudagraph_mode = None diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 7a1f79a493..6ea9058fd7 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -252,6 +252,17 @@ def __init__( self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" + + AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold + def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool: return False diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 21ea48e3ab..956df2eb31 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -35,6 +35,10 @@ def register_model(): "PanguProMoEForCausalLM", "vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" ) + ModelRegistry.register_model( "Qwen3NextForCausalLM", "vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM") + + ModelRegistry.register_model( + "Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP") diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index f5b4b8a142..b0bfde0eb6 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -260,6 +260,24 @@ def _forward( mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv + # 2.1: process the mutli-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = mixed_qkv_spec.view( + attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) + mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0] + [:attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + validate_data=False, + ) + mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') + # 2.2: process the remaining part if attn_metadata.num_prefills > 0: # - "cache_indices" updates the conv_state cache in positions diff --git a/vllm_ascend/models/qwen3_next_mtp.py b/vllm_ascend/models/qwen3_next_mtp.py new file mode 100644 index 0000000000..c17d969cb2 --- /dev/null +++ b/vllm_ascend/models/qwen3_next_mtp.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next MTP model.""" +import torch +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.qwen3_next_mtp import ( + Qwen3NextMTP, Qwen3NextMultiTokenPredictor) +from vllm.model_executor.models.utils import ( + make_empty_intermediate_tensors_factory, maybe_prefix) +from vllm.transformers_utils.configs import Qwen3NextConfig + +from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer, + Qwen3NextRMSNorm) + + +@support_torch_compile +class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3NextMultiTokenPredictor, self).__init__() + + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + config: Qwen3NextConfig = model_config.hf_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.fc = ColumnParallelLinear(self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False, + quant_config=quant_config, + prefix=f'{prefix}.fc') + + # use old version mtp layer name to avoid a exception in vllm + self.layers = torch.nn.ModuleList( + CustomQwen3NextDecoderLayer( + vllm_config, + layer_type="full_attention", + prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}', + ) for idx in range(self.num_mtp_layers)) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + +@support_torch_compile +class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + cache_config = vllm_config.cache_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3NextMTP currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super(Qwen3NextMTP, self).__init__() + self.config = config + self.model = CustomQwen3NextMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead(self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head")) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) diff --git a/vllm_ascend/ops/casual_conv1d.py b/vllm_ascend/ops/casual_conv1d.py index 2d008899ad..7ddc9cecca 100644 --- a/vllm_ascend/ops/casual_conv1d.py +++ b/vllm_ascend/ops/casual_conv1d.py @@ -55,7 +55,7 @@ def causal_conv1d_ref( final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( dtype_in) # (batch, dim, width - 1) if final_states_out is not None: - final_states_out.copy_(final_states) + final_states_out[..., :(width - 1)].copy_(final_states) else: final_states_out = final_states out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 3e17944c5c..6abe8777cd 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -29,9 +29,9 @@ def get_spec_decode_method(method, is_torchair_graph=False): if method == "ngram": return NgramProposer(vllm_config, device, runner) - elif method in ["eagle", "eagle3"]: + elif method in ("eagle", "eagle3"): return EagleProposer(vllm_config, device, runner) - elif method == 'deepseek_mtp': + elif method in ('deepseek_mtp', 'qwen3_next_mtp'): if is_torchair_graph: return TorchairMtpProposer(vllm_config, device, runner) return MtpProposer(vllm_config, device, runner) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 9f6d787471..2d4e239e27 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,3 +1,4 @@ +import importlib from typing import Optional import numpy as np @@ -12,7 +13,6 @@ from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import \ process_weights_after_loading -from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils import cdiv @@ -42,6 +42,26 @@ PADDING_SLOT_ID = -1 +_MTP_MODELS = { + "DeepseekV3ForCausalLM": + ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), + "Qwen3NextForCausalLM": + ("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP") +} + +_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn' + +_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'} + + +def _load_model(architecture): + if architecture not in _MTP_MODELS: + raise ValueError("Invalid architecture for mtp.") + module_name, model_name = _MTP_MODELS[architecture] + module = importlib.import_module(module_name) + model = getattr(module, model_name) + return model + class MtpProposer(Proposer): @@ -150,9 +170,7 @@ def load_model(self, model) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - self.model = DeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) - + self._init_mtp_model() draft_attn_layer_names = (get_layers_from_vllm_config( self.vllm_config, AttentionLayerBase).keys() - target_attn_layer_names) @@ -228,8 +246,7 @@ def generate_token_ids(self, attn_metadata=None, aux_hidden_states: torch.Tensor = None): common_attn_metadata = self.runner.spec_decode_common_attn_metadata - if attn_metadata is not None and isinstance(attn_metadata, dict): - attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] + attn_metadata = self._get_attn_metadata(attn_metadata) if self.speculative_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be @@ -311,6 +328,20 @@ def generate_token_ids(self, return draft_token_ids + def _init_mtp_model(self): + architecture = self.vllm_config.model_config.architecture + target_device = self.vllm_config.device_config.device + model = _load_model(architecture) + self.model = model(vllm_config=self.vllm_config).to(target_device) + + def _get_attn_metadata(self, attn_metadata): + if attn_metadata is not None and isinstance(attn_metadata, dict): + architecture = self.vllm_config.model_config.architecture + layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER) + attn_metadata = attn_metadata[layer_name] + + return attn_metadata + def _prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 66868dd649..2643d19e81 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1847,7 +1847,7 @@ def _prepare_inputs( extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens. gpu[:num_reqs], - num_draft_tokens=self.num_draft_tokens. + num_decode_draft_tokens_cpu=self.num_draft_tokens. gpu[:num_reqs], ) attn_metadata_i = builder.build( @@ -1943,11 +1943,10 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens, attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): - if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE - or self.drafter.name == SpecDcodeType.EAGLE3): - attn_state = AscendAttentionState.ChunkedPrefill - else: + if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': attn_state = AscendAttentionState.SpecDecoding + else: + attn_state = AscendAttentionState.ChunkedPrefill # splitfuse elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: attn_state = AscendAttentionState.ChunkedPrefill @@ -2543,7 +2542,7 @@ def propose_draft_token_ids(sampled_token_ids): with ProfileExecuteDuration().capture_async("Draft"): if self.speculative_config: use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.method == "deepseek_mtp" and \ + self.speculative_config.method in ("deepseek_mtp", "qwen3_next_mtp") and \ not self.speculative_config.disable_padded_drafter_batch if use_padded_batch_for_eagle: # EAGLE speculative decoding can use the GPU sampled tokens