Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
Expand Down Expand Up @@ -118,9 +119,9 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method.endswith("mtp"):
elif args.method == "mtp":
speculative_config = {
"method": args.method,
"method": "mtp",
"num_speculative_tokens": args.num_spec_tokens,
}
else:
Expand Down
65 changes: 65 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform

MTP_SIMILARITY_RATE = 0.8


def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
Expand Down Expand Up @@ -222,3 +224,66 @@ def test_eagle_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
],
ids=["mimo", "deepseek"])
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")

method, model_name, tp_size = model_setup

ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size,
trust_remote_code=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"num_speculative_tokens": 1,
"max_model_len": 2048,
},
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 80% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
195 changes: 195 additions & 0 deletions tests/v1/spec_decode/test_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest import mock

import pytest
import torch

from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer

mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"


def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
"""Create an MTP proposer with unified model configuration."""
model_config = ModelConfig(model=mimo_7b_dir,
runner="generate",
max_model_len=100,
trust_remote_code=True)

speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=mimo_7b_dir,
method="mtp",
num_speculative_tokens=num_speculative_tokens,
)

vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())

return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)


@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach."""

# Setup mocks
mock_model = mock.MagicMock()
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model

target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]

mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
mock_get_pp_group.return_value = mock_pp_group

# Create target model
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock

target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)
target_model.lm_head = mock.MagicMock()

# Create MTP proposer
proposer = _create_mtp_proposer(num_speculative_tokens=4)
proposer.load_model(target_model)

# Verify MTP-specific behavior:
# Model is loaded
mock_get_model.assert_called_once()
# MTP shares lm_head with target model
assert proposer.model.lm_head == target_model.lm_head
# MTP shares embed_tokens with target model
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens


@pytest.mark.parametrize("num_speculative_tokens", [1])
def test_mtp_propose(num_speculative_tokens, monkeypatch):
"""Test that MTP's forward method returns hidden states directly"""

device = torch.device(current_platform.device_type)
batch_size = 2
seq_lens = [5, 3]
total_tokens = sum(seq_lens)
vocab_size = 100

proposer = _create_mtp_proposer(num_speculative_tokens)
hidden_size = proposer.hidden_size

# Mock the MTP model to verify it returns hidden states directly
model_mock = mock.MagicMock()

# MTP returns hidden states directly
if num_speculative_tokens == 1:
model_mock.return_value = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
# Multiple forward passes for multi-token speculation
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
h_states = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append(h_states)
model_mock.side_effect = forward_returns

# Mock compute_logits
def create_deterministic_logits(batch_size, vocab_size, token_offset):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
logits[:, token_offset] = 100.0
return logits

if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = create_deterministic_logits(
batch_size, vocab_size, 42)
else:
logits_returns = [
create_deterministic_logits(batch_size, vocab_size, 42 + i)
for i in range(num_speculative_tokens)
]
model_mock.compute_logits.side_effect = logits_returns

proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"]

# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)

target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
sampling_metadata = mock.MagicMock()

# Setup attention metadata
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)

attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)

proposer.runner = mock.MagicMock()
proposer.attn_metadata_builder = attn_metadata_builder

# Run propose
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)

# Verify the model was called correctly
assert model_mock.called
# Verify output shape
assert result.shape == (batch_size, num_speculative_tokens)
50 changes: 19 additions & 31 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp"]
"longcat_flash_mtp", "mtp"]
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")


@config
Expand Down Expand Up @@ -207,11 +209,16 @@ def __post_init__(self):
# can not be detected, it will be considered as the "draft_model" by
# default.

if self.method in MTP_MODEL_TYPES:
logger.warning("method `%s` is deprecated and replaced with mtp.",
self.method)
self.method = "mtp"

if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
if (self.target_model_config
and self.target_model_config.hf_text_config.model_type
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
if self.method == "mtp":
assert (
self.target_model_config
is not None), "target_model_config must be present for mtp"
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
Expand Down Expand Up @@ -312,31 +319,13 @@ def __post_init__(self):
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
in MTP_MODEL_TYPES):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
"Enabling num_speculative_tokens > 1 will run" \
"multiple times of forward on same MTP layer" \
",which may result in lower acceptance rate" \
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
Expand All @@ -353,7 +342,7 @@ def __post_init__(self):
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
"eagle, or mtp.")

# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
Expand Down Expand Up @@ -562,8 +551,7 @@ def num_lookahead_slots(self) -> int:
return self.num_speculative_tokens

def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")
return self.method in ("eagle", "eagle3", "mtp")

def __repr__(self) -> str:
method = self.method
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
raise NotImplementedError(
"Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.")
"such as ngram, medusa, eagle, or mtp.")

V1_BACKENDS = [
"FLASH_ATTN",
Expand Down
Loading