|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +from unittest import mock |
| 5 | + |
| 6 | +import pytest |
| 7 | +import torch |
| 8 | + |
| 9 | +from tests.v1.attention.utils import (BatchSpec, _Backend, |
| 10 | + create_common_attn_metadata, |
| 11 | + create_standard_kv_cache_spec, |
| 12 | + get_attention_backend) |
| 13 | +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, |
| 14 | + ParallelConfig, SchedulerConfig, SpeculativeConfig, |
| 15 | + VllmConfig) |
| 16 | +from vllm.config.load import LoadConfig |
| 17 | +from vllm.model_executor.models.llama import LlamaForCausalLM |
| 18 | +from vllm.platforms import current_platform |
| 19 | +from vllm.v1.spec_decode.eagle import EagleProposer |
| 20 | + |
| 21 | +mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" |
| 22 | + |
| 23 | + |
| 24 | +def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: |
| 25 | + """Create an MTP proposer with unified model configuration.""" |
| 26 | + model_config = ModelConfig(model=mimo_7b_dir, |
| 27 | + runner="generate", |
| 28 | + max_model_len=100, |
| 29 | + trust_remote_code=True) |
| 30 | + |
| 31 | + speculative_config = SpeculativeConfig( |
| 32 | + target_model_config=model_config, |
| 33 | + target_parallel_config=ParallelConfig(), |
| 34 | + model=mimo_7b_dir, |
| 35 | + method="mtp", |
| 36 | + num_speculative_tokens=num_speculative_tokens, |
| 37 | + ) |
| 38 | + |
| 39 | + vllm_config = VllmConfig( |
| 40 | + model_config=model_config, |
| 41 | + cache_config=CacheConfig(), |
| 42 | + speculative_config=speculative_config, |
| 43 | + device_config=DeviceConfig(device=current_platform.device_type), |
| 44 | + parallel_config=ParallelConfig(), |
| 45 | + load_config=LoadConfig(), |
| 46 | + scheduler_config=SchedulerConfig()) |
| 47 | + |
| 48 | + return EagleProposer(vllm_config=vllm_config, |
| 49 | + device=current_platform.device_type) |
| 50 | + |
| 51 | + |
| 52 | +@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') |
| 53 | +@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') |
| 54 | +@mock.patch('vllm.v1.spec_decode.eagle.get_model') |
| 55 | +def test_mtp_load_model_unified(mock_get_model, mock_get_layers, |
| 56 | + mock_get_pp_group): |
| 57 | + """Test MTP-specific model loading with unified model approach.""" |
| 58 | + |
| 59 | + # Setup mocks |
| 60 | + mock_model = mock.MagicMock() |
| 61 | + mock_model.model.embed_tokens.weight.shape = (131072, 4096) |
| 62 | + mock_get_model.return_value = mock_model |
| 63 | + |
| 64 | + target_attn_layers = {"target_attn_1": mock.MagicMock()} |
| 65 | + all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()} |
| 66 | + mock_get_layers.side_effect = [target_attn_layers, all_attn_layers] |
| 67 | + |
| 68 | + mock_pp_group = mock.MagicMock() |
| 69 | + mock_pp_group.world_size = 1 |
| 70 | + mock_get_pp_group.return_value = mock_pp_group |
| 71 | + |
| 72 | + # Create target model |
| 73 | + class _TargetModelStub(LlamaForCausalLM): |
| 74 | + model: mock.MagicMock |
| 75 | + lm_head: mock.MagicMock |
| 76 | + |
| 77 | + target_model = mock.create_autospec(_TargetModelStub, instance=True) |
| 78 | + target_model.model = mock.MagicMock() |
| 79 | + target_model.model.embed_tokens.weight.shape = (131072, 4096) |
| 80 | + target_model.lm_head = mock.MagicMock() |
| 81 | + |
| 82 | + # Create MTP proposer |
| 83 | + proposer = _create_mtp_proposer(num_speculative_tokens=4) |
| 84 | + proposer.load_model(target_model) |
| 85 | + |
| 86 | + # Verify MTP-specific behavior: |
| 87 | + # Model is loaded |
| 88 | + mock_get_model.assert_called_once() |
| 89 | + # MTP shares lm_head with target model |
| 90 | + assert proposer.model.lm_head == target_model.lm_head |
| 91 | + # MTP shares embed_tokens with target model |
| 92 | + assert proposer.model.model.embed_tokens == target_model.model.embed_tokens |
| 93 | + |
| 94 | + |
| 95 | +@pytest.mark.parametrize("num_speculative_tokens", [1]) |
| 96 | +def test_mtp_propose(num_speculative_tokens, monkeypatch): |
| 97 | + """Test that MTP's forward method returns hidden states directly""" |
| 98 | + |
| 99 | + device = torch.device(current_platform.device_type) |
| 100 | + batch_size = 2 |
| 101 | + seq_lens = [5, 3] |
| 102 | + total_tokens = sum(seq_lens) |
| 103 | + vocab_size = 100 |
| 104 | + |
| 105 | + proposer = _create_mtp_proposer(num_speculative_tokens) |
| 106 | + hidden_size = proposer.hidden_size |
| 107 | + |
| 108 | + # Mock the MTP model to verify it returns hidden states directly |
| 109 | + model_mock = mock.MagicMock() |
| 110 | + |
| 111 | + # MTP returns hidden states directly |
| 112 | + if num_speculative_tokens == 1: |
| 113 | + model_mock.return_value = torch.zeros(total_tokens, |
| 114 | + hidden_size, |
| 115 | + device=device) |
| 116 | + else: |
| 117 | + # Multiple forward passes for multi-token speculation |
| 118 | + forward_returns = [] |
| 119 | + for i in range(num_speculative_tokens): |
| 120 | + if i == 0: |
| 121 | + h_states = torch.zeros(total_tokens, |
| 122 | + hidden_size, |
| 123 | + device=device) |
| 124 | + else: |
| 125 | + h_states = torch.zeros(batch_size, hidden_size, device=device) |
| 126 | + forward_returns.append(h_states) |
| 127 | + model_mock.side_effect = forward_returns |
| 128 | + |
| 129 | + # Mock compute_logits |
| 130 | + def create_deterministic_logits(batch_size, vocab_size, token_offset): |
| 131 | + logits = torch.full((batch_size, vocab_size), -100.0, device=device) |
| 132 | + logits[:, token_offset] = 100.0 |
| 133 | + return logits |
| 134 | + |
| 135 | + if num_speculative_tokens == 1: |
| 136 | + model_mock.compute_logits.return_value = create_deterministic_logits( |
| 137 | + batch_size, vocab_size, 42) |
| 138 | + else: |
| 139 | + logits_returns = [ |
| 140 | + create_deterministic_logits(batch_size, vocab_size, 42 + i) |
| 141 | + for i in range(num_speculative_tokens) |
| 142 | + ] |
| 143 | + model_mock.compute_logits.side_effect = logits_returns |
| 144 | + |
| 145 | + proposer.model = model_mock |
| 146 | + proposer.attn_layer_names = ["layer.0"] |
| 147 | + |
| 148 | + # Prepare inputs |
| 149 | + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) |
| 150 | + common_attn_metadata = create_common_attn_metadata(batch_spec, |
| 151 | + block_size=16, |
| 152 | + device=device) |
| 153 | + |
| 154 | + target_token_ids = torch.randint(0, |
| 155 | + vocab_size, (total_tokens, ), |
| 156 | + device=device) |
| 157 | + target_positions = torch.cat([ |
| 158 | + torch.arange(seq_lens[0], device=device), |
| 159 | + torch.arange(seq_lens[1], device=device) |
| 160 | + ]) |
| 161 | + target_hidden_states = torch.randn(total_tokens, |
| 162 | + hidden_size, |
| 163 | + device=device) |
| 164 | + next_token_ids = torch.randint(0, |
| 165 | + vocab_size, (batch_size, ), |
| 166 | + dtype=torch.int32, |
| 167 | + device=device) |
| 168 | + sampling_metadata = mock.MagicMock() |
| 169 | + |
| 170 | + # Setup attention metadata |
| 171 | + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) |
| 172 | + |
| 173 | + attn_metadata_builder = attn_metadata_builder_cls( |
| 174 | + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), |
| 175 | + layer_names=proposer.attn_layer_names, |
| 176 | + vllm_config=proposer.vllm_config, |
| 177 | + device=device, |
| 178 | + ) |
| 179 | + |
| 180 | + proposer.runner = mock.MagicMock() |
| 181 | + proposer.attn_metadata_builder = attn_metadata_builder |
| 182 | + |
| 183 | + # Run propose |
| 184 | + result = proposer.propose(target_token_ids=target_token_ids, |
| 185 | + target_positions=target_positions, |
| 186 | + target_hidden_states=target_hidden_states, |
| 187 | + next_token_ids=next_token_ids, |
| 188 | + last_token_indices=None, |
| 189 | + common_attn_metadata=common_attn_metadata, |
| 190 | + sampling_metadata=sampling_metadata) |
| 191 | + |
| 192 | + # Verify the model was called correctly |
| 193 | + assert model_mock.called |
| 194 | + # Verify output shape |
| 195 | + assert result.shape == (batch_size, num_speculative_tokens) |
0 commit comments