From 523db6b5d030b11e8c8241233f4a217c0ffd3ab8 Mon Sep 17 00:00:00 2001 From: mengwei805 Date: Mon, 21 Apr 2025 15:19:08 +0800 Subject: [PATCH] [bugfix] main-sd-bugfix Signed-off-by: mengwei805 --- .github/workflows/vllm_ascend_test.yaml | 3 +- .../spec_decode/e2e/test_mtp_correctness.py | 355 ++++++++++++++++++ vllm_ascend/attention/attention.py | 3 +- vllm_ascend/attention/mla_v1.py | 3 - vllm_ascend/core/scheduler.py | 4 +- vllm_ascend/distributed/parallel_state.py | 3 - vllm_ascend/models/deepseek_mtp.py | 21 +- .../patch_common/patch_spec_decode_worker.py | 4 +- vllm_ascend/quantization/quant_config.py | 6 +- vllm_ascend/worker/draft_model_runner.py | 4 +- 10 files changed, 375 insertions(+), 31 deletions(-) create mode 100644 tests/singlecard/spec_decode/e2e/test_mtp_correctness.py diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index f5a7038e2a..19b021e950 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -161,7 +161,8 @@ jobs: if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true' run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then - pytest -sv tests/singlecard/spec_decode + pytest -sv tests/singlecard/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process + pytest -sv tests/singlecard/spec_decode --ignore=tests/singlecard/spec_decode/e2e/test_mtp_correctness.py fi - name: Run vllm-project/vllm test for V0 Engine diff --git a/tests/singlecard/spec_decode/e2e/test_mtp_correctness.py b/tests/singlecard/spec_decode/e2e/test_mtp_correctness.py new file mode 100644 index 0000000000..18841fb749 --- /dev/null +++ b/tests/singlecard/spec_decode/e2e/test_mtp_correctness.py @@ -0,0 +1,355 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mtp_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, mtp would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_equality_correctness_test + +# main model +# NOTE vLLM use fp8 model, vllm-ascend use bf16 model +MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16" + +# max. number of speculative tokens: this corresponds to +# num_nextn_predict_layers in the config.json of the speculator model. +MAX_SPEC_TOKENS = 1 + +# precision +PRECISION = "bfloat16" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int): + + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": False, + }, + }, + { + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs": True, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test( + vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs["speculative_config"] + ["disable_logprobs"]) + + +@pytest.mark.skipif( + True, + reason= + "Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode" +) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size: int, + output_len: int, seed: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_with_preemption( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_config": { + "num_speculative_tokens": k, + }, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_different_k(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + + # GPU memory utilization + "gpu_memory_utilization": 0.9 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_config": { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_by_batch_size": 4 + }, +}]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_disable_queue(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index f3acf087db..6943fe8759 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -113,7 +113,8 @@ def get_splitfuse_attn_mask( self.update_attn_cache(max_seq_len, dtype, device) # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation # is not the same. Fix this in the future when kernel is ready. - if self.attn_mask_cache[0][1] > 0: + if self.attn_mask_cache.numel( + ) > 1 and self.attn_mask_cache[0][1] > 0: attn_mask = self.get_attn_mask( # type: ignore max_seq_len, dtype, device) attn_mask *= -10000 diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e5b7e73035..3e064ec6f9 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -6,7 +6,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) -from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -21,8 +20,6 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch -logger = init_logger(__name__) - class AscendMLABackend(AttentionBackend): diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 22b503eea3..cdcd58bbff 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -16,14 +16,12 @@ # from collections import deque -from vllm.logger import init_logger +from vllm.logger import logger from vllm.utils import cdiv from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.request import Request, RequestStatus -logger = init_logger(__name__) - class AscendScheduler(Scheduler): """This Scheduler extends vllm's original v1 scheduler diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index acb5048efb..d7be8c3bdc 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -36,7 +36,6 @@ def init_ascend_model_parallel( expert_tensor_parallel_size) global _EP - assert _EP is None, ("expert parallel group is already initialized") group_ranks = [] for i in range(num_expert_parallel_groups): ranks = list(range(i, world_size, num_expert_parallel_groups)) @@ -49,8 +48,6 @@ def init_ascend_model_parallel( group_ranks = [] global _ETP - assert _ETP is None, ( - "expert tensor parallel group is already initialized") for i in range(num_expert_tensor_parallel_groups): ranks = list( range(i * expert_tensor_parallel_size, diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 983b7fd140..a19d666a86 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -1,6 +1,6 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Adapted from vllm/model_executor/models/qwen2_vl.py +# Adapted from vllm/model_executor/models/deepseek_mtp.py # Copyright 2023 The vLLM team. # # This file is a part of the vllm-ascend project. @@ -17,12 +17,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Optional import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -70,8 +69,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, @@ -91,8 +88,6 @@ def forward( hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=None) hidden_states = residual + hidden_states return hidden_states @@ -130,8 +125,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, @@ -140,8 +133,6 @@ def forward( return self.layers_list[current_step_idx]( input_ids, positions, - kv_caches[current_step_idx], - attn_metadata, previous_hidden_states, inputs_embeds, current_step_idx, @@ -162,6 +153,14 @@ def compute_logits( class CustomDeepSeekMTP(DeepSeekMTP): + # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; + # NOTE 2.The description file generated by the current msmodelslim tool does not have + # MTP layer info. Please manually add it and set the value to FLOAT. + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) diff --git a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py index 223fa3d36f..040e62e6fa 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py +++ b/vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py @@ -18,7 +18,7 @@ from typing import Any, Dict, Optional from vllm.config import ParallelConfig -from vllm.logger import init_logger +from vllm.logger import logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.spec_decode_base_sampler import \ SpecDecodeBaseSampler @@ -34,8 +34,6 @@ from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner -logger = init_logger(__name__) - def create_worker( cls, diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 702829e041..3f3646ba99 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -23,8 +23,6 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.fused_moe.layer import \ - UnquantizedFusedMoEMethod from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, RowParallelLinear, UnquantizedLinearMethod) @@ -36,6 +34,8 @@ from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs +from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod + from .quantizer import AscendQuantizer @@ -97,7 +97,7 @@ def get_quant_method(self, layer: torch.nn.Module, elif isinstance(layer, FusedMoE): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): - return UnquantizedFusedMoEMethod() + return AscendUnquantizedFusedMoEMethod() return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping) return None diff --git a/vllm_ascend/worker/draft_model_runner.py b/vllm_ascend/worker/draft_model_runner.py index 9ec2c79bdb..162c1ee55a 100644 --- a/vllm_ascend/worker/draft_model_runner.py +++ b/vllm_ascend/worker/draft_model_runner.py @@ -19,7 +19,7 @@ import torch from vllm.forward_context import set_forward_context -from vllm.logger import init_logger +from vllm.logger import logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs from vllm.sequence import ExecuteModelRequest, IntermediateTensors @@ -29,8 +29,6 @@ from vllm_ascend.attention.attention import AscendMetadata -logger = init_logger(__name__) - # A flag to enable debug prints for the updated input tensors # before each step. debug_advance_input = False