diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 7d93a44c5059..e000d955cfc0 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -117,34 +117,13 @@ def test_prepare_inputs(): ]) @mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') -@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry') -@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader') -@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype') -@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config') -def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader, - mock_registry, mock_get_layers, mock_get_pp_group, method, +@mock.patch('vllm.v1.spec_decode.eagle.get_model') +def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, proposer_helper, draft_model_dir, target_attribute_path): - # Setup mock for model class - mock_model_cls = mock.MagicMock() - mock_registry.resolve_model_cls.return_value = (mock_model_cls, - "test_arch") - - # Create a real context manager for mocks - class MockContextManager: - - def __init__(self): - pass - - def __enter__(self): - return None - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - # Make the mocks return actual context manager objects - mock_set_dtype.return_value = MockContextManager() - mock_set_config.return_value = MockContextManager() + # Setup model mock + mock_model = mock.MagicMock() + mock_get_model.return_value = mock_model # Setup mocks for attention layers target_attn_layers = { @@ -164,25 +143,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): mock_pp_group.world_size = 2 if method == "eagle" else 1 mock_get_pp_group.return_value = mock_pp_group - # Setup model loader mock - mock_loader = mock.MagicMock() - mock_get_loader.return_value = mock_loader - - # Setup model mock - mock_model = mock.MagicMock() - mock_model_cls.return_value = mock_model - mock_model.to.return_value = mock_model - - # Configure mock to test the attribute sharing path - if method == "eagle": - # For eagle, test the lm_head path - mock_model.load_weights.return_value = { - "model.embed_tokens.weight": torch.zeros(1) - } - else: - # For eagle3, test the embed_tokens path - mock_model.load_weights.return_value = {} - # Setup target model with the appropriate attributes target_model = mock.MagicMock() @@ -204,13 +164,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): proposer.load_model(target_model) # Verify common interactions - mock_get_loader.assert_called_once() - mock_model_cls.assert_called_once() - mock_model.to.assert_called_once() - mock_model.load_weights.assert_called_once() - - # Verify the loader was called with the right config - mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config) + mock_get_model.assert_called_once() # Verify the specific attribute sharing based on the method if method == "eagle": diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 92a0b0923b6e..a443a652d8a3 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + from torch import nn -from vllm.config import LoadConfig, LoadFormat, VllmConfig +from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.bitsandbytes_loader import ( BitsAndBytesModelLoader) @@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: return DefaultModelLoader(load_config) -def get_model(*, vllm_config: VllmConfig) -> nn.Module: +def get_model(*, + vllm_config: VllmConfig, + model_config: Optional[ModelConfig] = None) -> nn.Module: loader = get_model_loader(vllm_config.load_config) - return loader.load_model(vllm_config=vllm_config) + if model_config is None: + model_config = vllm_config.model_config + return loader.load_model(vllm_config=vllm_config, + model_config=model_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index f17cab05c25d..010dd515784a 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -18,6 +18,7 @@ def download_model(self, model_config: ModelConfig) -> None: raise NotImplementedError @abstractmethod - def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, *, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: """Load a model with the given configurations.""" raise NotImplementedError diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 6771c128c5a1..0d83c8d53419 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -569,10 +569,9 @@ def _load_weights(self, model_config: ModelConfig, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config - with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 21eb7d8a75fb..ddbd60940e9e 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -264,13 +264,14 @@ def download_model(self, model_config: ModelConfig) -> None: fall_back_to_pt=True, allow_patterns_overrides=None) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, + model_config=model_config) weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 5047a161f3f9..0e2f0be1ec26 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -22,9 +22,9 @@ def __init__(self, load_config: LoadConfig): def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 2766c9787b83..806004bf9604 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -92,9 +92,9 @@ def _get_weights_iterator( def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index a695ba03bd1d..9f1022c25925 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -100,11 +100,10 @@ def download_model(self, model_config: ModelConfig) -> None: """Download model if necessary""" self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: """Perform streaming of the model to destination""" device_config = vllm_config.device_config - model_config = vllm_config.model_config - target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 913bda7e007a..78bca89f0015 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -100,9 +100,9 @@ def _prepare_weights(self, model_name_or_path: str, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) from vllm.distributed import get_tensor_model_parallel_rank diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 4107e741fd8f..8e2121d56c37 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -92,8 +92,8 @@ def download_model(self, model_config: ModelConfig) -> None: with self.tensorizer_config.open_stream(): pass - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - model_config = vllm_config.model_config + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 68b1f1ad74d3..39e380f07297 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -42,9 +42,11 @@ def initialize_model( *, prefix: str = "", model_class: Optional[type[nn.Module]] = None, + model_config: Optional[ModelConfig] = None, ) -> nn.Module: """Initialize a model with the given configurations.""" - model_config = vllm_config.model_config + if model_config is None: + model_config = vllm_config.model_config if model_class is None: model_class, _ = get_model_architecture(model_config) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 018ecc2a8c0f..172dc8b5ec06 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -130,13 +130,15 @@ def load_weights(self, weights: Iterable[tuple[str, class EagleLlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) self.model = LlamaModel(vllm_config=vllm_config, prefix="model", - start_layer_id=start_layer_id) + start_layer_id=target_layer_num) logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.config.vocab_size, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 2302d1352de6..96e666a3543d 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -175,13 +175,15 @@ def load_weights(self, weights: Iterable[tuple[str, class Eagle3LlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) self.model = LlamaModel(vllm_config=vllm_config, - start_layer_id=start_layer_id, - prefix="model") + prefix="model", + start_layer_id=target_layer_num) logit_scale = getattr(self.config, "logit_scale", 1.0) self.lm_head = ParallelLMHead( @@ -193,8 +195,7 @@ def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale) self.draft_id_to_target_id = nn.Parameter( - torch.zeros((self.config.draft_vocab_size), - dtype=torch.long).type(torch.LongTensor), + torch.zeros(self.config.draft_vocab_size, dtype=torch.long), requires_grad=False, ) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 588bcb628f8c..95ef1134b1bf 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -51,10 +51,7 @@ class Medusa(nn.Module): needs to have truncated_vocab_size (=k) as an attribute.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - if hasattr(vllm_config, 'draft_model_config'): - config = vllm_config.draft_model_config.hf_config - else: - config = vllm_config.model_config.hf_config + config = vllm_config.speculative_config.draft_model_config.hf_config super().__init__() self.config = config self.blocks = nn.ModuleList([ diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 19fb2a2af7dd..460d645a1a6c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,14 +4,11 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config, set_current_vllm_config) + get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading, set_default_torch_dtype) -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -280,51 +277,28 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - loader = get_model_loader(self.vllm_config.load_config) - target_layer_num = self.vllm_config.model_config.get_num_layers( - self.vllm_config.parallel_config) + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - # FIXME(lily): This does not handle with distributed inference. - target_device = self.vllm_config.device_config.device - # We need to set the vllm_config here to register attention - # layers in the forward context. - with set_default_torch_dtype( - draft_model_config.dtype), set_current_vllm_config( - self.vllm_config): - draft_model_cls, arch = ModelRegistry.resolve_model_cls( - draft_model_config.architectures) - self.model = draft_model_cls( - vllm_config=self.vllm_config, - start_layer_id=target_layer_num).to(target_device) + self.model = get_model(vllm_config=self.vllm_config, + model_config=draft_model_config) draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) assert len(draft_attn_layer_names) == 1 self.attn_layer_name = next(iter(draft_attn_layer_names)) - loaded_weights = self.model.load_weights( - loader.get_all_weights(draft_model_config, self.model)) - - process_weights_after_loading(self.model, draft_model_config, - target_device) # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: - assert "model.embed_tokens.weight" not in loaded_weights, \ - "For PP = 1, Eagle draft should share embed with target model" logger.info( "The EAGLE head shares the same vocab embedding" \ " with the target model." ) self.model.model.embed_tokens = target_model.model.embed_tokens else: - assert "model.embed_tokens.weight" in loaded_weights, \ - "For PP > 1, Eagle draft checkpoint should its own copy of " - " the model.embed_tokens.weight" logger.info( "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ " weights instead of sharing them with the target model." diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 14bc9c9e0d1a..fdac2ef64c3f 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -3,12 +3,10 @@ import torch import torch.nn as nn -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.medusa import Medusa +from vllm.model_executor.model_loader import get_model from vllm.v1.sample.metadata import SamplingMetadata # Initialize logger @@ -49,20 +47,9 @@ def propose( return [list(row) for row in zip(*draft_tokens)] def load_model(self, target_model: nn.Module) -> None: - # Get model loader and config - loader = get_model_loader(self.vllm_config.load_config) - draft_config = self.vllm_config.speculative_config.draft_model_config - - # Load model with proper dtype and config - with set_default_torch_dtype(draft_config.dtype), \ - set_current_vllm_config(self.vllm_config): - self.model = Medusa( - vllm_config=self.vllm_config.speculative_config).to( - self.device) - - # Load model weights - weights = loader.get_all_weights(draft_config, self.model) - self.model.load_weights(weights) + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) @torch.inference_mode() def dummy_run(self, num_tokens: int) -> None: