Skip to content

Commit c6b636f

Browse files
authored
[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 04eb88d commit c6b636f

File tree

16 files changed

+59
-135
lines changed

16 files changed

+59
-135
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -117,34 +117,13 @@ def test_prepare_inputs():
117117
])
118118
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
119119
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
120-
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
121-
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
122-
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
123-
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
124-
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
125-
mock_registry, mock_get_layers, mock_get_pp_group, method,
120+
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
121+
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
126122
proposer_helper, draft_model_dir, target_attribute_path):
127123

128-
# Setup mock for model class
129-
mock_model_cls = mock.MagicMock()
130-
mock_registry.resolve_model_cls.return_value = (mock_model_cls,
131-
"test_arch")
132-
133-
# Create a real context manager for mocks
134-
class MockContextManager:
135-
136-
def __init__(self):
137-
pass
138-
139-
def __enter__(self):
140-
return None
141-
142-
def __exit__(self, exc_type, exc_val, exc_tb):
143-
return False
144-
145-
# Make the mocks return actual context manager objects
146-
mock_set_dtype.return_value = MockContextManager()
147-
mock_set_config.return_value = MockContextManager()
124+
# Setup model mock
125+
mock_model = mock.MagicMock()
126+
mock_get_model.return_value = mock_model
148127

149128
# Setup mocks for attention layers
150129
target_attn_layers = {
@@ -164,25 +143,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
164143
mock_pp_group.world_size = 2 if method == "eagle" else 1
165144
mock_get_pp_group.return_value = mock_pp_group
166145

167-
# Setup model loader mock
168-
mock_loader = mock.MagicMock()
169-
mock_get_loader.return_value = mock_loader
170-
171-
# Setup model mock
172-
mock_model = mock.MagicMock()
173-
mock_model_cls.return_value = mock_model
174-
mock_model.to.return_value = mock_model
175-
176-
# Configure mock to test the attribute sharing path
177-
if method == "eagle":
178-
# For eagle, test the lm_head path
179-
mock_model.load_weights.return_value = {
180-
"model.embed_tokens.weight": torch.zeros(1)
181-
}
182-
else:
183-
# For eagle3, test the embed_tokens path
184-
mock_model.load_weights.return_value = {}
185-
186146
# Setup target model with the appropriate attributes
187147
target_model = mock.MagicMock()
188148

@@ -204,13 +164,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
204164
proposer.load_model(target_model)
205165

206166
# Verify common interactions
207-
mock_get_loader.assert_called_once()
208-
mock_model_cls.assert_called_once()
209-
mock_model.to.assert_called_once()
210-
mock_model.load_weights.assert_called_once()
211-
212-
# Verify the loader was called with the right config
213-
mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config)
167+
mock_get_model.assert_called_once()
214168

215169
# Verify the specific attribute sharing based on the method
216170
if method == "eagle":

vllm/model_executor/model_loader/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from typing import Optional
4+
35
from torch import nn
46

5-
from vllm.config import LoadConfig, LoadFormat, VllmConfig
7+
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
68
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
79
from vllm.model_executor.model_loader.bitsandbytes_loader import (
810
BitsAndBytesModelLoader)
@@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
4749
return DefaultModelLoader(load_config)
4850

4951

50-
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
52+
def get_model(*,
53+
vllm_config: VllmConfig,
54+
model_config: Optional[ModelConfig] = None) -> nn.Module:
5155
loader = get_model_loader(vllm_config.load_config)
52-
return loader.load_model(vllm_config=vllm_config)
56+
if model_config is None:
57+
model_config = vllm_config.model_config
58+
return loader.load_model(vllm_config=vllm_config,
59+
model_config=model_config)
5360

5461

