Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,18 @@ jobs:
speculative_tests_changed:
- "tests/singlecard/spec_decode/**"
- "tests/multicard/spec_decode_e2e/**"
- "vllm_ascend/worker/worker.py"
- "vllm_ascend/worker/model_runner.py"
- "vllm_ascend/worker/multi_step_runner.py"
- "vllm_ascend/worker/multi_step_worker.py"
- "vllm_ascend/patch/patch_rejection_sampler.py"
- "vllm_ascend/patch/patch_spec_decode_worker.py"
- "vllm_ascend/patch/patch_multi_step_worker.py"
- "vllm_ascend/worker/draft_model_runner.py"
- "vllm_ascend/patch/worker/patch_common/patch_metrics.py"
- "vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py"
- "vllm_ascend/patch/worker/patch_common/patch_multi_step_worker.py"

- name: Run vllm-project/vllm-ascend Speculative Decode test
env:
VLLM_USE_V1: 0
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true'
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
Expand Down
18 changes: 18 additions & 0 deletions tests/singlecard/spec_decode/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from vllm_ascend.patch import worker # noqa: F401
18 changes: 18 additions & 0 deletions tests/singlecard/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
# limitations under the License.
#

import shutil
from itertools import cycle
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union

import pytest
Expand Down Expand Up @@ -177,6 +179,12 @@ def _check_logprobs_when_output_disabled(
assert spec_pos_logprob_token_id in baseline_pos_logprobs


def _clean_torchair_cache():
cache_path = Path.cwd() / '.torchair_cache'
if cache_path.exists() and cache_path.is_dir():
shutil.rmtree(cache_path)


def run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
Expand Down Expand Up @@ -219,10 +227,20 @@ def run_equality_correctness_test(
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)

# TODO current torchair graph mode needs clean torchair cache.
# if do not clean, it will raise error
additional_config = common_llm_kwargs.get("additional_config")
enable_graph_mode = additional_config.get(
"enable_graph_mode") if additional_config else False

with vllm_runner(**org_args) as vllm_model:
if enable_graph_mode:
_clean_torchair_cache()
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

with vllm_runner(**sd_args) as vllm_model:
if enable_graph_mode:
_clean_torchair_cache()
if ensure_all_accepted or expected_acceptance_rate is not None:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers[
Expand Down
148 changes: 125 additions & 23 deletions tests/singlecard/spec_decode/e2e/test_mtp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,32 @@
With those tests, we can say at least, mtp would not break the
correctess for the target model outputs.
"""
import os

import pytest

from .conftest import run_equality_correctness_test

# main model
# NOTE vLLM use fp8 model, vllm-ascend use bf16 model
MAIN_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
# NOTE both main model and MTP are bfloat16
FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"

# NOTE main model is w8a8, MTP is bfloat16
QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"

# TODO when msmodelslim can quantify both main and MTP model
# This UT should use w8a8 fully weights.

# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS = 1

# precision
PRECISION = "bfloat16"
os.environ["VLLM_USE_MODELSCOPE"] = "True"


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand All @@ -66,7 +75,7 @@
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.85
Expand Down Expand Up @@ -97,6 +106,7 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.skipif(True, reason="quant model is not ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand All @@ -110,7 +120,53 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
"model_name": QUANT_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):

run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model_name": FLOAT_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.85
Expand Down Expand Up @@ -158,15 +214,13 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
["disable_logprobs"])


@pytest.mark.skipif(
True,
reason=
"Open it when vllm-ascend support graph mode and support enforce_eager status is False to run model in graph mode"
)
@pytest.mark.skipif(True, reason="torchair ut can not clean mem.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
"additional_config": {
'enable_graph_mode': True,
},

# Print spec metrics.
"disable_log_stats": False,
Expand All @@ -175,7 +229,7 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
Expand All @@ -192,20 +246,64 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
output_len: int, seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
def test_mtp_e2e_greedy_correctness_torchair_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with torchair graph enabled and different
batch sizes using bfloat16 weights."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.skipif(True, reason="quant model is not ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"additional_config": {
'enable_graph_mode': True,
},

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model_name": QUANT_MODEL,
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with torchair graph enabled and different
batch sizes using quant weights."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand All @@ -221,7 +319,7 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.9
Expand Down Expand Up @@ -256,6 +354,8 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
batch_size, output_len, seed)


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand All @@ -266,7 +366,7 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.9
Expand Down Expand Up @@ -305,6 +405,8 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand All @@ -315,7 +417,7 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
"model_name": FLOAT_MODEL,

# GPU memory utilization
"gpu_memory_utilization": 0.9
Expand Down
1 change: 0 additions & 1 deletion tests/singlecard/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import create_batch, mock_worker
from vllm_ascend.patch.worker import patch_common # noqa: F401


@pytest.mark.parametrize('queue_size', [4])
Expand Down
1 change: 0 additions & 1 deletion tests/singlecard/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
assert_logprobs_dict_allclose, create_batch,
create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache)
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker

Expand Down
1 change: 0 additions & 1 deletion tests/singlecard/spec_decode/test_ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from tests.singlecard.spec_decode.utils import (
create_seq_group_metadata_from_prompts, create_worker)
from vllm_ascend.patch.worker import patch_common # noqa: F401


def test_ngram_algo_correctness_for_single_no_match():
Expand Down
2 changes: 0 additions & 2 deletions tests/singlecard/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
from tests.singlecard.spec_decode.utils import (create_batch,
create_sampler_output_list,
create_worker, mock_worker)
# patch SpecDecodeWorker, AsyncMetricsCollector
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker

Expand Down
Loading