diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 23bfabfcf89b..5096f9fd647b 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -534,6 +534,8 @@ def create_deterministic_logits(token_ids): proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ attn_metadata_builder + proposer._get_attention_metadata_builder = mock.MagicMock( + return_value=attn_metadata_builder) result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, @@ -660,6 +662,8 @@ def create_deterministic_logits(token_ids, k: int): proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ attn_metadata_builder + proposer._get_attention_metadata_builder = mock.MagicMock( + return_value=attn_metadata_builder) # Setup inputs for the proposer. target_token_ids = torch.randint(0, diff --git a/vllm/config/model.py b/vllm/config/model.py index 33e5d3ea04a4..d8a8fe20fd03 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1003,6 +1003,7 @@ def _verify_quantization(self) -> None: self.quantization = quantization_override break + quant_method = quant_method if quant_method != "" else None # Verify quantization configurations. if self.quantization is None: self.quantization = quant_method diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index dfae3c3ea543..2ff2d54a83aa 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -134,6 +134,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + # Ensure draft_vocab_size is set + # default to the base vocab size when absent + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) self.model = LlamaModel(vllm_config=vllm_config, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index fb10af6c53c9..b99a1547918e 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -203,6 +203,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + # Ensure draft_vocab_size is set + # default to the base vocab size when absent + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a0f40828d42f..a9e0a38fe341 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +from vllm.attention.backends.abstract import AttentionMetadataBuilder from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) @@ -77,6 +78,8 @@ def __init__( self.is_multimodal_model = vllm_config.model_config \ .is_multimodal_model + self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) @@ -117,7 +120,7 @@ def __init__( with_numpy=True) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + self.allowed_attn_types: tuple[type, ...] if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend @@ -190,10 +193,12 @@ def propose( assert self.runner is not None - # FIXME: need to consider multiple kv_cache_groups - attn_metadata_builder = \ - self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata = attn_metadata_builder.build_for_drafting( + # Select the correct attention metadata builders for EAGLE layers. + # Get the attention metadata builders once and reuse for later. + builder = (self._get_attention_metadata_builder() + if self.attn_metadata_builder is None else + self.attn_metadata_builder) + attn_metadata = builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0) # At this moment, we assume all eagle layers belong to the same KV @@ -327,11 +332,9 @@ def propose( exceeds_max_model_len, PADDING_SLOT_ID) # Rebuild attention metadata - attn_metadata_builder = \ - self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata = attn_metadata_builder\ - .build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=token_index + 1) + attn_metadata = builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=token_index + 1) for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata @@ -851,10 +854,24 @@ def load_model(self, target_model: nn.Module) -> None: # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_language_model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") - self.model.lm_head = target_language_model.lm_head + if self.vllm_config.speculative_config.method != "eagle3": + if hasattr(target_language_model, "lm_head"): + logger.info( + "Loading EAGLE LM head weights from the target model.") + self.model.lm_head = target_language_model.lm_head + else: + if (hasattr(self.model, "lm_head") + and hasattr(target_language_model, "lm_head") + and self.model.lm_head.weight.shape + == target_language_model.lm_head.weight.shape): + logger.info("Assuming the EAGLE head shares the same lm_head" + " with the target model.") + del self.model.lm_head + self.model.lm_head = target_language_model.lm_head + else: + logger.info( + "The EAGLE head's lm_head will be loaded separately" + " from the target model.") @torch.inference_mode() def dummy_run( @@ -877,6 +894,31 @@ def dummy_run( inputs_embeds=inputs_embeds, ) + def _get_attention_metadata_builder( + self) -> list[AttentionMetadataBuilder]: + """Find and return the attention metadata builders for EAGLE layers. + + Returns: + The metadata builders for EAGLE layers. + + Raises: + AssertionError: If no metadata builders are found for EAGLE layers. + """ + builder = None + chosen_layer = self.attn_layer_names[0] + + for kv_cache_group in self.runner.attn_groups: + for attn_group in kv_cache_group: + if chosen_layer in attn_group.layer_names: + builder = attn_group.get_metadata_builder() + break + if builder is not None: + break + + assert builder is not None, ( + "Failed to find attention metadata builder for EAGLE layers.") + return builder + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ed324138c6fe..c3dc0374ca9d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1177,9 +1177,14 @@ def _prepare_inputs( encoder_seq_lens=encoder_seq_lens, ) - if self.speculative_config and \ - spec_decode_common_attn_metadata is None: - spec_decode_common_attn_metadata = common_attn_metadata + if (self.speculative_config + and spec_decode_common_attn_metadata is None): + if isinstance(self.drafter, EagleProposer): + if (self.drafter.attn_layer_names[0] + in kv_cache_group_spec.layer_names): + spec_decode_common_attn_metadata = common_attn_metadata + else: + spec_decode_common_attn_metadata = common_attn_metadata for attn_group in self.attn_groups[kv_cache_group_id]: # Prepare for cascade attention if enabled & beneficial.