Skip to content

Commit c214d69

Browse files
zixi-qisimon-mo
authored andcommitted
[spec decode] Consolidate speculative decode method name for MTP (#25232)
Signed-off-by: zixi-qi <qizixi@meta.com>
1 parent c3dfb0f commit c214d69

File tree

6 files changed

+287
-40
lines changed

6 files changed

+287
-40
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def parse_args():
5454
"--method",
5555
type=str,
5656
default="eagle",
57+
choices=["ngram", "eagle", "eagle3", "mtp"],
5758
)
5859
parser.add_argument("--num-spec-tokens", type=int, default=2)
5960
parser.add_argument("--prompt-lookup-max", type=int, default=5)
@@ -118,9 +119,9 @@ def main(args):
118119
"prompt_lookup_max": args.prompt_lookup_max,
119120
"prompt_lookup_min": args.prompt_lookup_min,
120121
}
121-
elif args.method.endswith("mtp"):
122+
elif args.method == "mtp":
122123
speculative_config = {
123-
"method": args.method,
124+
"method": "mtp",
124125
"num_speculative_tokens": args.num_spec_tokens,
125126
}
126127
else:

tests/v1/e2e/test_spec_decode.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from vllm.distributed import cleanup_dist_env_and_memory
1616
from vllm.platforms import current_platform
1717

18+
MTP_SIMILARITY_RATE = 0.8
19+
1820

1921
def get_test_prompts(mm_enabled: bool):
2022
prompt_types = ["repeat", "sentence"]
@@ -222,3 +224,66 @@ def test_eagle_correctness(
222224
del spec_llm
223225
torch.cuda.empty_cache()
224226
cleanup_dist_env_and_memory()
227+
228+
229+
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
230+
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
231+
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
232+
],
233+
ids=["mimo", "deepseek"])
234+
def test_mtp_correctness(
235+
monkeypatch: pytest.MonkeyPatch,
236+
sampling_config: SamplingParams,
237+
model_setup: tuple[str, str, int],
238+
mm_enabled: bool,
239+
):
240+
# Generate test prompts inside the function instead of using fixture
241+
test_prompts = get_test_prompts(mm_enabled)
242+
'''
243+
Compare the outputs of a original LLM and a speculative LLM
244+
should be the same when using MTP speculative decoding.
245+
model_setup: (method, model_name, tp_size)
246+
'''
247+
with monkeypatch.context() as m:
248+
m.setenv("VLLM_USE_V1", "1")
249+
m.setenv("VLLM_MLA_DISABLE", "1")
250+
251+
method, model_name, tp_size = model_setup
252+
253+
ref_llm = LLM(model=model_name,
254+
max_model_len=2048,
255+
tensor_parallel_size=tp_size,
256+
trust_remote_code=True)
257+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
258+
del ref_llm
259+
torch.cuda.empty_cache()
260+
cleanup_dist_env_and_memory()
261+
262+
spec_llm = LLM(
263+
model=model_name,
264+
trust_remote_code=True,
265+
tensor_parallel_size=tp_size,
266+
speculative_config={
267+
"method": method,
268+
"num_speculative_tokens": 1,
269+
"max_model_len": 2048,
270+
},
271+
max_model_len=2048,
272+
)
273+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
274+
matches = 0
275+
misses = 0
276+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
277+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
278+
matches += 1
279+
else:
280+
misses += 1
281+
print(f"ref_output: {ref_output.outputs[0].text}")
282+
print(f"spec_output: {spec_output.outputs[0].text}")
283+
284+
# Heuristic: expect at least 80% of the prompts to match exactly
285+
# Upon failure, inspect the outputs to check for inaccuracy.
286+
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
287+
del spec_llm
288+
torch.cuda.empty_cache()
289+
cleanup_dist_env_and_memory()

