From cb5c553e5f11ef833664db8f36796f276d564df6 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Tue, 30 Sep 2025 15:24:00 +0000 Subject: [PATCH 1/7] Add Eagle3 config support for auxiliary hidden state layer IDs Support configuring eagle_aux_hidden_state_layer_ids and inference_type in the Eagle3 speculator configuration. This allows users to specify which verifier layers should output auxiliary hidden states for the drafter to consume during speculative decoding. Signed-off-by: rahul-tuli Signed-off-by: Rahul Tuli --- vllm/transformers_utils/configs/speculators/algos.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py index 1375eca28e41..70f19aa47b35 100644 --- a/vllm/transformers_utils/configs/speculators/algos.py +++ b/vllm/transformers_utils/configs/speculators/algos.py @@ -28,3 +28,8 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None: vllm_config["target_hidden_size"] = config_dict["target_hidden_size"] vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True) vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"] + if config_dict.get("eagle_aux_hidden_state_layer_ids"): + vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ + "eagle_aux_hidden_state_layer_ids"] + if config_dict.get("inference_type"): + vllm_config["inference_type"] = config_dict["inference_type"] From 07e7c78e82f8f826029c71df5ac661ed128d3cf8 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Tue, 30 Sep 2025 15:24:05 +0000 Subject: [PATCH 2/7] Document Eagle3 auxiliary layer default selection in Llama Add documentation explaining that get_eagle3_aux_hidden_state_layers() provides default layer selection and that the GPU model runner can override this with values from speculative config for dynamic configuration. Signed-off-by: rahul-tuli Signed-off-by: Rahul Tuli --- vllm/model_executor/models/llama.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index faed1abb3bab..948c9280f953 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -604,6 +604,11 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Override to return default layers for Llama + + Note: The GPU model runner will override this with layers from + the speculative config if available, providing dynamic configuration. + """ num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) From 58dfcf6841160f81263c32073bbc7ab32652f7ca Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Tue, 30 Sep 2025 15:24:11 +0000 Subject: [PATCH 3/7] Implement SupportsEagle3 interface for Llama4 multimodal models Add Eagle3 support to Llama4ForConditionalGeneration by implementing set_aux_hidden_state_layers() and get_eagle3_aux_hidden_state_layers() methods. Both methods delegate to the underlying Llama4ForCausalLM language model, enabling Eagle3 speculative decoding with Llama4 multimodal verifier models. This allows text-only Eagle3 drafters to work with Llama4 multimodal verifiers by consuming auxiliary hidden states from specified layers. Signed-off-by: rahul-tuli Signed-off-by: Rahul Tuli --- vllm/model_executor/models/mllama4.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 1c8e8686ccae..c001ba403839 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -64,7 +64,8 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsEagle3, + SupportsMultiModal, SupportsPP) from .llama4 import Llama4ForCausalLM from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -717,7 +718,9 @@ def get_dummy_mm_data( info=Mllama4ProcessingInfo, dummy_inputs=Mllama4DummyInputsBuilder, ) -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): +class Llama4ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 +): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -767,6 +770,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model.make_empty_intermediate_tensors ) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + """Set which layers should output auxiliary hidden states for EAGLE3.""" + # Delegate to underlying language model (Llama4ForCausalLM) + assert hasattr(self.language_model, 'set_aux_hidden_state_layers') + self.language_model.set_aux_hidden_state_layers(layers) + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Get the layer indices for auxiliary hidden state outputs. + + Note: The GPU model runner will override this with layers from + the speculative config if available, providing dynamic configuration. + """ + # Delegate to underlying language model (Llama4ForCausalLM) + assert hasattr( + self.language_model, "get_eagle3_aux_hidden_state_layers" + ) + return self.language_model.get_eagle3_aux_hidden_state_layers() + def _parse_and_validate_image_input( self, **kwargs: object ) -> Optional[Llama4ImagePatchInputs]: From 730f04d88a97a05414e8fa4d89e2df58a7400ffb Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Tue, 30 Sep 2025 15:24:19 +0000 Subject: [PATCH 4/7] Override get_input_embeddings in Eagle3 to process text-only inputs Implement custom get_input_embeddings() in Eagle3LlamaForCausalLM that accepts multimodal parameters but only processes text embeddings. This ensures the Llama3-based Eagle3 drafter correctly handles text inputs while remaining compatible with multimodal verifier interfaces. The drafter receives multimodal context through auxiliary hidden states from the verifier rather than processing multimodal inputs directly. Signed-off-by: rahul-tuli Signed-off-by: Rahul Tuli --- vllm/model_executor/models/llama_eagle3.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 712c8df3dbbb..0372ed44a6cd 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -21,6 +21,7 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM +from vllm.multimodal.inputs import NestedTensors from .utils import AutoWeightsLoader, maybe_prefix @@ -241,8 +242,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): requires_grad=False, ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + is_multimodal: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # The llama3 drafter only processes text embeddings + return self.model.embed_tokens(input_ids) def forward( self, From 06c6c932a5e72713c1d4b1f7b83413d060c73e6d Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Tue, 30 Sep 2025 15:25:31 +0000 Subject: [PATCH 5/7] Add dynamic Eagle3 auxiliary layer configuration from speculative config Implement _get_eagle3_aux_layers_from_config() helper method to extract auxiliary layer IDs from the draft model's speculative config. The GPU model runner now prefers config-specified layers over model defaults, with fallback to model's get_eagle3_aux_hidden_state_layers() when not configured. Changes: - Refactor auxiliary layer setup with early return pattern for errors - Add config extraction with proper error handling - Log only when using non-default layer configuration - Enable dynamic layer configuration per deployment Signed-off-by: rahul-tuli Signed-off-by: Rahul Tuli --- vllm/v1/worker/gpu_model_runner.py | 50 +++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b31571a7c000..4800449e9ad0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2943,15 +2943,29 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - if supports_eagle3(self.model): - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers() - ) - else: + if not supports_eagle3(self.model): raise RuntimeError( "Model does not support EAGLE3 interface but " "aux_hidden_state_outputs was requested" ) + + # Try to get auxiliary layers from speculative config, + # otherwise use model's default layers + aux_layers = ( + self._get_eagle3_aux_layers_from_config() + or self.model.get_eagle3_aux_hidden_state_layers() + ) + + if ( + aux_layers + != self.model.get_eagle3_aux_hidden_state_layers() + ): + logger.info( + "Using auxiliary layers from speculative config: %s", + aux_layers, + ) + + self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info( @@ -3006,6 +3020,32 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model, self.vllm_config, CUDAGraphMode.NONE, self.device ) + def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]: + """Extract Eagle3 auxiliary layer IDs from speculative config. + + Returns: + Tuple of layer indices if found in draft model config, + None otherwise. + """ + if not (self.speculative_config + and self.speculative_config.draft_model_config): + return None + + try: + hf_config = self.speculative_config.draft_model_config.hf_config + if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'): + return None + + layer_ids = hf_config.eagle_aux_hidden_state_layer_ids + if layer_ids and isinstance(layer_ids, (list, tuple)): + return tuple(layer_ids) + except Exception as e: + logger.warning( + "Failed to read auxiliary layers from speculative config: %s", + e) + + return None + def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." From 1c1d67949c1ab8f8fcbc1f51e30ff67b28e33387 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 3 Oct 2025 08:17:21 +0000 Subject: [PATCH 6/7] Review comments Signed-off-by: Rahul Tuli --- .../configs/speculators/algos.py | 6 ++- vllm/v1/worker/gpu_model_runner.py | 38 ++++++++----------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py index 70f19aa47b35..11bdb053ba5e 100644 --- a/vllm/transformers_utils/configs/speculators/algos.py +++ b/vllm/transformers_utils/configs/speculators/algos.py @@ -21,6 +21,10 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None: - draft_vocab_size: Size of the draft model's vocabulary - target_hidden_size: Hidden size of the target model - norm_before_residual: Whether to apply norm before residual connection + - eagle_aux_hidden_state_layer_ids: List of layer indices from the base + model to use as auxiliary inputs for the Eagle3 drafter. These layers + provide intermediate hidden states that help the drafter make better + predictions. This is the standard field used in Eagle3 checkpoints. """ vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") @@ -31,5 +35,3 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None: if config_dict.get("eagle_aux_hidden_state_layer_ids"): vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ "eagle_aux_hidden_state_layer_ids"] - if config_dict.get("inference_type"): - vllm_config["inference_type"] = config_dict["inference_type"] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4800449e9ad0..a9c3984e3d13 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2951,19 +2951,14 @@ def load_model(self, eep_scale_up: bool = False) -> None: # Try to get auxiliary layers from speculative config, # otherwise use model's default layers - aux_layers = ( - self._get_eagle3_aux_layers_from_config() - or self.model.get_eagle3_aux_hidden_state_layers() - ) - - if ( - aux_layers - != self.model.get_eagle3_aux_hidden_state_layers() - ): + aux_layers = self._get_eagle3_aux_layers_from_config() + if aux_layers: logger.info( "Using auxiliary layers from speculative config: %s", aux_layers, ) + else: + aux_layers = self.model.get_eagle3_aux_hidden_state_layers() self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() @@ -3021,7 +3016,11 @@ def load_model(self, eep_scale_up: bool = False) -> None: ) def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]: - """Extract Eagle3 auxiliary layer IDs from speculative config. + """Extract Eagle3 auxiliary layer indices from speculative config. + + These indices specify which hidden states from the base model should + be used as auxiliary inputs for the Eagle3 drafter model during + speculative decoding. Returns: Tuple of layer indices if found in draft model config, @@ -3031,18 +3030,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]: and self.speculative_config.draft_model_config): return None - try: - hf_config = self.speculative_config.draft_model_config.hf_config - if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'): - return None - - layer_ids = hf_config.eagle_aux_hidden_state_layer_ids - if layer_ids and isinstance(layer_ids, (list, tuple)): - return tuple(layer_ids) - except Exception as e: - logger.warning( - "Failed to read auxiliary layers from speculative config: %s", - e) + hf_config = self.speculative_config.draft_model_config.hf_config + if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'): + return None + + layer_ids = hf_config.eagle_aux_hidden_state_layer_ids + if layer_ids and isinstance(layer_ids, (list, tuple)): + return tuple(layer_ids) return None From 1037b3625b806b55928c5cf25c6b0cba9324449f Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 3 Oct 2025 14:06:51 +0530 Subject: [PATCH 7/7] Use get_input_embeddings Signed-off-by: Rahul Tuli --- vllm/model_executor/models/llama_eagle3.py | 3 +-- vllm/model_executor/models/mllama4.py | 14 ++++++++------ .../configs/speculators/algos.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 5 ++--- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 0372ed44a6cd..155a4ecea28f 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -248,8 +248,7 @@ def get_input_embeddings( multimodal_embeddings: Optional[NestedTensors] = None, is_multimodal: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # The llama3 drafter only processes text embeddings - return self.model.embed_tokens(input_ids) + return self.model.get_input_embeddings(input_ids) def forward( self, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index c001ba403839..b624a6200ab3 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -64,8 +64,12 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsEagle3, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsEagle3, + SupportsMultiModal, + SupportsPP, +) from .llama4 import Llama4ForCausalLM from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -773,7 +777,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: """Set which layers should output auxiliary hidden states for EAGLE3.""" # Delegate to underlying language model (Llama4ForCausalLM) - assert hasattr(self.language_model, 'set_aux_hidden_state_layers') + assert hasattr(self.language_model, "set_aux_hidden_state_layers") self.language_model.set_aux_hidden_state_layers(layers) def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: @@ -783,9 +787,7 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: the speculative config if available, providing dynamic configuration. """ # Delegate to underlying language model (Llama4ForCausalLM) - assert hasattr( - self.language_model, "get_eagle3_aux_hidden_state_layers" - ) + assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") return self.language_model.get_eagle3_aux_hidden_state_layers() def _parse_and_validate_image_input( diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py index 11bdb053ba5e..88bce3d4f79e 100644 --- a/vllm/transformers_utils/configs/speculators/algos.py +++ b/vllm/transformers_utils/configs/speculators/algos.py @@ -34,4 +34,5 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None: vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"] if config_dict.get("eagle_aux_hidden_state_layer_ids"): vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ - "eagle_aux_hidden_state_layer_ids"] + "eagle_aux_hidden_state_layer_ids" + ] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a9c3984e3d13..5cbbe435a789 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3026,12 +3026,11 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]: Tuple of layer indices if found in draft model config, None otherwise. """ - if not (self.speculative_config - and self.speculative_config.draft_model_config): + if not (self.speculative_config and self.speculative_config.draft_model_config): return None hf_config = self.speculative_config.draft_model_config.hf_config - if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'): + if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"): return None layer_ids = hf_config.eagle_aux_hidden_state_layer_ids