Skip to content

Commit 6fb605a

Browse files
david6666666JunhongLJH-LBJ
authored andcommitted
[Multimodal][Speculative Decoding]Eagle Eagle3 mm support, enablement on qwen2.5vl (vllm-project#22872)
Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 9ec315f commit 6fb605a

File tree

8 files changed

+210
-45
lines changed

8 files changed

+210
-45
lines changed

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,9 @@ def check_available_online(
651651
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
652652
trust_remote_code=True,
653653
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
654+
"Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo(
655+
"Qwen/Qwen2.5-VL-7B-Instruct",
656+
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
654657
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
655658
min_transformers_version="4.56.3"),
656659
}

tests/v1/e2e/test_spec_decode.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def test_ngram_correctness(
129129
["model_setup", "mm_enabled"],
130130
[
131131
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
132+
pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
133+
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
134+
False,
135+
marks=pytest.mark.skip(reason="Skipping due to its " \
136+
"head_dim not being a a multiple of 32")),
132137
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
133138
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
134139
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
@@ -145,8 +150,8 @@ def test_ngram_correctness(
145150
"eagle618/eagle-deepseek-v3-random", 1), False),
146151
],
147152
ids=[
148-
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle",
149-
"llama4_eagle_mm", "deepseek_eagle"
153+
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
154+
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
150155
])
151156
@pytest.mark.parametrize("attn_backend",
152157
get_attn_backend_list_based_on_platform())

