Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,80 @@ def test_eagle_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("deepseek_mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
],
ids=["deepseek_mtp"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, int],
mm_enabled: bool,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
pytest.skip("MTP does not support tree-based speculative decoding")

if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token MTP spec decode on current platform")

# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")

method, model_name, tp_size = model_setup

ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size,
trust_remote_code=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"num_speculative_tokens": 1,
"max_model_len": 2048,
},
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 80% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.8 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
223 changes: 223 additions & 0 deletions tests/v1/spec_decode/test_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest import mock

import pytest
import torch

from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer

mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"


def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
"""Create an MTP proposer with unified model configuration."""
model_config = ModelConfig(model=mimo_7b_dir,
runner="generate",
max_model_len=100,
trust_remote_code=True)

speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=mimo_7b_dir,
method="deepseek_mtp",
num_speculative_tokens=num_speculative_tokens,
)

vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())

return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)


@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach."""

# Setup mocks
mock_model = mock.MagicMock()
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model

target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]

mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
mock_get_pp_group.return_value = mock_pp_group

# Create target model
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock

target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)
target_model.lm_head = mock.MagicMock()

# Create MTP proposer
proposer = _create_mtp_proposer(num_speculative_tokens=4)
proposer.load_model(target_model)

# Verify MTP-specific behavior:
# Model is loaded
mock_get_model.assert_called_once()
# MTP shares lm_head with target model
assert proposer.model.lm_head == target_model.lm_head
# MTP shares embed_tokens with target model
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens


@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1])
def test_mtp_propose_returns_hidden_states(attn_backend,
num_speculative_tokens,
monkeypatch):
"""Test that MTP's forward method returns hidden states directly"""

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token spec decode on current platform")

if attn_backend == "TREE_ATTN":
pytest.skip("MTP does not support tree-based speculative decoding")

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

device = torch.device(current_platform.device_type)
batch_size = 2
seq_lens = [5, 3]
total_tokens = sum(seq_lens)
vocab_size = 100

proposer = _create_mtp_proposer(num_speculative_tokens)
hidden_size = proposer.hidden_size

# Mock the MTP model to verify it returns hidden states directly
model_mock = mock.MagicMock()

# MTP returns hidden states directly
if num_speculative_tokens == 1:
model_mock.return_value = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
# Multiple forward passes for multi-token speculation
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
h_states = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append(h_states)
model_mock.side_effect = forward_returns

# Mock compute_logits
def create_deterministic_logits(batch_size, vocab_size, token_offset):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
logits[:, token_offset] = 100.0
return logits

if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = create_deterministic_logits(
batch_size, vocab_size, 42)
else:
logits_returns = [
create_deterministic_logits(batch_size, vocab_size, 42 + i)
for i in range(num_speculative_tokens)
]
model_mock.compute_logits.side_effect = logits_returns

proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"]

# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)

target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
sampling_metadata = mock.MagicMock()

# Setup attention metadata
if attn_backend == "FLASH_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
elif attn_backend == "TRITON_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TRITON_ATTN_VLLM_V1)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")

attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)

proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]

# Run propose
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)

# Verify the model was called correctly
assert model_mock.called
# Verify output shape
assert result.shape == (batch_size, num_speculative_tokens)