Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

import os

import pytest
from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode

from tests.e2e.conftest import VllmRunner

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


@pytest.fixture
def sampling_config():
Expand All @@ -17,12 +21,12 @@ def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"


def mtp_correctness(
sampling_config: SamplingParams,
model_name: str,
num_speculative_tokens: int,
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
):
def mtp_correctness(sampling_config: SamplingParams,
model_name: str,
num_speculative_tokens: int,
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
enforce_eager=False,
disable_padded_drafter_batch=True):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
Expand All @@ -37,7 +41,7 @@ def mtp_correctness(
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
enforce_eager=False) as ref_llm:
enforce_eager=enforce_eager) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)

graph_mode_str = "PIECEWISE"
Expand All @@ -54,8 +58,9 @@ def mtp_correctness(
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": num_speculative_tokens,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
},
enforce_eager=False,
enforce_eager=enforce_eager,
max_model_len=2000,
compilation_config=CompilationConfig(
cudagraph_mode=graph_mode_str),
Expand All @@ -82,6 +87,20 @@ def mtp_correctness(
del spec_llm


def test_mtp1_correctness_eager(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1, enforce_eager=True)


def test_mtp2_correctness_eager(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2, enforce_eager=True)


@pytest.mark.skip("TODO(cmq): Revert me when mtp aclgraph is fixed")
def test_mtp1_correctness_piecewise_graph(
sampling_config: SamplingParams,
Expand Down Expand Up @@ -110,3 +129,47 @@ def test_mtp2_correctness_full_graph(
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)


def test_mtp1_correctness_eager_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
1,
enforce_eager=True,
disable_padded_drafter_batch=False)


def test_mtp2_correctness_eager_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
2,
enforce_eager=True,
disable_padded_drafter_batch=False)


@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
def test_mtp1_correctness_piecewise_graph_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
1,
disable_padded_drafter_batch=False)


@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
def test_mtp2_correctness_piecewise_graph_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
2,
disable_padded_drafter_batch=False)
9 changes: 8 additions & 1 deletion vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer


def get_spec_decode_method(method, vllm_config, device, runner):
def get_spec_decode_method(method,
vllm_config,
device,
runner,
is_torchair_graph=False):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
return EagleProposer(vllm_config, device, runner)
elif method == 'deepseek_mtp':
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
Expand Down
Loading
Loading