Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ def create_deterministic_logits(token_ids):
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)

result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
Expand Down Expand Up @@ -660,6 +662,8 @@ def create_deterministic_logits(token_ids, k: int):
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)

# Setup inputs for the proposer.
target_token_ids = torch.randint(0,
Expand Down
1 change: 1 addition & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,7 @@ def _verify_quantization(self) -> None:
self.quantization = quantization_override
break

quant_method = quant_method if quant_method != "" else None
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config,
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)

Expand Down
70 changes: 56 additions & 14 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn

from vllm.attention.backends.abstract import AttentionMetadataBuilder
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
Expand Down Expand Up @@ -77,6 +78,8 @@ def __init__(
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model

self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None

self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
Expand Down Expand Up @@ -117,7 +120,7 @@ def __init__(
with_numpy=True)

# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
self.allowed_attn_types: tuple[type, ...]
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
Expand Down Expand Up @@ -190,10 +193,12 @@ def propose(

assert self.runner is not None

# FIXME: need to consider multiple kv_cache_groups
attn_metadata_builder = \
self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = attn_metadata_builder.build_for_drafting(
# Select the correct attention metadata builders for EAGLE layers.
# Get the attention metadata builders once and reuse for later.
builder = (self._get_attention_metadata_builder()
if self.attn_metadata_builder is None else
self.attn_metadata_builder)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)

# At this moment, we assume all eagle layers belong to the same KV
Expand Down Expand Up @@ -327,11 +332,9 @@ def propose(
exceeds_max_model_len, PADDING_SLOT_ID)

# Rebuild attention metadata
attn_metadata_builder = \
self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = attn_metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata

Expand Down Expand Up @@ -851,10 +854,24 @@ def load_model(self, target_model: nn.Module) -> None:
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \
hasattr(target_language_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
if self.vllm_config.speculative_config.method != "eagle3":
if hasattr(target_language_model, "lm_head"):
logger.info(
"Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
else:
if (hasattr(self.model, "lm_head")
and hasattr(target_language_model, "lm_head")
and self.model.lm_head.weight.shape
== target_language_model.lm_head.weight.shape):
logger.info("Assuming the EAGLE head shares the same lm_head"
" with the target model.")
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
else:
logger.info(
"The EAGLE head's lm_head will be loaded separately"
" from the target model.")

@torch.inference_mode()
def dummy_run(
Expand All @@ -877,6 +894,31 @@ def dummy_run(
inputs_embeds=inputs_embeds,
)

def _get_attention_metadata_builder(
self) -> list[AttentionMetadataBuilder]:
"""Find and return the attention metadata builders for EAGLE layers.

Returns:
The metadata builders for EAGLE layers.

Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""
builder = None
chosen_layer = self.attn_layer_names[0]

for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
if chosen_layer in attn_group.layer_names:
builder = attn_group.get_metadata_builder()
break
if builder is not None:
break

assert builder is not None, (
"Failed to find attention metadata builder for EAGLE layers.")
return builder

def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Expand Down
11 changes: 8 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,9 +1177,14 @@ def _prepare_inputs(
encoder_seq_lens=encoder_seq_lens,
)

if self.speculative_config and \
spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata
if (self.speculative_config
and spec_decode_common_attn_metadata is None):
if isinstance(self.drafter, EagleProposer):
if (self.drafter.attn_layer_names[0]
in kv_cache_group_spec.layer_names):
spec_decode_common_attn_metadata = common_attn_metadata
else:
spec_decode_common_attn_metadata = common_attn_metadata

for attn_group in self.attn_groups[kv_cache_group_id]:
# Prepare for cascade attention if enabled & beneficial.
Expand Down