From 058e78fb8a11a35cc5627a598d470a337f8f4913 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Fri, 16 May 2025 13:58:16 -0400 Subject: [PATCH 1/3] [V1][Spec Decoding] Use model_loader.get_model() to load models With: ``` $> vllm serve lmsys/vicuna-7b-v1.3 --speculative-config '{"method": "medusa", "model": "abhigoyal/vllm-medusa-vicuna-7b-v1.3", "num_speculative_tokens": 6}' ``` (A misconfiguration: num_spec_tokens > num_heads) I noticed in V0 we get this fairly obscure error: File "/home/markmc/vllm-project/vllm/vllm/model_executor/model_loader/default_loader.py", line 294, in load_model raise ValueError( ValueError: Following weights were not initialized from checkpoint: {'blocks.5.layers.0.weight', 'lm_heads.5.weight'} but in V1, we silently accept the config. It turns out the difference is that in V0 we're using ModelLoader.load_model() but in V1 we're reimplementing a subset that logic in the Medusa and EAGLE proposers. If we use the full model loading infrastructure, we can reduce the complexity of drafters, and automatically have support for e.g. quantized weights as per #18290. Signed-off-by: Mark McLoughlin --- vllm/model_executor/model_loader/__init__.py | 13 +++++-- .../model_loader/base_loader.py | 3 +- .../model_loader/bitsandbytes_loader.py | 5 +-- .../model_loader/default_loader.py | 7 ++-- .../model_loader/dummy_loader.py | 4 +- .../model_loader/gguf_loader.py | 4 +- .../model_loader/runai_streamer_loader.py | 5 +-- .../model_loader/sharded_state_loader.py | 4 +- .../model_loader/tensorizer_loader.py | 4 +- vllm/model_executor/model_loader/utils.py | 4 +- vllm/model_executor/models/llama_eagle.py | 6 ++- vllm/model_executor/models/llama_eagle3.py | 8 ++-- vllm/model_executor/models/medusa.py | 5 +-- vllm/v1/spec_decode/eagle.py | 38 +++---------------- vllm/v1/spec_decode/medusa.py | 23 +++-------- 15 files changed, 52 insertions(+), 81 deletions(-) diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 92a0b0923b6e..a443a652d8a3 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + from torch import nn -from vllm.config import LoadConfig, LoadFormat, VllmConfig +from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.bitsandbytes_loader import ( BitsAndBytesModelLoader) @@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: return DefaultModelLoader(load_config) -def get_model(*, vllm_config: VllmConfig) -> nn.Module: +def get_model(*, + vllm_config: VllmConfig, + model_config: Optional[ModelConfig] = None) -> nn.Module: loader = get_model_loader(vllm_config.load_config) - return loader.load_model(vllm_config=vllm_config) + if model_config is None: + model_config = vllm_config.model_config + return loader.load_model(vllm_config=vllm_config, + model_config=model_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index f17cab05c25d..010dd515784a 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -18,6 +18,7 @@ def download_model(self, model_config: ModelConfig) -> None: raise NotImplementedError @abstractmethod - def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, *, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: """Load a model with the given configurations.""" raise NotImplementedError diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 6771c128c5a1..0d83c8d53419 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -569,10 +569,9 @@ def _load_weights(self, model_config: ModelConfig, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config - with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 21eb7d8a75fb..ddbd60940e9e 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -264,13 +264,14 @@ def download_model(self, model_config: ModelConfig) -> None: fall_back_to_pt=True, allow_patterns_overrides=None) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, + model_config=model_config) weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 5047a161f3f9..0e2f0be1ec26 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -22,9 +22,9 @@ def __init__(self, load_config: LoadConfig): def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 2766c9787b83..806004bf9604 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -92,9 +92,9 @@ def _get_weights_iterator( def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index a695ba03bd1d..9f1022c25925 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -100,11 +100,10 @@ def download_model(self, model_config: ModelConfig) -> None: """Download model if necessary""" self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: """Perform streaming of the model to destination""" device_config = vllm_config.device_config - model_config = vllm_config.model_config - target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 913bda7e007a..78bca89f0015 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -100,9 +100,9 @@ def _prepare_weights(self, model_name_or_path: str, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, vllm_config: VllmConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: device_config = vllm_config.device_config - model_config = vllm_config.model_config target_device = torch.device(device_config.device) from vllm.distributed import get_tensor_model_parallel_rank diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 4107e741fd8f..8e2121d56c37 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -92,8 +92,8 @@ def download_model(self, model_config: ModelConfig) -> None: with self.tensorizer_config.open_stream(): pass - def load_model(self, vllm_config: VllmConfig) -> nn.Module: - model_config = vllm_config.model_config + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 68b1f1ad74d3..39e380f07297 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -42,9 +42,11 @@ def initialize_model( *, prefix: str = "", model_class: Optional[type[nn.Module]] = None, + model_config: Optional[ModelConfig] = None, ) -> nn.Module: """Initialize a model with the given configurations.""" - model_config = vllm_config.model_config + if model_config is None: + model_config = vllm_config.model_config if model_class is None: model_class, _ = get_model_architecture(model_config) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 018ecc2a8c0f..172dc8b5ec06 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -130,13 +130,15 @@ def load_weights(self, weights: Iterable[tuple[str, class EagleLlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) self.model = LlamaModel(vllm_config=vllm_config, prefix="model", - start_layer_id=start_layer_id) + start_layer_id=target_layer_num) logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.config.vocab_size, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 2302d1352de6..d358a9dae06e 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -175,13 +175,15 @@ def load_weights(self, weights: Iterable[tuple[str, class Eagle3LlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) self.model = LlamaModel(vllm_config=vllm_config, - start_layer_id=start_layer_id, - prefix="model") + prefix="model", + start_layer_id=target_layer_num) logit_scale = getattr(self.config, "logit_scale", 1.0) self.lm_head = ParallelLMHead( diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 588bcb628f8c..95ef1134b1bf 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -51,10 +51,7 @@ class Medusa(nn.Module): needs to have truncated_vocab_size (=k) as an attribute.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - if hasattr(vllm_config, 'draft_model_config'): - config = vllm_config.draft_model_config.hf_config - else: - config = vllm_config.model_config.hf_config + config = vllm_config.speculative_config.draft_model_config.hf_config super().__init__() self.config = config self.blocks = nn.ModuleList([ diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 19fb2a2af7dd..460d645a1a6c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,14 +4,11 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config, set_current_vllm_config) + get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading, set_default_torch_dtype) -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -280,51 +277,28 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - loader = get_model_loader(self.vllm_config.load_config) - target_layer_num = self.vllm_config.model_config.get_num_layers( - self.vllm_config.parallel_config) + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - # FIXME(lily): This does not handle with distributed inference. - target_device = self.vllm_config.device_config.device - # We need to set the vllm_config here to register attention - # layers in the forward context. - with set_default_torch_dtype( - draft_model_config.dtype), set_current_vllm_config( - self.vllm_config): - draft_model_cls, arch = ModelRegistry.resolve_model_cls( - draft_model_config.architectures) - self.model = draft_model_cls( - vllm_config=self.vllm_config, - start_layer_id=target_layer_num).to(target_device) + self.model = get_model(vllm_config=self.vllm_config, + model_config=draft_model_config) draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) assert len(draft_attn_layer_names) == 1 self.attn_layer_name = next(iter(draft_attn_layer_names)) - loaded_weights = self.model.load_weights( - loader.get_all_weights(draft_model_config, self.model)) - - process_weights_after_loading(self.model, draft_model_config, - target_device) # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: - assert "model.embed_tokens.weight" not in loaded_weights, \ - "For PP = 1, Eagle draft should share embed with target model" logger.info( "The EAGLE head shares the same vocab embedding" \ " with the target model." ) self.model.model.embed_tokens = target_model.model.embed_tokens else: - assert "model.embed_tokens.weight" in loaded_weights, \ - "For PP > 1, Eagle draft checkpoint should its own copy of " - " the model.embed_tokens.weight" logger.info( "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ " weights instead of sharing them with the target model." diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 14bc9c9e0d1a..fdac2ef64c3f 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -3,12 +3,10 @@ import torch import torch.nn as nn -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.medusa import Medusa +from vllm.model_executor.model_loader import get_model from vllm.v1.sample.metadata import SamplingMetadata # Initialize logger @@ -49,20 +47,9 @@ def propose( return [list(row) for row in zip(*draft_tokens)] def load_model(self, target_model: nn.Module) -> None: - # Get model loader and config - loader = get_model_loader(self.vllm_config.load_config) - draft_config = self.vllm_config.speculative_config.draft_model_config - - # Load model with proper dtype and config - with set_default_torch_dtype(draft_config.dtype), \ - set_current_vllm_config(self.vllm_config): - self.model = Medusa( - vllm_config=self.vllm_config.speculative_config).to( - self.device) - - # Load model weights - weights = loader.get_all_weights(draft_config, self.model) - self.model.load_weights(weights) + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) @torch.inference_mode() def dummy_run(self, num_tokens: int) -> None: From 236ab6d0d77aeef5d26f6f0a7b6ec14115e688cc Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 20 May 2025 04:56:13 -0400 Subject: [PATCH 2/3] [V1][Spec Decode] Fix v1/spec_decode/test_eagle.py Signed-off-by: Mark McLoughlin --- tests/v1/spec_decode/test_eagle.py | 58 ++++-------------------------- 1 file changed, 6 insertions(+), 52 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 7d93a44c5059..e000d955cfc0 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -117,34 +117,13 @@ def test_prepare_inputs(): ]) @mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') -@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry') -@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader') -@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype') -@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config') -def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader, - mock_registry, mock_get_layers, mock_get_pp_group, method, +@mock.patch('vllm.v1.spec_decode.eagle.get_model') +def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, proposer_helper, draft_model_dir, target_attribute_path): - # Setup mock for model class - mock_model_cls = mock.MagicMock() - mock_registry.resolve_model_cls.return_value = (mock_model_cls, - "test_arch") - - # Create a real context manager for mocks - class MockContextManager: - - def __init__(self): - pass - - def __enter__(self): - return None - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - # Make the mocks return actual context manager objects - mock_set_dtype.return_value = MockContextManager() - mock_set_config.return_value = MockContextManager() + # Setup model mock + mock_model = mock.MagicMock() + mock_get_model.return_value = mock_model # Setup mocks for attention layers target_attn_layers = { @@ -164,25 +143,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): mock_pp_group.world_size = 2 if method == "eagle" else 1 mock_get_pp_group.return_value = mock_pp_group - # Setup model loader mock - mock_loader = mock.MagicMock() - mock_get_loader.return_value = mock_loader - - # Setup model mock - mock_model = mock.MagicMock() - mock_model_cls.return_value = mock_model - mock_model.to.return_value = mock_model - - # Configure mock to test the attribute sharing path - if method == "eagle": - # For eagle, test the lm_head path - mock_model.load_weights.return_value = { - "model.embed_tokens.weight": torch.zeros(1) - } - else: - # For eagle3, test the embed_tokens path - mock_model.load_weights.return_value = {} - # Setup target model with the appropriate attributes target_model = mock.MagicMock() @@ -204,13 +164,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): proposer.load_model(target_model) # Verify common interactions - mock_get_loader.assert_called_once() - mock_model_cls.assert_called_once() - mock_model.to.assert_called_once() - mock_model.load_weights.assert_called_once() - - # Verify the loader was called with the right config - mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config) + mock_get_model.assert_called_once() # Verify the specific attribute sharing based on the method if method == "eagle": From 758c21b6001e32e7fcd562a8f00800211422dbfa Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Thu, 22 May 2025 18:53:12 -0400 Subject: [PATCH 3/3] [V1][Spec Decode] Fix eagle3 parameter definition The redundant .type(torch.LongTensor) moves the GPU tensor to CPU causing: ``` targets = base + self.draft_id_to_target_id ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~ RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! ``` Signed-off-by: Mark McLoughlin --- vllm/model_executor/models/llama_eagle3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index d358a9dae06e..96e666a3543d 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -195,8 +195,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale) self.draft_id_to_target_id = nn.Parameter( - torch.zeros((self.config.draft_vocab_size), - dtype=torch.long).type(torch.LongTensor), + torch.zeros(self.config.draft_vocab_size, dtype=torch.long), requires_grad=False, )