99import torch
1010import torch .nn as nn
1111
12+ from vllm .attention .backends .abstract import AttentionMetadataBuilder
1213from vllm .attention .layer import Attention
1314from vllm .config import (CompilationLevel , VllmConfig ,
1415 get_layers_from_vllm_config )
@@ -77,6 +78,8 @@ def __init__(
7778 self .is_multimodal_model = vllm_config .model_config \
7879 .is_multimodal_model
7980
81+ self .attn_metadata_builder : Optional [AttentionMetadataBuilder ] = None
82+
8083 self .use_cuda_graph = (self .vllm_config .compilation_config .level
8184 == CompilationLevel .PIECEWISE and
8285 not self .vllm_config .model_config .enforce_eager )
@@ -117,7 +120,7 @@ def __init__(
117120 with_numpy = True )
118121
119122 # Determine allowed attention backends once during initialization.
120- self .allowed_attn_types : tuple [type [ EagleAttentionMetadata ] , ...]
123+ self .allowed_attn_types : tuple [type , ...]
121124 if current_platform .is_rocm ():
122125 rocm_types = [TritonAttentionMetadata , FlashAttentionMetadata ]
123126 # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
@@ -190,10 +193,12 @@ def propose(
190193
191194 assert self .runner is not None
192195
193- # FIXME: need to consider multiple kv_cache_groups
194- attn_metadata_builder = \
195- self .runner .attn_groups [0 ][0 ].get_metadata_builder ()
196- attn_metadata = attn_metadata_builder .build_for_drafting (
196+ # Select the correct attention metadata builders for EAGLE layers.
197+ # Get the attention metadata builders once and reuse for later.
198+ builder = (self ._get_attention_metadata_builder ()
199+ if self .attn_metadata_builder is None else
200+ self .attn_metadata_builder )
201+ attn_metadata = builder .build_for_drafting (
197202 common_attn_metadata = common_attn_metadata , draft_index = 0 )
198203
199204 # At this moment, we assume all eagle layers belong to the same KV
@@ -327,11 +332,9 @@ def propose(
327332 exceeds_max_model_len , PADDING_SLOT_ID )
328333
329334 # Rebuild attention metadata
330- attn_metadata_builder = \
331- self .runner .attn_groups [0 ][0 ].get_metadata_builder ()
332- attn_metadata = attn_metadata_builder \
333- .build_for_drafting (common_attn_metadata = common_attn_metadata ,
334- draft_index = token_index + 1 )
335+ attn_metadata = builder .build_for_drafting (
336+ common_attn_metadata = common_attn_metadata ,
337+ draft_index = token_index + 1 )
335338 for layer_name in self .attn_layer_names :
336339 per_layer_attn_metadata [layer_name ] = attn_metadata
337340
@@ -851,10 +854,24 @@ def load_model(self, target_model: nn.Module) -> None:
851854 # share lm_head with the target model if needed
852855 # some model definition do not define lm_head explicitly
853856 # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
854- if self .vllm_config .speculative_config .method != "eagle3" and \
855- hasattr (target_language_model , "lm_head" ):
856- logger .info ("Loading EAGLE LM head weights from the target model." )
857- self .model .lm_head = target_language_model .lm_head
857+ if self .vllm_config .speculative_config .method != "eagle3" :
858+ if hasattr (target_language_model , "lm_head" ):
859+ logger .info (
860+ "Loading EAGLE LM head weights from the target model." )
861+ self .model .lm_head = target_language_model .lm_head
862+ else :
863+ if (hasattr (self .model , "lm_head" )
864+ and hasattr (target_language_model , "lm_head" )
865+ and self .model .lm_head .weight .shape
866+ == target_language_model .lm_head .weight .shape ):
867+ logger .info ("Assuming the EAGLE head shares the same lm_head"
868+ " with the target model." )
869+ del self .model .lm_head
870+ self .model .lm_head = target_language_model .lm_head
871+ else :
872+ logger .info (
873+ "The EAGLE head's lm_head will be loaded separately"
874+ " from the target model." )
858875
859876 @torch .inference_mode ()
860877 def dummy_run (
@@ -877,6 +894,31 @@ def dummy_run(
877894 inputs_embeds = inputs_embeds ,
878895 )
879896
897+ def _get_attention_metadata_builder (
898+ self ) -> list [AttentionMetadataBuilder ]:
899+ """Find and return the attention metadata builders for EAGLE layers.
900+
901+ Returns:
902+ The metadata builders for EAGLE layers.
903+
904+ Raises:
905+ AssertionError: If no metadata builders are found for EAGLE layers.
906+ """
907+ builder = None
908+ chosen_layer = self .attn_layer_names [0 ]
909+
910+ for kv_cache_group in self .runner .attn_groups :
911+ for attn_group in kv_cache_group :
912+ if chosen_layer in attn_group .layer_names :
913+ builder = attn_group .get_metadata_builder ()
914+ break
915+ if builder is not None :
916+ break
917+
918+ assert builder is not None , (
919+ "Failed to find attention metadata builder for EAGLE layers." )
920+ return builder
921+
880922 def validate_same_kv_cache_group (self ,
881923 kv_cache_config : KVCacheConfig ) -> None :
882924 """
0 commit comments