tests/v1/spec_decode/test_mtp.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from unittest import mock
5+
6+
import pytest
7+
import torch
8+
9+
from tests.v1.attention.utils import (BatchSpec, _Backend,
10+
create_common_attn_metadata,
11+
create_standard_kv_cache_spec,
12+
get_attention_backend)
13+
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
14+
ParallelConfig, SchedulerConfig, SpeculativeConfig,
15+
VllmConfig)
16+
from vllm.config.load import LoadConfig
17+
from vllm.model_executor.models.llama import LlamaForCausalLM
18+
from vllm.platforms import current_platform
19+
from vllm.v1.spec_decode.eagle import EagleProposer
20+
21+
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
22+
23+
24+
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
25+
"""Create an MTP proposer with unified model configuration."""
26+
model_config = ModelConfig(model=mimo_7b_dir,
27+
runner="generate",
28+
max_model_len=100,
29+
trust_remote_code=True)
30+
31+
speculative_config = SpeculativeConfig(
32+
target_model_config=model_config,
33+
target_parallel_config=ParallelConfig(),
34+
model=mimo_7b_dir,
35+
method="mtp",
36+
num_speculative_tokens=num_speculative_tokens,
37+
)
38+
39+
vllm_config = VllmConfig(
40+
model_config=model_config,
41+
cache_config=CacheConfig(),
42+
speculative_config=speculative_config,
43+
device_config=DeviceConfig(device=current_platform.device_type),
44+
parallel_config=ParallelConfig(),
45+
load_config=LoadConfig(),
46+
scheduler_config=SchedulerConfig())
47+
48+
return EagleProposer(vllm_config=vllm_config,
49+
device=current_platform.device_type)
50+
51+
52+
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
53+
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
54+
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
55+
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
56+
mock_get_pp_group):
57+
"""Test MTP-specific model loading with unified model approach."""
58+
59+
# Setup mocks
60+
mock_model = mock.MagicMock()
61+
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
62+
mock_get_model.return_value = mock_model
63+
64+
target_attn_layers = {"target_attn_1": mock.MagicMock()}
65+
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
66+
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
67+
68+
mock_pp_group = mock.MagicMock()
69+
mock_pp_group.world_size = 1
70+
mock_get_pp_group.return_value = mock_pp_group
71+
72+
# Create target model
73+
class _TargetModelStub(LlamaForCausalLM):
74+
model: mock.MagicMock
75+
lm_head: mock.MagicMock
76+
77+
target_model = mock.create_autospec(_TargetModelStub, instance=True)
78+
target_model.model = mock.MagicMock()
79+
target_model.model.embed_tokens.weight.shape = (131072, 4096)
80+
target_model.lm_head = mock.MagicMock()
81+
82+
# Create MTP proposer
83+
proposer = _create_mtp_proposer(num_speculative_tokens=4)
84+
proposer.load_model(target_model)
85+
86+
# Verify MTP-specific behavior:
87+
# Model is loaded
88+
mock_get_model.assert_called_once()
89+
# MTP shares lm_head with target model
90+
assert proposer.model.lm_head == target_model.lm_head
91+
# MTP shares embed_tokens with target model
92+
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
93+
94+
95+
@pytest.mark.parametrize("num_speculative_tokens", [1])
96+
def test_mtp_propose(num_speculative_tokens, monkeypatch):
97+
"""Test that MTP's forward method returns hidden states directly"""
98+
99+
device = torch.device(current_platform.device_type)
100+
batch_size = 2
101+
seq_lens = [5, 3]
102+
total_tokens = sum(seq_lens)
103+
vocab_size = 100
104+
105+
proposer = _create_mtp_proposer(num_speculative_tokens)
106+
hidden_size = proposer.hidden_size
107+
108+
# Mock the MTP model to verify it returns hidden states directly
109+
model_mock = mock.MagicMock()
110+
111+
# MTP returns hidden states directly
112+
if num_speculative_tokens == 1:
113+
model_mock.return_value = torch.zeros(total_tokens,
114+
hidden_size,
115+
device=device)
116+
else:
117+
# Multiple forward passes for multi-token speculation
118+
forward_returns = []
119+
for i in range(num_speculative_tokens):
120+
if i == 0:
121+
h_states = torch.zeros(total_tokens,
122+
hidden_size,
123+
device=device)
124+
else:
125+
h_states = torch.zeros(batch_size, hidden_size, device=device)
126+
forward_returns.append(h_states)
127+
model_mock.side_effect = forward_returns
128+
129+
# Mock compute_logits
130+
def create_deterministic_logits(batch_size, vocab_size, token_offset):
131+
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
132+
logits[:, token_offset] = 100.0
133+
return logits
134+
135+
if num_speculative_tokens == 1:
136+
model_mock.compute_logits.return_value = create_deterministic_logits(
137+
batch_size, vocab_size, 42)
138+
else:
139+
logits_returns = [
140+
create_deterministic_logits(batch_size, vocab_size, 42 + i)
141+
for i in range(num_speculative_tokens)
142+
]
143+
model_mock.compute_logits.side_effect = logits_returns
144+
145+
proposer.model = model_mock
146+
proposer.attn_layer_names = ["layer.0"]
147+
148+
# Prepare inputs
149+
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
150+
common_attn_metadata = create_common_attn_metadata(batch_spec,
151+
block_size=16,
152+
device=device)
153+
154+
target_token_ids = torch.randint(0,
155+
vocab_size, (total_tokens, ),
156+
device=device)
157+
target_positions = torch.cat([
158+
torch.arange(seq_lens[0], device=device),
159+
torch.arange(seq_lens[1], device=device)
160+
])
161+
target_hidden_states = torch.randn(total_tokens,
162+
hidden_size,
163+
device=device)
164+
next_token_ids = torch.randint(0,
165+
vocab_size, (batch_size, ),
166+
dtype=torch.int32,
167+
device=device)
168+
sampling_metadata = mock.MagicMock()
169+
170+
# Setup attention metadata
171+
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
172+
173+
attn_metadata_builder = attn_metadata_builder_cls(
174+
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
175+
layer_names=proposer.attn_layer_names,
176+
vllm_config=proposer.vllm_config,
177+
device=device,
178+
)
179+
180+
proposer.runner = mock.MagicMock()
181+
proposer.attn_metadata_builder = attn_metadata_builder
182+
183+
# Run propose
184+
result = proposer.propose(target_token_ids=target_token_ids,
185+
target_positions=target_positions,
186+
target_hidden_states=target_hidden_states,
187+
next_token_ids=next_token_ids,
188+
last_token_indices=None,
189+
common_attn_metadata=common_attn_metadata,
190+
sampling_metadata=sampling_metadata)
191+
192+
# Verify the model was called correctly
193+
assert model_mock.called
194+
# Verify output shape
195+
assert result.shape == (batch_size, num_speculative_tokens)