5562
__all__ = [

vllm/model_executor/model_loader/base_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def download_model(self, model_config: ModelConfig) -> None:
1818
raise NotImplementedError
1919

2020
@abstractmethod
21-
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
21+
def load_model(self, *, vllm_config: VllmConfig,
22+
model_config: ModelConfig) -> nn.Module:
2223
"""Load a model with the given configurations."""
2324
raise NotImplementedError

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,10 +569,9 @@ def _load_weights(self, model_config: ModelConfig,
569569
def download_model(self, model_config: ModelConfig) -> None:
570570
self._prepare_weights(model_config.model, model_config.revision)
571571

572-
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
572+
def load_model(self, vllm_config: VllmConfig,
573+
model_config: ModelConfig) -> nn.Module:
573574
device_config = vllm_config.device_config
574-
model_config = vllm_config.model_config
575-
576575
with set_default_torch_dtype(model_config.dtype):
577576
with torch.device(device_config.device):
578577

vllm/model_executor/model_loader/default_loader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,14 @@ def download_model(self, model_config: ModelConfig) -> None:
264264
fall_back_to_pt=True,
265265
allow_patterns_overrides=None)
266266

267-
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
267+
def load_model(self, vllm_config: VllmConfig,
268+
model_config: ModelConfig) -> nn.Module:
268269
device_config = vllm_config.device_config
269-
model_config = vllm_config.model_config
270270
target_device = torch.device(device_config.device)
271271
with set_default_torch_dtype(model_config.dtype):
272272
with target_device:
273-
model = initialize_model(vllm_config=vllm_config)
273+
model = initialize_model(vllm_config=vllm_config,
274+
model_config=model_config)
274275

275276
weights_to_load = {name for name, _ in model.named_parameters()}
276277
loaded_weights = model.load_weights(

vllm/model_executor/model_loader/dummy_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ def __init__(self, load_config: LoadConfig):
2222
def download_model(self, model_config: ModelConfig) -> None:
2323
pass # Nothing to download
2424

25-
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
25+
def load_model(self, vllm_config: VllmConfig,
26+
model_config: ModelConfig) -> nn.Module:
2627
device_config = vllm_config.device_config
27-
model_config = vllm_config.model_config
2828
target_device = torch.device(device_config.device)
2929
with set_default_torch_dtype(model_config.dtype):
3030
with target_device:

vllm/model_executor/model_loader/gguf_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def _get_weights_iterator(
9292
def download_model(self, model_config: ModelConfig) -> None:
9393
self._prepare_weights(model_config.model)
9494

95-
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
95+
def load_model(self, vllm_config: VllmConfig,
96+
model_config: ModelConfig) -> nn.Module:
9697
device_config = vllm_config.device_config
97-
model_config = vllm_config.model_config
9898
local_model_path = self._prepare_weights(model_config.model)
9999
gguf_weights_map = self._get_gguf_weights_map(model_config)
100100
# we can only know if tie word embeddings after mapping weights

vllm/model_executor/model_loader/runai_streamer_loader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,10 @@ def download_model(self, model_config: ModelConfig) -> None:
100100
"""Download model if necessary"""
101101
self._prepare_weights(model_config.model, model_config.revision)
102102

103-
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
103+
def load_model(self, vllm_config: VllmConfig,
104+
model_config: ModelConfig) -> nn.Module:
104105
"""Perform streaming of the model to destination"""
105106
device_config = vllm_config.device_config
106-
model_config = vllm_config.model_config
107-
108107
target_device = torch.device(device_config.device)
109108
with set_default_torch_dtype(model_config.dtype):
110109
with target_device:

vllm/model_executor/model_loader/sharded_state_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def _prepare_weights(self, model_name_or_path: str,
100100
def download_model(self, model_config: ModelConfig) -> None:
101101
self._prepare_weights(model_config.model, model_config.revision)
102102

103-
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
103+
def load_model(self, vllm_config: VllmConfig,
104+
model_config: ModelConfig) -> nn.Module:
104105
device_config = vllm_config.device_config
105-
model_config = vllm_config.model_config
106106
target_device = torch.device(device_config.device)
107107

108108
from vllm.distributed import get_tensor_model_parallel_rank

vllm/model_executor/model_loader/tensorizer_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def download_model(self, model_config: ModelConfig) -> None:
9393
with self.tensorizer_config.open_stream():
9494
pass
9595

96-
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
97-
model_config = vllm_config.model_config
96+
def load_model(self, vllm_config: VllmConfig,
97+
model_config: ModelConfig) -> nn.Module:
9898
parallel_config = vllm_config.parallel_config
9999
self._verify_config(model_config, parallel_config)
100100

0 commit comments

Comments
 (0)