Skip to content

Commit 9f94737

Browse files
RonaldBXuamitm02
authored andcommitted
[V1] [Bugfix] eagle bugfix and enable correct lm_head for multimodal (2) (vllm-project#18781)
Signed-off-by: Ronald Xu <ronaldxu@amazon.com> Signed-off-by: amit <amit.man@gmail.com>
1 parent 395d5d3 commit 9f94737

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

vllm/transformers_utils/configs/eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self,
7070

7171
if self.model is not None:
7272
for k, v in self.model.to_dict().items():
73-
if not hasattr(self, k):
73+
if k not in kwargs:
7474
setattr(self, k, v)
7575

7676
@classmethod

vllm/v1/spec_decode/eagle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.forward_context import set_forward_context
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.model_loader import get_model
12+
from vllm.model_executor.models import supports_multimodal
1213
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1314
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
1415
FlashAttentionMetadata)
@@ -346,7 +347,10 @@ def load_model(self, target_model: nn.Module) -> None:
346347
if self.vllm_config.speculative_config.method != "eagle3" and \
347348
hasattr(target_model, "lm_head"):
348349
logger.info("Loading EAGLE LM head weights from the target model.")
349-
self.model.lm_head = target_model.lm_head
350+
if supports_multimodal(target_model):
351+
self.model.lm_head = target_model.get_language_model().lm_head
352+
else:
353+
self.model.lm_head = target_model.lm_head
350354

351355
@torch.inference_mode()
352356
def dummy_run(

0 commit comments

Comments
 (0)