Skip to content

Commit 4c61134

Browse files
authored
[V1] [Bugfix] eagle bugfix and enable correct lm_head for multimodal (#18034)
Signed-off-by: Ronald Xu <ronaldxu@amazon.com>
1 parent 60cad94 commit 4c61134

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

vllm/transformers_utils/configs/eagle.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def __init__(self,
7070

7171
if self.model is not None:
7272
for k, v in self.model.to_dict().items():
73-
if not hasattr(self, k):
74-
setattr(self, k, v)
73+
setattr(self, k, v)
7574

7675
@classmethod
7776
def from_pretrained(

vllm/v1/spec_decode/eagle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.forward_context import set_forward_context
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.model_loader import get_model
12+
from vllm.model_executor.models import supports_multimodal
1213
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1314
from vllm.triton_utils import tl, triton
1415
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
@@ -310,7 +311,10 @@ def load_model(self, target_model: nn.Module) -> None:
310311
if self.vllm_config.speculative_config.method != "eagle3" and \
311312
hasattr(target_model, "lm_head"):
312313
logger.info("Loading EAGLE LM head weights from the target model.")
313-
self.model.lm_head = target_model.lm_head
314+
if supports_multimodal(target_model):
315+
self.model.lm_head = target_model.get_language_model().lm_head
316+
else:
317+
self.model.lm_head = target_model.lm_head
314318

315319
@torch.inference_mode()
316320
def dummy_run(

0 commit comments

Comments
 (0)