Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ jobs:
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true'
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard/spec_decode
pytest -sv tests/singlecard/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make more comments on why it needs a clean process

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will add more comments

  1. i see a lot of comments it needs a clean process in vLLM test_pipeline, so i add it here, i think vLLM UT arch maybe have some problem?
  2. in this case, it needs a clean process becase it use bf16 and other UTs use fp16, if they are in a common process, it will failed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think 2 is the main reason in our case.

pytest -sv tests/singlecard/spec_decode --ignore=tests/singlecard/spec_decode/e2e/test_mtp_correctness.py
fi

- name: Run vllm-project/vllm test for V0 Engine
Expand Down
355 changes: 355 additions & 0 deletions tests/singlecard/spec_decode/e2e/test_mtp_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mtp_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This docstring details important information on the testing methodology.

Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.

Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.

However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.

With those tests, we can say at least, mtp would not break the
correctess for the target model outputs.
"""

import pytest

from .conftest import run_equality_correctness_test

# main model
# NOTE vLLM use fp8 model, vllm-ascend use bf16 model
MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"

# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS = 1

# precision
PRECISION = "bfloat16"


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):

run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):

run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])


@pytest.mark.skipif(
True,
reason=
"Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"
)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
output_len: int, seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Precision
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Precision
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Precision
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.9
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


if __name__ == "__main__":
import pytest
pytest.main([__file__])
3 changes: 2 additions & 1 deletion vllm_ascend/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def get_splitfuse_attn_mask(
self.update_attn_cache(max_seq_len, dtype, device)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
if self.attn_mask_cache[0][1] > 0:
if self.attn_mask_cache.numel(
) > 1 and self.attn_mask_cache[0][1] > 0:
attn_mask = self.get_attn_mask( # type: ignore
max_seq_len, dtype, device)
attn_mask *= -10000
Expand Down
3 changes: 0 additions & 3 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
Expand All @@ -21,8 +20,6 @@
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch

logger = init_logger(__name__)


class AscendMLABackend(AttentionBackend):

Expand Down
Loading