vllm/config/speculative.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
3333
"mlp_speculator", "draft_model", "deepseek_mtp",
3434
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
35-
"longcat_flash_mtp"]
35+
"longcat_flash_mtp", "mtp"]
36+
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
37+
"qwen3_next_mtp", "longcat_flash_mtp")
3638

3739

3840
@config
@@ -207,11 +209,16 @@ def __post_init__(self):
207209
# can not be detected, it will be considered as the "draft_model" by
208210
# default.
209211

212+
if self.method in MTP_MODEL_TYPES:
213+
logger.warning("method `%s` is deprecated and replaced with mtp.",
214+
self.method)
215+
self.method = "mtp"
216+
210217
if self.model is None and self.num_speculative_tokens is not None:
211-
# TODO(Shangming): Refactor mtp configuration logic when supporting
212-
if (self.target_model_config
213-
and self.target_model_config.hf_text_config.model_type
214-
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
218+
if self.method == "mtp":
219+
assert (
220+
self.target_model_config
221+
is not None), "target_model_config must be present for mtp"
215222
# use the draft model from the same model:
216223
self.model = self.target_model_config.model
217224
# Align the quantization of draft model for cases such as
@@ -314,31 +321,13 @@ def __post_init__(self):
314321
"mlp_speculator"):
315322
self.method = "mlp_speculator"
316323
elif (self.draft_model_config.hf_config.model_type
317-
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
318-
self.method = "deepseek_mtp"
319-
if self.num_speculative_tokens > 1:
320-
logger.warning(
321-
"All Deepseek MTP models only have " \
322-
"one layer. Might need some code changes " \
323-
"to support multiple layers."
324-
)
325-
elif (self.draft_model_config.hf_config.model_type ==
326-
"ernie_mtp"):
327-
self.method = "ernie_mtp"
324+
in MTP_MODEL_TYPES):
325+
self.method = "mtp"
328326
if self.num_speculative_tokens > 1:
329327
logger.warning(
330-
"All Ernie MTP models only have " \
331-
"one layer. Might need some code changes " \
332-
"to support multiple layers."
333-
)
334-
elif (self.draft_model_config.hf_config.model_type ==
335-
"qwen3_next_mtp"):
336-
self.method = "qwen3_next_mtp"
337-
if self.num_speculative_tokens > 1:
338-
logger.warning(
339-
"All Qwen3Next MTP models only have " \
340-
"one layer. Might need some code changes " \
341-
"to support multiple layers."
328+
"Enabling num_speculative_tokens > 1 will run" \
329+
"multiple times of forward on same MTP layer" \
330+
",which may result in lower acceptance rate" \
342331
)
343332
elif (self.draft_model_config.hf_config.model_type
344333
in ("longcat_flash_mtp")):
@@ -355,7 +344,7 @@ def __post_init__(self):
355344
"Speculative decoding with draft model is not "
356345
"supported yet. Please consider using other "
357346
"speculative decoding methods such as ngram, medusa, "
358-
"eagle, or deepseek_mtp.")
347+
"eagle, or mtp.")
359348

360349
# Replace hf_config for EAGLE draft_model
361350
if self.method in ("eagle", "eagle3"):
@@ -564,8 +553,7 @@ def num_lookahead_slots(self) -> int:
564553
return self.num_speculative_tokens
565554

566555
def use_eagle(self) -> bool:
567-
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
568-
"qwen3_next_mtp", "longcat_flash_mtp")
556+
return self.method in ("eagle", "eagle3", "mtp")
569557

570558
def __repr__(self) -> str:
571559
method = self.method

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1486,7 +1486,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14861486
raise NotImplementedError(
14871487
"Draft model speculative decoding is not supported yet. "
14881488
"Please consider using other speculative decoding methods "
1489-
"such as ngram, medusa, eagle, or deepseek_mtp.")
1489+
"such as ngram, medusa, eagle, or mtp.")
14901490

14911491
V1_BACKENDS = [
14921492
"FLASH_ATTN",

0 commit comments

Comments
 (0)