vllm/benchmarks/datasets.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14501450
):
14511451
dataset_class = MLPerfDataset
14521452
args.hf_split = "train"
1453+
elif (
1454+
args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
1455+
or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
1456+
):
1457+
dataset_class = MMStarDataset
1458+
args.hf_split = "val"
1459+
args.hf_subset = None
14531460
else:
14541461
supported_datasets = set([
14551462
dataset_name for cls in HuggingFaceDataset.__subclasses__()
@@ -2721,3 +2728,76 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
27212728

27222729
random.shuffle(requests)
27232730
return requests
2731+
2732+
2733+
# -----------------------------------------------------------------------------
2734+
# MMStar Dataset Implementation
2735+
# -----------------------------------------------------------------------------
2736+
2737+
2738+
class MMStarDataset(HuggingFaceDataset):
2739+
"""
2740+
Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar
2741+
refer to: https://github.com/sgl-project/SpecForge/pull/106
2742+
"""
2743+
DEFAULT_OUTPUT_LEN = 128
2744+
SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"}
2745+
IS_MULTIMODAL = True
2746+
2747+
def sample(
2748+
self,
2749+
tokenizer: PreTrainedTokenizerBase,
2750+
num_requests: int,
2751+
output_len: Optional[int] = None,
2752+
enable_multimodal_chat: bool = False,
2753+
request_id_prefix: str = "",
2754+
no_oversample: bool = False,
2755+
**kwargs,
2756+
) -> list[SampleRequest]:
2757+
# If --hf-output-len is not set, use the default output length.
2758+
output_len = (output_len
2759+
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
2760+
sampled_requests: list[SampleRequest] = []
2761+
2762+
for ind, item in enumerate(self.data):
2763+
if len(sampled_requests) >= num_requests:
2764+
break
2765+
# Split the question text from options
2766+
# (keep only the part before "Options:").
2767+
full_q: str = item.get("question", "")
2768+
question_text = full_q.split("Options:", 1)[0].strip()
2769+
2770+
# Multimodal image content.
2771+
mm_content = process_image(item["image"])
2772+
2773+
# Compute prompt token length (note: this is plain text length
2774+
# if enable_multimodal_chat is False).
2775+
prompt_len = len(tokenizer(question_text).input_ids)
2776+
2777+
if enable_multimodal_chat:
2778+
# If multimodal content should be embedded in the chat message,
2779+
# convert to [{"role":"user","content":[...]}]
2780+
prompt = self.apply_multimodal_chat_transformation(
2781+
question_text, mm_content
2782+
)
2783+
mm_for_request = None # Already embedded in chat content.
2784+
else:
2785+
# Default: prompt is plain text,
2786+
# image is in mm_content for the bench to assemble.
2787+
prompt = question_text
2788+
mm_for_request = mm_content
2789+
2790+
sampled_requests.append(
2791+
SampleRequest(
2792+
prompt=prompt,
2793+
prompt_len=prompt_len,
2794+
expected_output_len=output_len,
2795+
multi_modal_data=mm_for_request,
2796+
request_id=request_id_prefix + str(ind),
2797+
)
2798+
)
2799+
2800+
self.maybe_oversample_requests(
2801+
sampled_requests, num_requests, request_id_prefix, no_oversample
2802+
)
2803+
return sampled_requests

vllm/model_executor/models/llama_eagle3.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.nn as nn
99
from transformers import LlamaConfig
1010

11-
from vllm.compilation.decorators import support_torch_compile
1211
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
1312
from vllm.logger import init_logger
1413
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -19,6 +18,7 @@
1918
from vllm.model_executor.layers.vocab_parallel_embedding import (
2019
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
2120
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21+
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
2222
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
2323
LlamaForCausalLM)
2424

@@ -102,7 +102,6 @@ def forward(
102102
return hidden_states, residual
103103

104104

105-
@support_torch_compile
106105
class LlamaModel(nn.Module):
107106

108107
def __init__(
@@ -145,13 +144,21 @@ def __init__(
145144
eps=self.config.rms_norm_eps,
146145
)
147146

147+
def get_input_embeddings(
148+
self,
149+
input_ids: torch.Tensor,
150+
) -> torch.Tensor:
151+
return self.embed_tokens(input_ids)
152+
148153
def forward(
149154
self,
150155
input_ids: torch.Tensor,
151156
positions: torch.Tensor,
152157
hidden_states: torch.Tensor,
158+
input_embeds: Optional[torch.Tensor] = None,
153159
) -> tuple[torch.Tensor, torch.Tensor]:
154-
input_embeds = self.embed_tokens(input_ids)
160+
if input_embeds is None:
161+
input_embeds = self.get_input_embeddings(input_ids)
155162
assert hidden_states.shape[-1] == input_embeds.shape[-1]
156163

157164
residual = None
@@ -239,11 +246,7 @@ def forward(
239246
hidden_states: torch.Tensor,
240247
inputs_embeds: Optional[torch.Tensor] = None,
241248
) -> tuple[torch.Tensor, torch.Tensor]:
242-
if inputs_embeds is not None:
243-
raise NotImplementedError(
244-
f"{type(self).__name__} does not support multimodal inputs yet."
245-
)
246-
return self.model(input_ids, positions, hidden_states)
249+
return self.model(input_ids, positions, hidden_states, inputs_embeds)
247250

248251
def compute_logits(
249252
self,
@@ -299,3 +302,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
299302
skip_substrs=skip_substrs,
300303
)
301304
loader.load_weights(model_weights.items())
305+
306+
def get_input_embeddings(
307+
self,
308+
input_ids: torch.Tensor,
309+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
310+
) -> torch.Tensor:
311+
inputs_embeds = self.model.get_input_embeddings(input_ids)
312+
return inputs_embeds

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from vllm.utils import is_pin_memory_available
6969
from vllm.utils.tensor_schema import TensorSchema, TensorShape
7070

71-
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
71+
from .interfaces import (MultiModalEmbeddings, SupportsEagle3, SupportsLoRA,
7272
SupportsMultiModal, SupportsMultiModalPruning,
7373
SupportsPP, SupportsQuant)
7474
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
@@ -965,7 +965,7 @@ def get_replacement_qwen2vl(item_idx: int, modality: str):
965965
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
966966
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
967967
SupportsLoRA, SupportsPP,
968-
SupportsQuant,
968+
SupportsQuant, SupportsEagle3,
969969
SupportsMultiModalPruning):
970970

971971
packed_modules_mapping = {
@@ -1028,6 +1028,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
10281028
self.make_empty_intermediate_tensors = (
10291029
self.language_model.make_empty_intermediate_tensors)
10301030

1031+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
1032+
self.language_model.model.aux_hidden_state_layers = layers
1033+
1034+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
1035+
num_layers = len(self.language_model.model.layers)
1036+
return (2, num_layers // 2, num_layers - 3)
1037+
10311038
def _validate_and_reshape_mm_tensor(self, mm_input: object,
10321039
name: str) -> torch.Tensor:
10331040
if not isinstance(mm_input, (torch.Tensor, list)):

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@
286286
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
287287
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
288288
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
289+
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
289290
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
290291
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
291292
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),

0 commit comments

Comments
 (0)