From 8874e1613c7ff0725819ff14265e188ed0ebf50e Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 14 Aug 2025 20:19:22 +0800 Subject: [PATCH 01/32] [V1] support eagle and eagle3 for qwen2_5vl Signed-off-by: Junhong --- tests/models/registry.py | 6 +- tests/v1/e2e/test_spec_decode.py | 14 +- vllm/model_executor/models/qwen2_5_vl.py | 15 +- .../model_executor/models/qwen2_5_vl_eagle.py | 208 ++++++++++++ .../models/qwen2_5_vl_eagle3.py | 321 ++++++++++++++++++ vllm/model_executor/models/registry.py | 4 + vllm/v1/spec_decode/eagle.py | 19 +- 7 files changed, 579 insertions(+), 8 deletions(-) create mode 100644 vllm/model_executor/models/qwen2_5_vl_eagle.py create mode 100644 vllm/model_executor/models/qwen2_5_vl_eagle3.py diff --git a/tests/models/registry.py b/tests/models/registry.py index eb48c0f6a773..49b724fde66f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -549,7 +549,11 @@ def check_available_online( is_available_online=False), "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, - speculative_model="XiaomiMiMo/MiMo-7B-RL") + speculative_model="XiaomiMiMo/MiMo-7B-RL"), + "Eagle3Qwen2_5_VLForCausalLM": _HfExamplesInfo( + "Qwen/Qwen2.5-VL-7B-Instruct", + trust_remote_code=True, + speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index dde95fbe590b..85b1f0e0f762 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -130,6 +130,8 @@ def test_ngram_correctness( [ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct", + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), False), (("eagle", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", @@ -144,14 +146,21 @@ def test_ngram_correctness( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param(("eagle", "Qwen/Qwen2.5-VL-7B-Instruct", + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), + False, + marks=pytest.mark.skip( + reason="Skipping due to lack of eagle model")), ], ids=[ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 # "qwen3_eagle3", + "qwen2.5_vl_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle", - "llama4_eagle_mm" + "llama4_eagle_mm", + "qwen2.5_vl_eagle" ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @@ -183,6 +192,9 @@ def test_eagle_correctness( method, model_name, spec_model_name, tp_size = model_setup + if "Qwen2.5-VL" in model_name and attn_backend == "TREE_ATTN": + pytest.skip("TREE ATTN not support Qwen2.5-VL Model yet") + print(f"model_setup={model_setup}") ref_llm = LLM(model=model_name, max_model_len=2048, tensor_parallel_size=tp_size) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 5bcbcc4f0e37..144ffd67a7ef 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -247,13 +247,14 @@ def __init__( self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.ROCM_AITER_FA, _Backend.FLASH_ATTN_VLLM_V1 }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN_VLLM_V1 } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -643,7 +644,8 @@ def compute_attn_mask_seqlen( ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + or self.attn_backend == _Backend.ROCM_AITER_FA + or self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -864,6 +866,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + self.language_model.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + num_layers = len(self.language_model.model.layers) + return (2, num_layers // 2, num_layers - 3) + def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. diff --git a/vllm/model_executor/models/qwen2_5_vl_eagle.py b/vllm/model_executor/models/qwen2_5_vl_eagle.py new file mode 100644 index 000000000000..342740fdff9a --- /dev/null +++ b/vllm/model_executor/models/qwen2_5_vl_eagle.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer, + Qwen2ForCausalLM) +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings +from .utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + + +@support_torch_compile +class Qwen2_5Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + self.multimodal_config = (vllm_config.speculative_config. + draft_model_config.multimodal_config) + # embbeding + if get_pp_group().is_first_rank or (self.config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # language model initial + self.layers = nn.ModuleList([ + Qwen2DecoderLayer( + self.config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + # Eagle feature fusion + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + # Eagle feature fusion + hidden_states = self.fc( + torch.cat((inputs_embeds, hidden_states), dim=-1)) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + # name = name.removeprefix("model.") + # TODO :related to the trained model and may need to be modified + if (name.find("t2d") or name.find("d2t") + or name.find("hidden_norm")) and name not in params_dict: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # TODO: train a suitable model + if name.startswith("fc"): + loaded_weight = loaded_weight[:, :self.config.hidden_size * + 2] + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class EagleQwen2_5_VLForCausalLM(Qwen2ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.speculative_config.\ + draft_model_config.hf_config + self.multimodal_config = vllm_config.model_config.multimodal_config + + # The number of layers in the target model + # start_layer_id for the draft model + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + # draft model quantization config may differ from target model + quant_config = VllmConfig.get_quantization_config( + vllm_config.speculative_config.draft_model_config, + vllm_config.load_config) + # Initialize the EAGLE model of QWEN2.5 + self.model = Qwen2_5Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "draft_model"), + start_layer_id=target_layer_num, + quant_config=quant_config) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def load_weights(self, weights): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."]), + ) + model_weights = {} + + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + loader.load_weights(model_weights.items()) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states, inputs_embeds) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds diff --git a/vllm/model_executor/models/qwen2_5_vl_eagle3.py b/vllm/model_executor/models/qwen2_5_vl_eagle3.py new file mode 100644 index 000000000000..fa260ee293f4 --- /dev/null +++ b/vllm/model_executor/models/qwen2_5_vl_eagle3.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import Qwen2Config + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer, + Qwen2ForCausalLM) +from vllm.sequence import IntermediateTensors +from vllm.v1.sample.metadata import SamplingMetadata + +from .interfaces import MultiModalEmbeddings +from .utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + + +class Qwen2_5DecodeLayer(Qwen2DecoderLayer): + + def __init__( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config=quant_config, prefix=prefix) + + # override qkv + self.self_attn.qkv_proj = QKVParallelLinear( + 2 * self.hidden_size, + self.self_attn.head_dim, + self.self_attn.total_num_heads, + self.self_attn.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "qkv_proj"), + ) + # Add a normalization layer + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = hidden_states + embeds = self.input_layernorm(embeds) + hidden_states = self.hidden_norm(hidden_states) + # Reuse the target model's features + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class Qwen2_5Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + self.multimodal_config = (vllm_config.speculative_config. + draft_model_config.multimodal_config) + # embbeding + if get_pp_group().is_first_rank or (self.config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + # small language model initialization + self.layers = nn.ModuleList([ + Qwen2_5DecodeLayer( + config=self.config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + # The EAGLE3 feature fusion layer needs to + # fuse high, medium, and low-level features. + # Therefore, the input size is hidden_size * 3 + if hasattr(self.config, "target_hidden_size"): + self.fc = torch.nn.Linear(self.config.target_hidden_size * 3, + self.config.hidden_size, + bias=False) + else: + self.fc = torch.nn.Linear(self.config.hidden_size * 3, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + assert hidden_states.shape[-1] == inputs_embeds.shape[-1] + residual = None # No residual on the first layer + for layer in self.layers: + hidden_states, residual = layer( + positions, + inputs_embeds, + hidden_states, + residual, + ) + # Normalized features (hidden_states) + # original features (hidden_prenorm) + hidden_states, hidden_prenorm = self.norm(hidden_states, residual) + return hidden_states, hidden_prenorm + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if 'midlayer.' in name: + name = name.replace('midlayer.', 'layers.0.') + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Eagle3Qwen2_5_VLForCausalLM(Qwen2ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.speculative_config.\ + draft_model_config.hf_config + self.multimodal_config = vllm_config.model_config.multimodal_config + + # The number of layers in the target model + # start_layer_id for the draft model + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + # draft model quantization config may differ from target model + quant_config = VllmConfig.get_quantization_config( + vllm_config.speculative_config.draft_model_config, + vllm_config.load_config) + # Initialize the EAGLE model of QWEN2.5 + self.model = Qwen2_5Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "draft_model"), + start_layer_id=target_layer_num, + quant_config=quant_config) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + # Establish a mapping relationship between + # the draft model vocabulary and the target model vocabulary. + self.draft_id_to_target_id = nn.Parameter( + torch.zeros(self.config.draft_vocab_size, dtype=torch.long), + requires_grad=False, + ) + + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.draft_vocab_size, + padding_size=(DEFAULT_VOCAB_PADDING_SIZE), + prefix="") + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + if self.draft_id_to_target_id is None: + assert logits.shape[1] == self.config.vocab_size, \ + "Expected logits to have shape " \ + f"(*, {self.config.vocab_size}), but got {logits.shape}" + return logits + + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + # Mapping to the main model vocabulary space + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full(( + logits.shape[0], + self.config.vocab_size, + ), float('-inf')) + logits_new[:, targets] = logits # Only valid positions are filled + return logits_new + + def combine_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # combine multiple auxiliary hidden states returned by eagle3 + return self.model.fc(hidden_states) + + def load_weights(self, weights): + """ + Load weights + Not shared lm_head with target model + Skip t2d + """ + model_weights = {} + include_draft_id_mapping = False + include_embb_tokens_mapping = False + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + include_draft_id_mapping = True + elif "lm_head" not in name: + name = "model." + name + if "embed_tokens" in name: + include_embb_tokens_mapping = True + model_weights[name] = loaded_weight + + skip_substrs = [] + if not include_draft_id_mapping: + skip_substrs.append("d2t") + if not include_embb_tokens_mapping: + skip_substrs.append("embed_tokens") + # Not shared lm_head with target model + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=skip_substrs, + ) + loader.load_weights(model_weights.items()) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states, inputs_embeds) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b817615b4356..6729716bc82d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -262,6 +262,10 @@ "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "EagleQwen2_5_VLForCausalLM": ("qwen2_5_vl_eagle", \ + "EagleQwen2_5_VLForCausalLM"), + "Eagle3Qwen2_5_VLForCausalLM": ("qwen2_5_vl_eagle3", \ + "Eagle3Qwen2_5_VLForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a8a160a0f995..65e33bc57ee6 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -17,6 +17,8 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models.qwen2_5_vl_eagle3 import ( + Eagle3Qwen2_5_VLForCausalLM) from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -140,7 +142,8 @@ def propose( last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": - assert isinstance(self.model, Eagle3LlamaForCausalLM) + assert isinstance(self.model, \ + (Eagle3LlamaForCausalLM,Eagle3Qwen2_5_VLForCausalLM)) target_hidden_states = self.model.combine_hidden_states( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size @@ -589,6 +592,11 @@ def prepare_inputs( return spec_common_attn_metadata, token_indices + def get_model_name(self, model: nn.Module) -> str: + if hasattr(model, 'module'): # multi-GPU + model = model.module + return model.__class__.__name__ + def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -608,8 +616,13 @@ def load_model(self, target_model: nn.Module) -> None: if supports_multimodal(target_model): # handle multimodality - self.model.config.image_token_index = ( - target_model.config.image_token_index) + if (self.get_model_name(target_model) == + "Qwen2_5_VLForConditionalGeneration"): + self.model.config.image_token_index = ( + target_model.config.image_token_id) + else: + self.model.config.image_token_index = ( + target_model.config.image_token_index) target_language_model = target_model.get_language_model() else: target_language_model = target_model From b242260726f08184c0ddce1c2554d2d59c52e824 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 21 Aug 2025 19:28:46 +0800 Subject: [PATCH 02/32] fix bug Signed-off-by: Junhong --- tests/v1/e2e/test_spec_decode.py | 1 - vllm/model_executor/models/qwen2_5_vl_eagle3.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 6491530595ba..462dc40008b4 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -200,7 +200,6 @@ def test_eagle_correctness( if "Qwen2.5-VL" in model_name and attn_backend == "TREE_ATTN": pytest.skip("TREE ATTN not support Qwen2.5-VL Model yet") - print(f"model_setup={model_setup}") ref_llm = LLM(model=model_name, max_model_len=2048, tensor_parallel_size=tp_size) diff --git a/vllm/model_executor/models/qwen2_5_vl_eagle3.py b/vllm/model_executor/models/qwen2_5_vl_eagle3.py index fa260ee293f4..8b6c5dc65d19 100644 --- a/vllm/model_executor/models/qwen2_5_vl_eagle3.py +++ b/vllm/model_executor/models/qwen2_5_vl_eagle3.py @@ -217,7 +217,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config) logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, + self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale) # Establish a mapping relationship between # the draft model vocabulary and the target model vocabulary. From 6b682de11a394d32e6d27f5d85c9f2581ca9933d Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 21 Aug 2025 21:31:22 +0800 Subject: [PATCH 03/32] support M-RoPE in eagle Signed-off-by: Junhong --- vllm/v1/spec_decode/eagle.py | 215 +++++++++++++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 14 +- 2 files changed, 171 insertions(+), 58 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 65e33bc57ee6..927f9b4c3359 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -77,9 +77,18 @@ def __init__( self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) + # M-RoPE + self.uses_mrope = self.vllm_config.model_config.uses_mrope + if self.uses_mrope: + # M-RoPE need (3, max_num_tokens) + self.positions = torch.zeros((3, self.max_num_tokens), + dtype=torch.int64, + device=device) + else: + # RoPE need (max_num_tokens,) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=device) self.hidden_states = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -173,7 +182,11 @@ def propose( else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions + # M-RoPE + if self.uses_mrope: + self.positions[:, :num_tokens] = target_positions + else: + self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] @@ -191,9 +204,14 @@ def propose( with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): + # M-RoPE + if self.uses_mrope: + forward_positions = self.positions[:, :num_input_tokens] + else: + forward_positions = self.positions[:num_input_tokens] ret_hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:num_input_tokens], + positions=forward_positions, hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) @@ -203,7 +221,11 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - positions = target_positions[last_token_indices] + # M-RoPE + if self.uses_mrope: + positions = target_positions[:, last_token_indices] + else: + positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] if isinstance(attn_metadata, TreeAttentionMetadata): @@ -258,19 +280,25 @@ def propose( # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + # M-RoPE + if self.uses_mrope: + positions += 1 + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions[0] >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where\ + (exceeds_max_model_len.unsqueeze(0), 0,positions) + else: + positions += 1 + exceeds_max_model_len = positions >= self.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) # Increment the sequence lengths. attn_metadata.max_seq_len += 1 @@ -281,14 +309,23 @@ def propose( # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) - # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size + # M-RoPE + if self.uses_mrope: + # all dimensions of positions are the same + block_numbers = clamped_positions[0] // self.block_size + else: + block_numbers = clamped_positions // self.block_size block_ids = attn_metadata.block_table.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + # M-RoPE + if self.uses_mrope: + attn_metadata.slot_mapping = (block_ids * self.block_size + + clamped_positions[0] % self.block_size) + else: + attn_metadata.slot_mapping = (block_ids * self.block_size + + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -297,7 +334,11 @@ def propose( # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions + # M-RoPE + if self.uses_mrope: + self.positions[:,:batch_size] = clamped_positions[0] + else: + self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states if self.is_multimodal_model: inputs_embeds = self.model.get_input_embeddings(input_ids) @@ -312,9 +353,14 @@ def propose( with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): + # M-RoPE + if self.uses_mrope: + forward_positions = self.positions[:, :input_batch_size] + else: + forward_positions = self.positions[:input_batch_size] last_hidden_states, hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:input_batch_size], + positions=forward_positions, hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) @@ -360,34 +406,60 @@ def propose_tree( tree_input_ids = torch.empty(0, device=self.input_ids.device, dtype=self.input_ids.dtype) - tree_positions = torch.empty(0, - device=self.positions.device, - dtype=self.positions.dtype) + # M-RoPE + if self.uses_mrope: + tree_positions = torch.empty((3, 0), + device=self.positions.device, + dtype=self.positions.dtype) + assert positions.dim() == 3 + # Precompute the draft token positions. + flattened_draft_positions = ( + positions.view(3, batch_size, -1) + + self.tree_draft_pos_offsets[:batch_size, :]).unsqueeze(0) + + else: + tree_positions = torch.empty(0, + device=self.positions.device, + dtype=self.positions.dtype) + # Precompute the draft token positions. + flattened_draft_positions = ( + positions.view(batch_size, -1) + + self.tree_draft_pos_offsets[:batch_size, :]) tree_hidden_states = torch.empty(0, device=self.hidden_states.device, dtype=self.hidden_states.dtype) - # Precompute the draft token positions. - flattened_draft_positions = ( - positions.view(batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]) tree_depth = len(self.cu_drafts_per_level) for level in range(tree_depth - 1): - # Get draft positions for RoPE. - draft_positions = positions + (level + 1) - exceeds_max_model_len = (positions + - total_num_drafts) >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - draft_positions = torch.where( - exceeds_max_model_len, - 0, - draft_positions, - ).view(batch_size, -1) + # M-RoPE + if self.uses_mrope: + # Get draft positions for RoPE + draft_positions = positions + (level + 1) + exceeds_max_model_len = (positions[0] + total_num_drafts) >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + draft_positions = torch.where( + exceeds_max_model_len.unsqueeze(0), + 0, + draft_positions, + ).view(3, batch_size, -1) + else: + draft_positions = positions + (level + 1) + exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len + draft_positions = torch.where( + exceeds_max_model_len, + 0, + draft_positions, + ).view(batch_size, -1) if level_num_drafts > 1: # Repeat the positions for each draft at this level. - draft_positions = draft_positions.repeat_interleave( - level_num_drafts, dim=1) + # M-RoPE + if self.uses_mrope: + draft_positions = draft_positions.repeat_interleave( + level_num_drafts, dim=2) + else: + draft_positions = draft_positions.repeat_interleave( + level_num_drafts, dim=1) if num_children > 1: # Repeat draft hidden states for each child. @@ -397,8 +469,13 @@ def propose_tree( # Concatenate the draft tokens, positions, and hidden states. tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1) - tree_positions = torch.cat([tree_positions, draft_positions], - dim=1) + # M-RoPE + if self.uses_mrope: + tree_positions = torch.cat([tree_positions, draft_positions.view(3, -1)], + dim=1) + else: + tree_positions = torch.cat([tree_positions, draft_positions.view(-1)], + dim=1) tree_hidden_states = torch.cat( [tree_hidden_states, draft_hidden_states], dim=1) @@ -430,13 +507,23 @@ def propose_tree( attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - query_positions = flattened_draft_positions[:, level:level + - query_len] - block_numbers = query_positions // self.block_size - block_ids = attn_metadata.block_table.gather(dim=1, + # M-RoPE + if self.uses_mrope: + query_positions = \ + flattened_draft_positions[:, :, level:level + query_len] + block_numbers = query_positions[0] // self.block_size + block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) - slot_mapping = (block_ids * self.block_size + - query_positions % self.block_size) + slot_mapping = (block_ids * self.block_size + + query_positions[0] % self.block_size) + else: + query_positions = flattened_draft_positions[:, level:level + + query_len] + block_numbers = query_positions // self.block_size + block_ids = attn_metadata.block_table.gather(dim=1, + index=block_numbers) + slot_mapping = (block_ids * self.block_size + + query_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -447,7 +534,12 @@ def propose_tree( num_tokens = attn_metadata.num_actual_tokens input_ids = tree_input_ids.view(-1) self.input_ids[:num_tokens] = input_ids - self.positions[:num_tokens] = tree_positions.view(-1) + # M-RoPE + if self.uses_mrope: + self.positions[:, :num_tokens] = \ + tree_positions.view(3, num_tokens) + else: + self.positions[:num_tokens] = tree_positions.view(-1) self.hidden_states[:num_tokens] = tree_hidden_states.view( num_tokens, -1) @@ -461,9 +553,14 @@ def propose_tree( with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): + # M-RoPE + if self.uses_mrope: + forward_positions = self.positions[:, :num_input_tokens] + else: + forward_positions = self.positions[:num_input_tokens] last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], + positions=forward_positions, hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=None, ) @@ -658,6 +755,12 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + # M-RoPE + if self.uses_mrope: + forward_positions = self.positions[:, :num_tokens] + else: + forward_positions = self.positions[:num_tokens] + if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -667,7 +770,7 @@ def dummy_run( self.model( input_ids=input_ids, - positions=self.positions[:num_tokens], + positions=forward_positions, hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8fb9641844fb..cd70eaec58ed 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1845,7 +1845,12 @@ def propose_draft_token_ids( # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[:num_scheduled_tokens] + # M-RoPE + if self.uses_mrope: + target_positions = \ + self.mrope_positions[:, :num_scheduled_tokens] + else: + target_positions = self.positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], @@ -1867,7 +1872,12 @@ def propose_draft_token_ids( target_token_ids = self.input_ids[token_indices] # TODO(woosuk): Support M-RoPE. - target_positions = self.positions[token_indices] + # M-RoPE + if self.uses_mrope: + target_positions = \ + self.mrope_positions[:, :token_indices] + else: + target_positions = self.positions[:token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) From 475ce2bfd4b8b80e77ba5a41a7746f315efb6342 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 21 Aug 2025 22:50:01 +0800 Subject: [PATCH 04/32] fix bug for SupportsEagle3 and graph compile Signed-off-by: Junhong --- vllm/model_executor/models/qwen2_5_vl.py | 4 ++-- vllm/model_executor/models/qwen2_5_vl_eagle3.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 144ffd67a7ef..3f1b5a5997e8 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -61,7 +61,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsEagle3, SupportsMultiModal, SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, @@ -815,7 +815,7 @@ def _get_mm_fields_config( dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, - SupportsQuant): + SupportsQuant, SupportsEagle3): # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/qwen2_5_vl_eagle3.py b/vllm/model_executor/models/qwen2_5_vl_eagle3.py index 8b6c5dc65d19..0513c9013f97 100644 --- a/vllm/model_executor/models/qwen2_5_vl_eagle3.py +++ b/vllm/model_executor/models/qwen2_5_vl_eagle3.py @@ -82,7 +82,14 @@ def forward( return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "inputs_embeds": 0, + }) class Qwen2_5Model(nn.Module): def __init__( From 7df533ca907cd38d698baa01e54dc4fdf7de6478 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 21 Aug 2025 23:33:31 +0800 Subject: [PATCH 05/32] fix bug Signed-off-by: Junhong --- vllm/v1/spec_decode/eagle.py | 27 ++++++++++++++------------- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 927f9b4c3359..095f2fd6afea 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -136,7 +136,7 @@ def propose( self, # [num_tokens] target_token_ids: torch.Tensor, - # [num_tokens] + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, @@ -293,7 +293,8 @@ def propose( # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. clamped_positions = torch.where\ - (exceeds_max_model_len.unsqueeze(0), 0,positions) + (exceeds_max_model_len.unsqueeze(0), \ + torch.zeros_like(positions), positions) else: positions += 1 exceeds_max_model_len = positions >= self.max_model_len @@ -336,7 +337,7 @@ def propose( self.input_ids[:batch_size] = input_ids # M-RoPE if self.uses_mrope: - self.positions[:,:batch_size] = clamped_positions[0] + self.positions[:,:batch_size] = clamped_positions else: self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states @@ -411,16 +412,16 @@ def propose_tree( tree_positions = torch.empty((3, 0), device=self.positions.device, dtype=self.positions.dtype) - assert positions.dim() == 3 - # Precompute the draft token positions. + assert positions.dim() == 3 or 2 + # Precompute the draft token positions. -> (3, B, L) flattened_draft_positions = ( - positions.view(3, batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]).unsqueeze(0) + positions.view(3, batch_size, 1) + + self.tree_draft_pos_offsets[:batch_size, :].unsqueeze(0)) else: - tree_positions = torch.empty(0, - device=self.positions.device, - dtype=self.positions.dtype) + tree_positions = torch.empty((batch_size, 0), + device=self.positions.device, + dtype=self.positions.dtype) # Precompute the draft token positions. flattened_draft_positions = ( positions.view(batch_size, -1) + @@ -474,7 +475,7 @@ def propose_tree( tree_positions = torch.cat([tree_positions, draft_positions.view(3, -1)], dim=1) else: - tree_positions = torch.cat([tree_positions, draft_positions.view(-1)], + tree_positions = torch.cat([tree_positions, draft_positions], dim=1) tree_hidden_states = torch.cat( [tree_hidden_states, draft_hidden_states], dim=1) @@ -509,8 +510,8 @@ def propose_tree( # Compute the slot mapping. # M-RoPE if self.uses_mrope: - query_positions = \ - flattened_draft_positions[:, :, level:level + query_len] + query_positions = flattened_draft_positions[:, :, \ + level:level + query_len] block_numbers = query_positions[0] // self.block_size block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cd70eaec58ed..a9eb244dc830 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1875,9 +1875,9 @@ def propose_draft_token_ids( # M-RoPE if self.uses_mrope: target_positions = \ - self.mrope_positions[:, :token_indices] + self.mrope_positions[:, token_indices] else: - target_positions = self.positions[:token_indices] + target_positions = self.positions[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) From ce18a917f9e4de527c9c7ef7818583ce2fd31186 Mon Sep 17 00:00:00 2001 From: Junhong Date: Mon, 25 Aug 2025 19:46:26 +0800 Subject: [PATCH 06/32] fix bug Signed-off-by: Junhong --- vllm/v1/spec_decode/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 61152cde21f6..14f0774591ce 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -424,7 +424,7 @@ def propose_tree( tree_positions = torch.empty((3, 0), device=self.positions.device, dtype=self.positions.dtype) - assert positions.dim() == 3 or 2 + assert positions.dim() in (2, 3) # Precompute the draft token positions. -> (3, B, L) flattened_draft_positions = ( positions.view(3, batch_size, 1) + From bfde2ca1e794911600ddad97dad92f6d978be888 Mon Sep 17 00:00:00 2001 From: Junhong Date: Tue, 2 Sep 2025 19:42:28 +0800 Subject: [PATCH 07/32] optimize code Signed-off-by: Junhong --- vllm/v1/spec_decode/eagle.py | 69 ++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 14f0774591ce..2d5ace41106f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -92,8 +92,8 @@ def __init__( if self.uses_mrope: # M-RoPE need (3, max_num_tokens) self.positions = torch.zeros((3, self.max_num_tokens), - dtype=torch.int64, - device=device) + dtype=torch.int64, + device=device) else: # RoPE need (max_num_tokens,) self.positions = torch.zeros(self.max_num_tokens, @@ -296,11 +296,13 @@ def propose( if self.uses_mrope: positions += 1 # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. + # generates tokens beyond the max model length. + # Since it is complex to remove such requests from the batch, + # we keep them in the batch but adjust the position ids + # and slot mappings to avoid the + # out-of-range access during the model execution. + # The draft tokens generated with this adjustment + # should be ignored. exceeds_max_model_len = positions[0] >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. @@ -334,11 +336,13 @@ def propose( block_ids = block_ids.view(-1) # M-RoPE if self.uses_mrope: - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions[0] % self.block_size) + attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions[0] % self.block_size) else: - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -416,7 +420,8 @@ def propose_tree( draft_hidden_states = hidden_states.view(batch_size, 1, -1) # Initialize empty tensors for concatenation with the level outputs. - tree_input_ids = torch.empty(0, + # M-RoPE + tree_input_ids = torch.empty((batch_size, 0), device=self.input_ids.device, dtype=self.input_ids.dtype) # M-RoPE @@ -429,16 +434,15 @@ def propose_tree( flattened_draft_positions = ( positions.view(3, batch_size, 1) + self.tree_draft_pos_offsets[:batch_size, :].unsqueeze(0)) - else: tree_positions = torch.empty((batch_size, 0), - device=self.positions.device, - dtype=self.positions.dtype) + device=self.positions.device, + dtype=self.positions.dtype) # Precompute the draft token positions. flattened_draft_positions = ( positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]) - tree_hidden_states = torch.empty(0, + tree_hidden_states = torch.empty((batch_size, 0, self.hidden_size), device=self.hidden_states.device, dtype=self.hidden_states.dtype) tree_depth = len(self.cu_drafts_per_level) @@ -447,7 +451,8 @@ def propose_tree( if self.uses_mrope: # Get draft positions for RoPE draft_positions = positions + (level + 1) - exceeds_max_model_len = (positions[0] + total_num_drafts) >= self.max_model_len + exceeds_max_model_len = ( + positions[0] + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. draft_positions = torch.where( @@ -457,7 +462,8 @@ def propose_tree( ).view(3, batch_size, -1) else: draft_positions = positions + (level + 1) - exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len + exceeds_max_model_len = ( + positions + total_num_drafts) >= self.max_model_len draft_positions = torch.where( exceeds_max_model_len, 0, @@ -484,11 +490,12 @@ def propose_tree( dim=1) # M-RoPE if self.uses_mrope: - tree_positions = torch.cat([tree_positions, draft_positions.view(3, -1)], - dim=1) + tree_positions = torch.cat( + [tree_positions, + draft_positions.view(3, -1)], dim=1) else: tree_positions = torch.cat([tree_positions, draft_positions], - dim=1) + dim=1) tree_hidden_states = torch.cat( [tree_hidden_states, draft_hidden_states], dim=1) @@ -525,16 +532,16 @@ def propose_tree( query_positions = flattened_draft_positions[:, :, \ level:level + query_len] block_numbers = query_positions[0] // self.block_size - block_ids = attn_metadata.block_table.gather(dim=1, - index=block_numbers) + block_ids = attn_metadata.block_table.gather( + dim=1, index=block_numbers) slot_mapping = (block_ids * self.block_size + - query_positions[0] % self.block_size) + query_positions[0] % self.block_size) else: query_positions = flattened_draft_positions[:, level:level + query_len] block_numbers = query_positions // self.block_size - block_ids = attn_metadata.block_table.gather(dim=1, - index=block_numbers) + block_ids = attn_metadata.block_table.gather( + dim=1, index=block_numbers) slot_mapping = (block_ids * self.block_size + query_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. @@ -545,16 +552,18 @@ def propose_tree( # Copy inputs to buffer for cudagraph. num_tokens = attn_metadata.num_actual_tokens - input_ids = tree_input_ids.view(-1) + # M-RoPE + input_ids = tree_input_ids.view(-1)[-num_tokens:] # [B*qlen] self.input_ids[:num_tokens] = input_ids # M-RoPE if self.uses_mrope: self.positions[:, :num_tokens] = \ - tree_positions.view(3, num_tokens) + tree_positions.view(3, -1)[:,-num_tokens:] else: - self.positions[:num_tokens] = tree_positions.view(-1) + self.positions[:num_tokens] = tree_positions.view( + -1)[-num_tokens:] self.hidden_states[:num_tokens] = tree_hidden_states.view( - num_tokens, -1) + num_tokens, -1)[-num_tokens:, :] if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: From b0f2181319e664472b06c6475ab10bdea0c85121 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 4 Sep 2025 00:27:58 +0800 Subject: [PATCH 08/32] llama_eagle3 support mm Signed-off-by: Junhong --- vllm/model_executor/models/llama_eagle3.py | 41 +++++++++++++++++----- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 572930c39a84..f5b67838e60e 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -19,11 +19,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) from vllm.v1.sample.metadata import SamplingMetadata -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings logger = init_logger(__name__) @@ -99,7 +100,14 @@ def forward( return hidden_states, residual -@support_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "inputs_embeds": 0, + }) class LlamaModel(nn.Module): def __init__( @@ -139,13 +147,21 @@ def __init__( eps=self.config.rms_norm_eps, ) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - input_embeds = self.embed_tokens(input_ids) + if input_embeds is None: + input_embeds = self.get_input_embeddings(input_ids) assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None @@ -224,11 +240,7 @@ def forward( hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if inputs_embeds is not None: - raise NotImplementedError( - f"{type(self).__name__} does not support multimodal inputs yet." - ) - return self.model(input_ids, positions, hidden_states) + return self.model(input_ids, positions, hidden_states, inputs_embeds) def compute_logits( self, @@ -286,3 +298,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): skip_substrs=skip_substrs, ) loader.load_weights(model_weights.items()) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds From 9499db47b4b235c4c6516e41c288bbdd84fbbe3a Mon Sep 17 00:00:00 2001 From: Junhong Date: Mon, 22 Sep 2025 14:57:16 +0800 Subject: [PATCH 09/32] [llama_eagle3] delete support_torch_compile Signed-off-by: Junhong --- vllm/model_executor/models/llama_eagle3.py | 8 -------- vllm/v1/spec_decode/eagle.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 3 ++- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index f5b67838e60e..c1588d5f50cf 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -100,14 +100,6 @@ def forward( return hidden_states, residual -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, - # otherwise (seq_len, ). - "positions": -1, - "inputs_embeds": 0, - }) class LlamaModel(nn.Module): def __init__( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 22e1af8c8e90..a9d6fd814d1c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -354,7 +354,7 @@ def propose( self.input_ids[:batch_size] = input_ids # M-RoPE if self.uses_mrope: - self.positions[:,:batch_size] = clamped_positions + self.positions[:, :batch_size] = clamped_positions else: self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d82676ef7ebb..bf3383a0daf0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1829,7 +1829,8 @@ def propose_draft_token_ids( target_positions = \ self.mrope_positions.gpu[:, :num_scheduled_tokens] else: - target_positions = self.positions.gpu[:num_scheduled_tokens] + target_positions = self.positions.gpu[: + num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], From 15f24d2cc3c05c06935ccfab8229018a114bc35b Mon Sep 17 00:00:00 2001 From: Junhong Date: Tue, 23 Sep 2025 11:13:48 +0800 Subject: [PATCH 10/32] [llama_eagle3] delete get_input_embeddings Signed-off-by: Junhong --- vllm/model_executor/models/llama_eagle3.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index c1588d5f50cf..87347570b738 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -291,15 +291,3 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ) loader.load_weights(model_weights.items()) - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds From ef445069b9178e079287c3103e7f1dfb194df967 Mon Sep 17 00:00:00 2001 From: Junhong Date: Tue, 23 Sep 2025 11:25:16 +0800 Subject: [PATCH 11/32] [llama_eagle3] fix get_input_embeddings Signed-off-by: Junhong --- vllm/model_executor/models/llama_eagle3.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 87347570b738..8d8aee13d54b 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -291,3 +291,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ) loader.load_weights(model_weights.items()) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + return inputs_embeds From 55f66b7984dc7f63c16b995438940a5ad9e38941 Mon Sep 17 00:00:00 2001 From: Junhong Date: Tue, 23 Sep 2025 16:36:32 +0800 Subject: [PATCH 12/32] [fix bug] delete duplicated code Signed-off-by: Junhong --- tests/models/registry.py | 2 +- tests/v1/e2e/test_spec_decode.py | 6 - .../model_executor/models/qwen2_5_vl_eagle.py | 208 ----------- .../models/qwen2_5_vl_eagle3.py | 328 ------------------ vllm/model_executor/models/registry.py | 4 - vllm/v1/spec_decode/eagle.py | 131 ++----- 6 files changed, 37 insertions(+), 642 deletions(-) delete mode 100644 vllm/model_executor/models/qwen2_5_vl_eagle.py delete mode 100644 vllm/model_executor/models/qwen2_5_vl_eagle3.py diff --git a/tests/models/registry.py b/tests/models/registry.py index efc51a2f50e2..10586e1f2940 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -593,7 +593,7 @@ def check_available_online( "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL"), - "Eagle3Qwen2_5_VLForCausalLM": _HfExamplesInfo( + "Eagle3LlamaForCausalLM": _HfExamplesInfo( "Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True, speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"), diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index faf9da9617fe..5cd87be0774c 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -148,11 +148,6 @@ def test_ngram_correctness( marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), (("eagle", "eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", 1), False), - pytest.param(("eagle", "Qwen/Qwen2.5-VL-7B-Instruct", - "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), - False, - marks=pytest.mark.skip( - reason="Skipping due to lack of eagle model")), ], ids=[ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 @@ -162,7 +157,6 @@ def test_ngram_correctness( "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm", - "qwen2.5_vl_eagle", "deepseek_eagle" ]) @pytest.mark.parametrize("attn_backend", diff --git a/vllm/model_executor/models/qwen2_5_vl_eagle.py b/vllm/model_executor/models/qwen2_5_vl_eagle.py deleted file mode 100644 index 342740fdff9a..000000000000 --- a/vllm/model_executor/models/qwen2_5_vl_eagle.py +++ /dev/null @@ -1,208 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Iterable -from typing import Optional - -import torch -import torch.nn as nn - -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer, - Qwen2ForCausalLM) -from vllm.sequence import IntermediateTensors - -from .interfaces import MultiModalEmbeddings -from .utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix, - merge_multimodal_embeddings) - -logger = init_logger(__name__) - - -@support_torch_compile -class Qwen2_5Model(nn.Module): - - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - start_layer_id: int = 0, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.config = ( - vllm_config.speculative_config.draft_model_config.hf_config) - self.multimodal_config = (vllm_config.speculative_config. - draft_model_config.multimodal_config) - # embbeding - if get_pp_group().is_first_rank or (self.config.tie_word_embeddings - and get_pp_group().is_last_rank): - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) - else: - self.embed_tokens = PPMissingLayer() - - # language model initial - self.layers = nn.ModuleList([ - Qwen2DecoderLayer( - self.config, - quant_config=quant_config, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - ) for i in range(self.config.num_hidden_layers) - ]) - # Eagle feature fusion - self.fc = torch.nn.Linear(self.config.hidden_size * 2, - self.config.hidden_size, - bias=False) - self.norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) - # Eagle feature fusion - hidden_states = self.fc( - torch.cat((inputs_embeds, hidden_states), dim=-1)) - residual = None - for layer in self.layers: - hidden_states, residual = layer( - positions, - hidden_states, - residual, - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states, hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - # name = name.removeprefix("model.") - # TODO :related to the trained model and may need to be modified - if (name.find("t2d") or name.find("d2t") - or name.find("hidden_norm")) and name not in params_dict: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - # TODO: train a suitable model - if name.startswith("fc"): - loaded_weight = loaded_weight[:, :self.config.hidden_size * - 2] - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class EagleQwen2_5_VLForCausalLM(Qwen2ForCausalLM): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - self.config = vllm_config.speculative_config.\ - draft_model_config.hf_config - self.multimodal_config = vllm_config.model_config.multimodal_config - - # The number of layers in the target model - # start_layer_id for the draft model - target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) - # draft model quantization config may differ from target model - quant_config = VllmConfig.get_quantization_config( - vllm_config.speculative_config.draft_model_config, - vllm_config.load_config) - # Initialize the EAGLE model of QWEN2.5 - self.model = Qwen2_5Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "draft_model"), - start_layer_id=target_layer_num, - quant_config=quant_config) - - logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, - scale=logit_scale) - - def load_weights(self, weights): - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."]), - ) - model_weights = {} - - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - - loader.load_weights(model_weights.items()) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.model(input_ids, positions, hidden_states, inputs_embeds) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds diff --git a/vllm/model_executor/models/qwen2_5_vl_eagle3.py b/vllm/model_executor/models/qwen2_5_vl_eagle3.py deleted file mode 100644 index 0513c9013f97..000000000000 --- a/vllm/model_executor/models/qwen2_5_vl_eagle3.py +++ /dev/null @@ -1,328 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Iterable -from typing import Optional - -import torch -import torch.nn as nn -from transformers import Qwen2Config - -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import QKVParallelLinear -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer, - Qwen2ForCausalLM) -from vllm.sequence import IntermediateTensors -from vllm.v1.sample.metadata import SamplingMetadata - -from .interfaces import MultiModalEmbeddings -from .utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix, - merge_multimodal_embeddings) - -logger = init_logger(__name__) - - -class Qwen2_5DecodeLayer(Qwen2DecoderLayer): - - def __init__( - self, - config: Qwen2Config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, quant_config=quant_config, prefix=prefix) - - # override qkv - self.self_attn.qkv_proj = QKVParallelLinear( - 2 * self.hidden_size, - self.self_attn.head_dim, - self.self_attn.total_num_heads, - self.self_attn.total_num_kv_heads, - bias=False, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "qkv_proj"), - ) - # Add a normalization layer - self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - embeds: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, torch.Tensor]: - residual = hidden_states - embeds = self.input_layernorm(embeds) - hidden_states = self.hidden_norm(hidden_states) - # Reuse the target model's features - hidden_states = torch.cat([embeds, hidden_states], dim=-1) - - # Self Attention - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - # Fully Connected - hidden_states = self.mlp(hidden_states) - - return hidden_states, residual - - -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, - # otherwise (seq_len, ). - "positions": -1, - "inputs_embeds": 0, - }) -class Qwen2_5Model(nn.Module): - - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - start_layer_id: int = 0, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.config = ( - vllm_config.speculative_config.draft_model_config.hf_config) - self.multimodal_config = (vllm_config.speculative_config. - draft_model_config.multimodal_config) - # embbeding - if get_pp_group().is_first_rank or (self.config.tie_word_embeddings - and get_pp_group().is_last_rank): - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) - else: - self.embed_tokens = PPMissingLayer() - - # small language model initialization - self.layers = nn.ModuleList([ - Qwen2_5DecodeLayer( - config=self.config, - quant_config=quant_config, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - ) for i in range(self.config.num_hidden_layers) - ]) - # The EAGLE3 feature fusion layer needs to - # fuse high, medium, and low-level features. - # Therefore, the input size is hidden_size * 3 - if hasattr(self.config, "target_hidden_size"): - self.fc = torch.nn.Linear(self.config.target_hidden_size * 3, - self.config.hidden_size, - bias=False) - else: - self.fc = torch.nn.Linear(self.config.hidden_size * 3, - self.config.hidden_size, - bias=False) - self.norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - ) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) - assert hidden_states.shape[-1] == inputs_embeds.shape[-1] - residual = None # No residual on the first layer - for layer in self.layers: - hidden_states, residual = layer( - positions, - inputs_embeds, - hidden_states, - residual, - ) - # Normalized features (hidden_states) - # original features (hidden_prenorm) - hidden_states, hidden_prenorm = self.norm(hidden_states, residual) - return hidden_states, hidden_prenorm - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if 'midlayer.' in name: - name = name.replace('midlayer.', 'layers.0.') - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Eagle3Qwen2_5_VLForCausalLM(Qwen2ForCausalLM): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - self.config = vllm_config.speculative_config.\ - draft_model_config.hf_config - self.multimodal_config = vllm_config.model_config.multimodal_config - - # The number of layers in the target model - # start_layer_id for the draft model - target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) - # draft model quantization config may differ from target model - quant_config = VllmConfig.get_quantization_config( - vllm_config.speculative_config.draft_model_config, - vllm_config.load_config) - # Initialize the EAGLE model of QWEN2.5 - self.model = Qwen2_5Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "draft_model"), - start_layer_id=target_layer_num, - quant_config=quant_config) - - logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, - scale=logit_scale) - # Establish a mapping relationship between - # the draft model vocabulary and the target model vocabulary. - self.draft_id_to_target_id = nn.Parameter( - torch.zeros(self.config.draft_vocab_size, dtype=torch.long), - requires_grad=False, - ) - - self.lm_head = ParallelLMHead( - self.config.draft_vocab_size, - self.config.hidden_size, - org_num_embeddings=self.config.draft_vocab_size, - padding_size=(DEFAULT_VOCAB_PADDING_SIZE), - prefix="") - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - if self.draft_id_to_target_id is None: - assert logits.shape[1] == self.config.vocab_size, \ - "Expected logits to have shape " \ - f"(*, {self.config.vocab_size}), but got {logits.shape}" - return logits - - base = torch.arange(self.config.draft_vocab_size, device=logits.device) - # Mapping to the main model vocabulary space - targets = base + self.draft_id_to_target_id - logits_new = logits.new_full(( - logits.shape[0], - self.config.vocab_size, - ), float('-inf')) - logits_new[:, targets] = logits # Only valid positions are filled - return logits_new - - def combine_hidden_states( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - # combine multiple auxiliary hidden states returned by eagle3 - return self.model.fc(hidden_states) - - def load_weights(self, weights): - """ - Load weights - Not shared lm_head with target model - Skip t2d - """ - model_weights = {} - include_draft_id_mapping = False - include_embb_tokens_mapping = False - for name, loaded_weight in weights: - if "t2d" in name: - continue - if "d2t" in name: - name = name.replace("d2t", "draft_id_to_target_id") - include_draft_id_mapping = True - elif "lm_head" not in name: - name = "model." + name - if "embed_tokens" in name: - include_embb_tokens_mapping = True - model_weights[name] = loaded_weight - - skip_substrs = [] - if not include_draft_id_mapping: - skip_substrs.append("d2t") - if not include_embb_tokens_mapping: - skip_substrs.append("embed_tokens") - # Not shared lm_head with target model - loader = AutoWeightsLoader( - self, - skip_prefixes=None, - skip_substrs=skip_substrs, - ) - loader.load_weights(model_weights.items()) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.model(input_ids, positions, hidden_states, inputs_embeds) - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_index) - return inputs_embeds diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a7a7de73b988..f236040bb234 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -277,10 +277,6 @@ "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), - "EagleQwen2_5_VLForCausalLM": ("qwen2_5_vl_eagle", \ - "EagleQwen2_5_VLForCausalLM"), - "Eagle3Qwen2_5_VLForCausalLM": ("qwen2_5_vl_eagle3", \ - "Eagle3Qwen2_5_VLForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a9d6fd814d1c..224f3d1ac9dd 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -18,8 +18,6 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.model_executor.models.qwen2_5_vl_eagle3 import ( - Eagle3Qwen2_5_VLForCausalLM) from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -175,8 +173,7 @@ def propose( last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": - assert isinstance(self.model, \ - (Eagle3LlamaForCausalLM,Eagle3Qwen2_5_VLForCausalLM)) + assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size @@ -421,65 +418,37 @@ def propose_tree( draft_hidden_states = hidden_states.view(batch_size, 1, -1) # Initialize empty tensors for concatenation with the level outputs. - # M-RoPE - tree_input_ids = torch.empty((batch_size, 0), + tree_input_ids = torch.empty(0, device=self.input_ids.device, dtype=self.input_ids.dtype) - # M-RoPE - if self.uses_mrope: - tree_positions = torch.empty((3, 0), - device=self.positions.device, - dtype=self.positions.dtype) - assert positions.dim() in (2, 3) - # Precompute the draft token positions. -> (3, B, L) - flattened_draft_positions = ( - positions.view(3, batch_size, 1) + - self.tree_draft_pos_offsets[:batch_size, :].unsqueeze(0)) - else: - tree_positions = torch.empty((batch_size, 0), - device=self.positions.device, - dtype=self.positions.dtype) - # Precompute the draft token positions. - flattened_draft_positions = ( - positions.view(batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]) - tree_hidden_states = torch.empty((batch_size, 0, self.hidden_size), + tree_positions = torch.empty(0, + device=self.positions.device, + dtype=self.positions.dtype) + tree_hidden_states = torch.empty(0, device=self.hidden_states.device, dtype=self.hidden_states.dtype) + # Precompute the draft token positions. + flattened_draft_positions = ( + positions.view(batch_size, -1) + + self.tree_draft_pos_offsets[:batch_size, :]) tree_depth = len(self.cu_drafts_per_level) for level in range(tree_depth - 1): - # M-RoPE - if self.uses_mrope: - # Get draft positions for RoPE - draft_positions = positions + (level + 1) - exceeds_max_model_len = ( - positions[0] + total_num_drafts) >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - draft_positions = torch.where( - exceeds_max_model_len.unsqueeze(0), - 0, - draft_positions, - ).view(3, batch_size, -1) - else: - draft_positions = positions + (level + 1) - exceeds_max_model_len = ( - positions + total_num_drafts) >= self.max_model_len - draft_positions = torch.where( - exceeds_max_model_len, - 0, - draft_positions, - ).view(batch_size, -1) + # Get draft positions for RoPE. + draft_positions = positions + (level + 1) + exceeds_max_model_len = (positions + + total_num_drafts) >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + draft_positions = torch.where( + exceeds_max_model_len, + 0, + draft_positions, + ).view(batch_size, -1) if level_num_drafts > 1: # Repeat the positions for each draft at this level. - # M-RoPE - if self.uses_mrope: - draft_positions = draft_positions.repeat_interleave( - level_num_drafts, dim=2) - else: - draft_positions = draft_positions.repeat_interleave( - level_num_drafts, dim=1) + draft_positions = draft_positions.repeat_interleave( + level_num_drafts, dim=1) if num_children > 1: # Repeat draft hidden states for each child. @@ -489,14 +458,8 @@ def propose_tree( # Concatenate the draft tokens, positions, and hidden states. tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1) - # M-RoPE - if self.uses_mrope: - tree_positions = torch.cat( - [tree_positions, - draft_positions.view(3, -1)], dim=1) - else: - tree_positions = torch.cat([tree_positions, draft_positions], - dim=1) + tree_positions = torch.cat([tree_positions, draft_positions], + dim=1) tree_hidden_states = torch.cat( [tree_hidden_states, draft_hidden_states], dim=1) @@ -528,23 +491,13 @@ def propose_tree( attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - # M-RoPE - if self.uses_mrope: - query_positions = flattened_draft_positions[:, :, \ - level:level + query_len] - block_numbers = query_positions[0] // self.block_size - block_ids = attn_metadata.block_table.gather( - dim=1, index=block_numbers) - slot_mapping = (block_ids * self.block_size + - query_positions[0] % self.block_size) - else: - query_positions = flattened_draft_positions[:, level:level + - query_len] - block_numbers = query_positions // self.block_size - block_ids = attn_metadata.block_table.gather( - dim=1, index=block_numbers) - slot_mapping = (block_ids * self.block_size + - query_positions % self.block_size) + query_positions = flattened_draft_positions[:, level:level + + query_len] + block_numbers = query_positions // self.block_size + block_ids = attn_metadata.block_table.gather(dim=1, + index=block_numbers) + slot_mapping = (block_ids * self.block_size + + query_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -553,18 +506,11 @@ def propose_tree( # Copy inputs to buffer for cudagraph. num_tokens = attn_metadata.num_actual_tokens - # M-RoPE - input_ids = tree_input_ids.view(-1)[-num_tokens:] # [B*qlen] + input_ids = tree_input_ids.view(-1) self.input_ids[:num_tokens] = input_ids - # M-RoPE - if self.uses_mrope: - self.positions[:, :num_tokens] = \ - tree_positions.view(3, -1)[:,-num_tokens:] - else: - self.positions[:num_tokens] = tree_positions.view( - -1)[-num_tokens:] + self.positions[:num_tokens] = tree_positions.view(-1) self.hidden_states[:num_tokens] = tree_hidden_states.view( - num_tokens, -1)[-num_tokens:, :] + num_tokens, -1) if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: @@ -576,14 +522,9 @@ def propose_tree( with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - # M-RoPE - if self.uses_mrope: - forward_positions = self.positions[:, :num_input_tokens] - else: - forward_positions = self.positions[:num_input_tokens] last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], - positions=forward_positions, + positions=self.positions[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=None, ) From 026dde743b73a86c92bc08633f9bc3fa1c79373e Mon Sep 17 00:00:00 2001 From: Junhong Date: Wed, 24 Sep 2025 10:50:15 +0800 Subject: [PATCH 13/32] fix pre-commit Signed-off-by: Junhong --- tests/v1/e2e/test_spec_decode.py | 8 ++++---- vllm/model_executor/models/llama_eagle3.py | 3 +-- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/v1/spec_decode/eagle.py | 4 ++-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 6b8d492273c9..53ee44b68337 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -128,7 +128,7 @@ def test_ngram_correctness( @pytest.mark.parametrize(["model_setup", "mm_enabled"], [ (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), (("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct", - "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), False), + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), False), (("eagle", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", @@ -147,9 +147,9 @@ def test_ngram_correctness( "eagle618/eagle-deepseek-v3-random", 1), False), ], ids=[ - "qwen3_eagle3", "qwen2.5_vl_eagle3", "llama3_eagle", - "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm", - "deepseek_eagle" + "qwen3_eagle3", "qwen2.5_vl_eagle3", + "llama3_eagle", "llama3_eagle3", "llama4_eagle", + "llama4_eagle_mm", "deepseek_eagle" ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 19504a0dc8a1..55b6ae6ee0e9 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -8,7 +8,6 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -23,7 +22,7 @@ from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) -from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings +from .utils import AutoWeightsLoader, maybe_prefix logger = init_logger(__name__) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6b6aa724b046..040edac7d2b3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -65,7 +65,7 @@ from vllm.utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsEagle3, +from .interfaces import (MultiModalEmbeddings, SupportsEagle3, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 360ff4651ac4..4de0ac534747 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -273,7 +273,7 @@ def propose( else: positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] - + if isinstance(attn_metadata, TreeAttentionMetadata): # Draft using tree attention. draft_token_ids_list = self.propose_tree( @@ -947,7 +947,7 @@ def dummy_run( forward_positions = self.positions[:, :num_tokens] else: forward_positions = self.positions[:num_tokens] - + if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] From fa20a26efcabaa59a0356ebdd5160142ccee4a3f Mon Sep 17 00:00:00 2001 From: Junhong Date: Wed, 24 Sep 2025 11:42:23 +0800 Subject: [PATCH 14/32] fix pre-commit Signed-off-by: Junhong --- tests/models/registry.py | 2 +- vllm/model_executor/models/registry.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 70a20378a05c..6e1fe41c7445 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -642,7 +642,7 @@ def check_available_online( "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL"), - "Eagle3LlamaForCausalLM": _HfExamplesInfo( + "Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo( "Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True, speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"), diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6ab3fa902c38..2b653e8c5ff7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -284,6 +284,7 @@ "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), From 50b13f549827b8cb8fac0f64be14314f52c40f51 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 10:44:59 +0800 Subject: [PATCH 15/32] fix bug Signed-off-by: Junhong --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4de0ac534747..9a1a93bb10b5 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -361,11 +361,11 @@ def propose( block_ids = block_ids.view(-1) # M-RoPE if self.uses_mrope: - attn_metadata.slot_mapping = ( + common_attn_metadata.slot_mapping = ( block_ids * self.block_size + clamped_positions[0] % self.block_size) else: - attn_metadata.slot_mapping = ( + common_attn_metadata.slot_mapping = ( block_ids * self.block_size + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. From ef940da267199ec39fc657b27dfba88f34321a15 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 11:26:18 +0800 Subject: [PATCH 16/32] add benchmark_run_mmstar Signed-off-by: Junhong --- benchmarks/benchmarks_run_mmstar.py | 200 ++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 benchmarks/benchmarks_run_mmstar.py diff --git a/benchmarks/benchmarks_run_mmstar.py b/benchmarks/benchmarks_run_mmstar.py new file mode 100644 index 000000000000..8204fb560b88 --- /dev/null +++ b/benchmarks/benchmarks_run_mmstar.py @@ -0,0 +1,200 @@ +import argparse +import os +import shutil +import time +import io +import base64 +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Dict, Any, List, Tuple + +import requests +from datasets import load_dataset +from PIL import Image + + +def to_data_url(image_path: str, fmt: str = "JPEG") -> str: + """Read image from disk and convert to base64 data URL (robust for vLLM OpenAI server).""" + with Image.open(image_path).convert("RGB") as img: + buf = io.BytesIO() + img.save(buf, format=fmt) + b64 = base64.b64encode(buf.getvalue()).decode("utf-8") + mime = "image/jpeg" if fmt.upper() == "JPEG" else f"image/{fmt.lower()}" + return f"data:{mime};base64,{b64}" + + +def build_messages(question_text: str, image_url: str) -> List[Dict[str, Any]]: + """Mimic the SGLang prompt: one user message containing image + question text.""" + return [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": question_text}, + ], + } + ] + + +def chat_once( + api_base: str, + model: str, + messages: List[Dict[str, Any]], + max_tokens: int = 2048, + temperature: float = 0.0, + top_p: float = 1.0, + top_k: int = -1, + repetition_penalty: float = 1.0, + seed: int = -1, + timeout: int = 120, + api_key: str = "", +) -> Tuple[str, int]: + """ + Send one /chat/completions request to vLLM OpenAI-compatible server. + Return (answer_text, completion_tokens). + """ + url = api_base.rstrip("/") + "/chat/completions" + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + payload: Dict[str, Any] = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + } + if top_k is not None and top_k >= 0: + payload["top_k"] = top_k + if seed is not None and seed >= 0: + payload["seed"] = seed + + resp = requests.post(url, json=payload, headers=headers, timeout=timeout) + resp.raise_for_status() + data = resp.json() + text = data["choices"][0]["message"]["content"] + usage = data.get("usage") or {} + completion_tokens = int(usage.get("completion_tokens") or 0) + return text, completion_tokens + + +def main(args): + # Apply deterministic overrides if requested + temperature = args.temperature + top_p = args.top_p + top_k = args.top_k + repetition_penalty = args.repetition_penalty + seed = args.seed + + if args.deterministic: + # for deterministic test + temperature = 0.0 + top_p = 1.0 + top_k = -1 + repetition_penalty = 1.0 + if seed is None or seed < 0: + seed = 42 + + # Prepare cache dirs (same structure as the SGLang script) + cache_dir = os.path.join(".cache", "mmstar") + image_dir = os.path.join(cache_dir, "images") + os.makedirs(cache_dir, exist_ok=True) + os.makedirs(image_dir, exist_ok=True) + print(f"Created temporary image directory: {cache_dir}") + + # Read data (MMStar val) + dataset = load_dataset("Lin-Chen/MMStar")["val"] + + # Build requests + requests_payload = [] + for idx, q in enumerate(dataset): + if idx >= args.num_questions: + break + # Save image to expected nested path under cache_dir (e.g., images/2.jpg) + rel_path = q["meta_info"]["image_path"] # e.g., "images/2.jpg" + image_path = os.path.join(cache_dir, rel_path) + os.makedirs(os.path.dirname(image_path), exist_ok=True) + q["image"].convert("RGB").save(image_path, "JPEG") + + # Strip options from question text, same as SGL script + question_text = q["question"].split("Options:", 1)[0].strip() + + # Use data URL so we don't depend on --allowed-local-media-path + img_url = to_data_url(image_path, fmt="JPEG") + + messages = build_messages(question_text, img_url) + requests_payload.append(messages) + + # Fire requests (parallel similar to num_threads in SGLang run_batch) + tic = time.perf_counter() + completion_tokens_sum = 0 + answers: List[str] = [""] * len(requests_payload) + + def submit_one(messages): + return chat_once( + args.api_base, + args.model, + messages, + args.max_new_tokens, + temperature, + top_p, + top_k, + repetition_penalty, + seed, + args.timeout, + args.api_key, + ) + + with ThreadPoolExecutor(max_workers=args.parallel) as ex: + futures = {ex.submit(submit_one, messages): i for i, messages in enumerate(requests_payload)} + for fut in as_completed(futures): + i = futures[fut] + try: + text, ctk = fut.result() + print(f"text={text},ctk={ctk}") + answers[i] = text + completion_tokens_sum += ctk + except Exception as e: + answers[i] = f"[ERROR] {e}" + + latency = time.perf_counter() - tic + + # Compute throughput (tokens/s) — matches SGL's "Output throughput" + output_throughput = completion_tokens_sum / latency if latency > 0 else 0.0 + + # Accept length: SGLang reports spec_verify_ct; not available via OpenAI API -> set to 1.0 + # accept_length = 1.0 + + # Print results (same fields as SGL script) + print(f"Latency: {latency:.3f} s") + print(f"Output throughput: {output_throughput:.3f} token/s") + + # Cleanup + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + print(f"Deleted temporary directory: {cache_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Keep SGL-like knobs + parser.add_argument("--num-questions", type=int, default=20) + parser.add_argument("--parallel", type=int, default=8, help="Number of concurrent requests") + # vLLM OpenAI-compatible endpoint args + parser.add_argument("--api-base", type=str, default="http://127.0.0.1:8080/v1", help="vLLM OpenAI-compatible base URL") + # If you didn't set --served-model-name when launching vLLM, set this to your served name or path + parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", help="Served model name or path recognized by vLLM") + parser.add_argument("--api-key", type=str, default="", help="Bearer token if your server requires auth") + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--timeout", type=int, default=120) + + # Sampling / determinism controls + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=-1, help="-1 to disable; non-negative to enable") + parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--seed", type=int, default=-1, help="Fixed RNG seed; <0 means unset") + parser.add_argument("--deterministic", action="store_true", help="Force deterministic-like settings: temp=0, top_p=1, top_k=-1, rep_penalty=1, seed=42 if unset") + + args = parser.parse_args() + main(args) \ No newline at end of file From ae085017330a65215cff0744f56120c9b6204b3d Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 11:41:26 +0800 Subject: [PATCH 17/32] fix benchmark Signed-off-by: Junhong --- benchmarks/benchmarks_run_mmstar.py | 52 +++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/benchmarks/benchmarks_run_mmstar.py b/benchmarks/benchmarks_run_mmstar.py index 8204fb560b88..52c2eaecc0c6 100644 --- a/benchmarks/benchmarks_run_mmstar.py +++ b/benchmarks/benchmarks_run_mmstar.py @@ -1,11 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse +import base64 +import io import os import shutil import time -import io -import base64 from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, Any, List, Tuple +from typing import Any, Dict, List, Tuple import requests from datasets import load_dataset @@ -146,7 +148,10 @@ def submit_one(messages): ) with ThreadPoolExecutor(max_workers=args.parallel) as ex: - futures = {ex.submit(submit_one, messages): i for i, messages in enumerate(requests_payload)} + futures = { + ex.submit(submit_one, messages): i + for i, messages in enumerate(requests_payload) + } for fut in as_completed(futures): i = futures[fut] try: @@ -179,22 +184,47 @@ def submit_one(messages): parser = argparse.ArgumentParser() # Keep SGL-like knobs parser.add_argument("--num-questions", type=int, default=20) - parser.add_argument("--parallel", type=int, default=8, help="Number of concurrent requests") + parser.add_argument( + "--parallel", type=int, default=8, help="Number of concurrent requests" + ) # vLLM OpenAI-compatible endpoint args - parser.add_argument("--api-base", type=str, default="http://127.0.0.1:8080/v1", help="vLLM OpenAI-compatible base URL") + parser.add_argument( + "--api-base", + type=str, + default="http://127.0.0.1:8080/v1", + help="vLLM OpenAI-compatible base URL", + ) # If you didn't set --served-model-name when launching vLLM, set this to your served name or path - parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", help="Served model name or path recognized by vLLM") - parser.add_argument("--api-key", type=str, default="", help="Bearer token if your server requires auth") + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-VL-7B-Instruct", + help="Served model name or path recognized by vLLM", + ) + parser.add_argument( + "--api-key", + type=str, + default="", + help="*** if your server requires auth", + ) parser.add_argument("--max-new-tokens", type=int, default=2048) parser.add_argument("--timeout", type=int, default=120) # Sampling / determinism controls parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--top-p", type=float, default=1.0) - parser.add_argument("--top-k", type=int, default=-1, help="-1 to disable; non-negative to enable") + parser.add_argument( + "--top-k", type=int, default=-1, help="-1 to disable; non-negative to enable" + ) parser.add_argument("--repetition-penalty", type=float, default=1.0) - parser.add_argument("--seed", type=int, default=-1, help="Fixed RNG seed; <0 means unset") - parser.add_argument("--deterministic", action="store_true", help="Force deterministic-like settings: temp=0, top_p=1, top_k=-1, rep_penalty=1, seed=42 if unset") + parser.add_argument( + "--seed", type=int, default=-1, help="Fixed RNG seed; <0 means unset" + ) + parser.add_argument( + "--deterministic", + action="store_true", + help="Force deterministic-like settings: temp=0, top_p=1, top_k=-1, rep_penalty=1, seed=42 if unset", + ) args = parser.parse_args() main(args) \ No newline at end of file From 1447be534ea6129d862fbe52271dec9f9aa66101 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 11:50:11 +0800 Subject: [PATCH 18/32] fix benchmark Signed-off-by: Junhong --- benchmarks/benchmarks_run_mmstar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarks_run_mmstar.py b/benchmarks/benchmarks_run_mmstar.py index 52c2eaecc0c6..e2a29dc4d78a 100644 --- a/benchmarks/benchmarks_run_mmstar.py +++ b/benchmarks/benchmarks_run_mmstar.py @@ -227,4 +227,4 @@ def submit_one(messages): ) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) From 6e15b6e1faad512f67d4bf29e5faa61f668720b2 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 12:02:21 +0800 Subject: [PATCH 19/32] fix benchmark Signed-off-by: Junhong --- benchmarks/benchmarks_run_mmstar.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/benchmarks/benchmarks_run_mmstar.py b/benchmarks/benchmarks_run_mmstar.py index e2a29dc4d78a..abc9ca960b2e 100644 --- a/benchmarks/benchmarks_run_mmstar.py +++ b/benchmarks/benchmarks_run_mmstar.py @@ -7,7 +7,7 @@ import shutil import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Dict, List, Tuple +from typing import Any import requests from datasets import load_dataset @@ -15,7 +15,7 @@ def to_data_url(image_path: str, fmt: str = "JPEG") -> str: - """Read image from disk and convert to base64 data URL (robust for vLLM OpenAI server).""" + """Read image from disk and convert to base64 data URL.""" with Image.open(image_path).convert("RGB") as img: buf = io.BytesIO() img.save(buf, format=fmt) @@ -24,7 +24,7 @@ def to_data_url(image_path: str, fmt: str = "JPEG") -> str: return f"data:{mime};base64,{b64}" -def build_messages(question_text: str, image_url: str) -> List[Dict[str, Any]]: +def build_messages(question_text: str, image_url: str) -> list[dict[str, Any]]: """Mimic the SGLang prompt: one user message containing image + question text.""" return [ { @@ -40,7 +40,7 @@ def build_messages(question_text: str, image_url: str) -> List[Dict[str, Any]]: def chat_once( api_base: str, model: str, - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], max_tokens: int = 2048, temperature: float = 0.0, top_p: float = 1.0, @@ -49,7 +49,7 @@ def chat_once( seed: int = -1, timeout: int = 120, api_key: str = "", -) -> Tuple[str, int]: +) -> tuple[str, int]: """ Send one /chat/completions request to vLLM OpenAI-compatible server. Return (answer_text, completion_tokens). @@ -58,7 +58,7 @@ def chat_once( headers = {} if api_key: headers["Authorization"] = f"Bearer {api_key}" - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "model": model, "messages": messages, "max_tokens": max_tokens, @@ -130,7 +130,7 @@ def main(args): # Fire requests (parallel similar to num_threads in SGLang run_batch) tic = time.perf_counter() completion_tokens_sum = 0 - answers: List[str] = [""] * len(requests_payload) + answers: list[str] = [""] * len(requests_payload) def submit_one(messages): return chat_once( @@ -167,7 +167,8 @@ def submit_one(messages): # Compute throughput (tokens/s) — matches SGL's "Output throughput" output_throughput = completion_tokens_sum / latency if latency > 0 else 0.0 - # Accept length: SGLang reports spec_verify_ct; not available via OpenAI API -> set to 1.0 + # Accept length: SGLang reports spec_verify_ct + # not available via OpenAI API -> set to 1.0 # accept_length = 1.0 # Print results (same fields as SGL script) @@ -194,7 +195,8 @@ def submit_one(messages): default="http://127.0.0.1:8080/v1", help="vLLM OpenAI-compatible base URL", ) - # If you didn't set --served-model-name when launching vLLM, set this to your served name or path + # If you didn't set --served-model-name when launching vLLM, + # set this to your served name or path parser.add_argument( "--model", type=str, @@ -223,7 +225,8 @@ def submit_one(messages): parser.add_argument( "--deterministic", action="store_true", - help="Force deterministic-like settings: temp=0, top_p=1, top_k=-1, rep_penalty=1, seed=42 if unset", + help="Force deterministic-like settings: temp=0, top_p=1, " \ + "top_k=-1, rep_penalty=1, seed=42 if unset", ) args = parser.parse_args() From 8cd5a98f4573085a82055f3c5404d39ed005f7c5 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 14:26:52 +0800 Subject: [PATCH 20/32] fix benchmark Signed-off-by: Junhong --- benchmarks/benchmarks_run_mmstar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarks_run_mmstar.py b/benchmarks/benchmarks_run_mmstar.py index abc9ca960b2e..3305eaadaaf7 100644 --- a/benchmarks/benchmarks_run_mmstar.py +++ b/benchmarks/benchmarks_run_mmstar.py @@ -225,7 +225,7 @@ def submit_one(messages): parser.add_argument( "--deterministic", action="store_true", - help="Force deterministic-like settings: temp=0, top_p=1, " \ + help="Force deterministic-like settings: temp=0, top_p=1, " "top_k=-1, rep_penalty=1, seed=42 if unset", ) From d2bc9e5e359d313e06ecb429507563788d255114 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 14:49:54 +0800 Subject: [PATCH 21/32] fix pre-commit Signed-off-by: Junhong --- tests/v1/e2e/test_spec_decode.py | 2 +- vllm/v1/spec_decode/eagle.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 53ee44b68337..d7a063171ad1 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -147,7 +147,7 @@ def test_ngram_correctness( "eagle618/eagle-deepseek-v3-random", 1), False), ], ids=[ - "qwen3_eagle3", "qwen2.5_vl_eagle3", + "qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm", "deepseek_eagle" ]) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3418eaf50211..ca60a6c6d185 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -254,7 +254,7 @@ def propose( draft_token_ids = logits.argmax(dim=-1) return draft_token_ids.view(-1, 1) - # M-RoPE + # M-RoPE if self.uses_mrope: positions = target_positions[:, last_token_indices] else: From 50f31df638fc8b41fa2c63fda90204e852f258dd Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 15:22:19 +0800 Subject: [PATCH 22/32] fix test Signed-off-by: Junhong --- tests/v1/e2e/test_spec_decode.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index d7a063171ad1..338997834c41 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -188,8 +188,9 @@ def test_eagle_correctness( method, model_name, spec_model_name, tp_size = model_setup - if "Qwen2.5-VL" in model_name and attn_backend == "TREE_ATTN": - pytest.skip("TREE ATTN not support Qwen2.5-VL Model yet") + if "Qwen2.5-VL" in model_name: + pytest.skip("FLASH_ATTN_VLLM_V1 does not support Qwen2.5-VL " + "due to its head_dim not being a a multiple of 32") ref_llm = LLM(model=model_name, max_model_len=2048, tensor_parallel_size=tp_size) From b629dbb3b25c62825180066c8c8fbdb67d61e833 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 19:43:35 +0800 Subject: [PATCH 23/32] fix bug Signed-off-by: Junhong --- tests/models/registry.py | 1 - vllm/benchmarks/datasets.py | 81 ++++++++++++++++++++++++ vllm/model_executor/models/qwen2_5_vl.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 2 - 4 files changed, 83 insertions(+), 5 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index b93f9045e9f7..874890c7bc75 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -650,7 +650,6 @@ def check_available_online( speculative_model="XiaomiMiMo/MiMo-7B-RL"), "Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo( "Qwen/Qwen2.5-VL-7B-Instruct", - trust_remote_code=True, speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"), "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3"), diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 68a937d5750e..59ddb26f214a 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ): dataset_class = MLPerfDataset args.hf_split = "train" + elif ( + args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MMStarDataset + args.hf_split = "val" + args.hf_subset = None else: supported_datasets = set([ dataset_name for cls in HuggingFaceDataset.__subclasses__() @@ -2721,3 +2728,77 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]: random.shuffle(requests) return requests + + +# ----------------------------------------------------------------------------- +# MMStar Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MMStarDataset(HuggingFaceDataset): + """ + Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar + """ + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"} + IS_MULTIMODAL = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list[SampleRequest]: + # If --hf-output-len is not set, use the default output length. + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests: list[SampleRequest] = [] + ind = 0 + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + # Split the question text from options + # (keep only the part before "Options:"). + full_q: str = item.get("question", "") + question_text = full_q.split("Options:", 1)[0].strip() + + # Multimodal image content. + mm_content = process_image(item["image"]) + + # Compute prompt token length (note: this is plain text length + # if enable_multimodal_chat is False). + prompt_len = len(tokenizer(question_text).input_ids) + + if enable_multimodal_chat: + # If multimodal content should be embedded in the chat message, + # convert to [{"role":"user","content":[...]}] + prompt = self.apply_multimodal_chat_transformation( + question_text, mm_content + ) + mm_for_request = None # Already embedded in chat content. + else: + # Default: prompt is plain text, + # image is in mm_content for the bench to assemble. + prompt = question_text + mm_for_request = mm_content + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_for_request, + request_id=request_id_prefix + str(ind), + ) + ) + ind += 1 + + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) + return sampled_requests diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 040edac7d2b3..822345efc339 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -974,10 +974,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.language_model.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.language_model.model.layers) return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 520cb0e9eb4a..4d9d20c0caf0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2435,7 +2435,6 @@ def propose_draft_token_ids( token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] - # TODO(woosuk): Support M-RoPE. # M-RoPE if self.uses_mrope: target_positions = \ @@ -2467,7 +2466,6 @@ def propose_draft_token_ids( valid_sampled_tokens_count) target_token_ids = self.input_ids.gpu[token_indices] - # TODO(woosuk): Support M-RoPE. # M-RoPE if self.uses_mrope: target_positions = \ From 727be4423df3829ae829fe08b67b60c67effd751 Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 19:48:39 +0800 Subject: [PATCH 24/32] fix bug Signed-off-by: Junhong --- vllm/benchmarks/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 59ddb26f214a..bd00b501b84e 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -2738,6 +2738,7 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]: class MMStarDataset(HuggingFaceDataset): """ Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar + refer to: https://github.com/sgl-project/SpecForge/pull/106 """ DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"} From f8b06517c42f91a0eddfcbefd60c35d811514b7f Mon Sep 17 00:00:00 2001 From: Junhong Date: Thu, 25 Sep 2025 20:02:20 +0800 Subject: [PATCH 25/32] fix pre-commit Signed-off-by: Junhong --- vllm/benchmarks/datasets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index bd00b501b84e..f0c0d829a393 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -2758,9 +2758,8 @@ def sample( output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) sampled_requests: list[SampleRequest] = [] - ind = 0 - for item in self.data: + for ind, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break # Split the question text from options @@ -2797,7 +2796,6 @@ def sample( request_id=request_id_prefix + str(ind), ) ) - ind += 1 self.maybe_oversample_requests( sampled_requests, num_requests, request_id_prefix, no_oversample From 0a83ffb911b2a969d9efb1e04b8757b57d0d1da2 Mon Sep 17 00:00:00 2001 From: Junhong Date: Fri, 26 Sep 2025 16:36:13 +0800 Subject: [PATCH 26/32] delete benchmarks\benchmarks_run_mmstar.py Signed-off-by: Junhong --- benchmarks/benchmarks_run_mmstar.py | 233 ---------------------------- 1 file changed, 233 deletions(-) delete mode 100644 benchmarks/benchmarks_run_mmstar.py diff --git a/benchmarks/benchmarks_run_mmstar.py b/benchmarks/benchmarks_run_mmstar.py deleted file mode 100644 index 3305eaadaaf7..000000000000 --- a/benchmarks/benchmarks_run_mmstar.py +++ /dev/null @@ -1,233 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import base64 -import io -import os -import shutil -import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any - -import requests -from datasets import load_dataset -from PIL import Image - - -def to_data_url(image_path: str, fmt: str = "JPEG") -> str: - """Read image from disk and convert to base64 data URL.""" - with Image.open(image_path).convert("RGB") as img: - buf = io.BytesIO() - img.save(buf, format=fmt) - b64 = base64.b64encode(buf.getvalue()).decode("utf-8") - mime = "image/jpeg" if fmt.upper() == "JPEG" else f"image/{fmt.lower()}" - return f"data:{mime};base64,{b64}" - - -def build_messages(question_text: str, image_url: str) -> list[dict[str, Any]]: - """Mimic the SGLang prompt: one user message containing image + question text.""" - return [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": question_text}, - ], - } - ] - - -def chat_once( - api_base: str, - model: str, - messages: list[dict[str, Any]], - max_tokens: int = 2048, - temperature: float = 0.0, - top_p: float = 1.0, - top_k: int = -1, - repetition_penalty: float = 1.0, - seed: int = -1, - timeout: int = 120, - api_key: str = "", -) -> tuple[str, int]: - """ - Send one /chat/completions request to vLLM OpenAI-compatible server. - Return (answer_text, completion_tokens). - """ - url = api_base.rstrip("/") + "/chat/completions" - headers = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - payload: dict[str, Any] = { - "model": model, - "messages": messages, - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - } - if top_k is not None and top_k >= 0: - payload["top_k"] = top_k - if seed is not None and seed >= 0: - payload["seed"] = seed - - resp = requests.post(url, json=payload, headers=headers, timeout=timeout) - resp.raise_for_status() - data = resp.json() - text = data["choices"][0]["message"]["content"] - usage = data.get("usage") or {} - completion_tokens = int(usage.get("completion_tokens") or 0) - return text, completion_tokens - - -def main(args): - # Apply deterministic overrides if requested - temperature = args.temperature - top_p = args.top_p - top_k = args.top_k - repetition_penalty = args.repetition_penalty - seed = args.seed - - if args.deterministic: - # for deterministic test - temperature = 0.0 - top_p = 1.0 - top_k = -1 - repetition_penalty = 1.0 - if seed is None or seed < 0: - seed = 42 - - # Prepare cache dirs (same structure as the SGLang script) - cache_dir = os.path.join(".cache", "mmstar") - image_dir = os.path.join(cache_dir, "images") - os.makedirs(cache_dir, exist_ok=True) - os.makedirs(image_dir, exist_ok=True) - print(f"Created temporary image directory: {cache_dir}") - - # Read data (MMStar val) - dataset = load_dataset("Lin-Chen/MMStar")["val"] - - # Build requests - requests_payload = [] - for idx, q in enumerate(dataset): - if idx >= args.num_questions: - break - # Save image to expected nested path under cache_dir (e.g., images/2.jpg) - rel_path = q["meta_info"]["image_path"] # e.g., "images/2.jpg" - image_path = os.path.join(cache_dir, rel_path) - os.makedirs(os.path.dirname(image_path), exist_ok=True) - q["image"].convert("RGB").save(image_path, "JPEG") - - # Strip options from question text, same as SGL script - question_text = q["question"].split("Options:", 1)[0].strip() - - # Use data URL so we don't depend on --allowed-local-media-path - img_url = to_data_url(image_path, fmt="JPEG") - - messages = build_messages(question_text, img_url) - requests_payload.append(messages) - - # Fire requests (parallel similar to num_threads in SGLang run_batch) - tic = time.perf_counter() - completion_tokens_sum = 0 - answers: list[str] = [""] * len(requests_payload) - - def submit_one(messages): - return chat_once( - args.api_base, - args.model, - messages, - args.max_new_tokens, - temperature, - top_p, - top_k, - repetition_penalty, - seed, - args.timeout, - args.api_key, - ) - - with ThreadPoolExecutor(max_workers=args.parallel) as ex: - futures = { - ex.submit(submit_one, messages): i - for i, messages in enumerate(requests_payload) - } - for fut in as_completed(futures): - i = futures[fut] - try: - text, ctk = fut.result() - print(f"text={text},ctk={ctk}") - answers[i] = text - completion_tokens_sum += ctk - except Exception as e: - answers[i] = f"[ERROR] {e}" - - latency = time.perf_counter() - tic - - # Compute throughput (tokens/s) — matches SGL's "Output throughput" - output_throughput = completion_tokens_sum / latency if latency > 0 else 0.0 - - # Accept length: SGLang reports spec_verify_ct - # not available via OpenAI API -> set to 1.0 - # accept_length = 1.0 - - # Print results (same fields as SGL script) - print(f"Latency: {latency:.3f} s") - print(f"Output throughput: {output_throughput:.3f} token/s") - - # Cleanup - if os.path.exists(cache_dir): - shutil.rmtree(cache_dir) - print(f"Deleted temporary directory: {cache_dir}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Keep SGL-like knobs - parser.add_argument("--num-questions", type=int, default=20) - parser.add_argument( - "--parallel", type=int, default=8, help="Number of concurrent requests" - ) - # vLLM OpenAI-compatible endpoint args - parser.add_argument( - "--api-base", - type=str, - default="http://127.0.0.1:8080/v1", - help="vLLM OpenAI-compatible base URL", - ) - # If you didn't set --served-model-name when launching vLLM, - # set this to your served name or path - parser.add_argument( - "--model", - type=str, - default="Qwen/Qwen2.5-VL-7B-Instruct", - help="Served model name or path recognized by vLLM", - ) - parser.add_argument( - "--api-key", - type=str, - default="", - help="*** if your server requires auth", - ) - parser.add_argument("--max-new-tokens", type=int, default=2048) - parser.add_argument("--timeout", type=int, default=120) - - # Sampling / determinism controls - parser.add_argument("--temperature", type=float, default=0.0) - parser.add_argument("--top-p", type=float, default=1.0) - parser.add_argument( - "--top-k", type=int, default=-1, help="-1 to disable; non-negative to enable" - ) - parser.add_argument("--repetition-penalty", type=float, default=1.0) - parser.add_argument( - "--seed", type=int, default=-1, help="Fixed RNG seed; <0 means unset" - ) - parser.add_argument( - "--deterministic", - action="store_true", - help="Force deterministic-like settings: temp=0, top_p=1, " - "top_k=-1, rep_penalty=1, seed=42 if unset", - ) - - args = parser.parse_args() - main(args) From 63b726a8dcc63feaf374f4036a85fbed25cb603b Mon Sep 17 00:00:00 2001 From: Junhong Date: Sat, 27 Sep 2025 00:02:10 +0800 Subject: [PATCH 27/32] fix bug Signed-off-by: Junhong --- tests/v1/e2e/test_spec_decode.py | 3 -- vllm/model_executor/models/qwen2_5_vl.py | 8 ++--- vllm/v1/spec_decode/eagle.py | 43 +++++++++++------------- vllm/v1/worker/gpu_model_runner.py | 24 +++++++------ 4 files changed, 36 insertions(+), 42 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 7cac78d013c2..c2b184041a74 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -187,9 +187,6 @@ def test_eagle_correctness( method, model_name, spec_model_name, tp_size = model_setup - if "Qwen2.5-VL" in model_name: - pytest.skip("FLASH_ATTN_VLLM_V1 does not support Qwen2.5-VL " - "due to its head_dim not being a a multiple of 32") ref_llm = LLM(model=model_name, max_model_len=2048, tensor_parallel_size=tp_size) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 2aae82a7bd04..f84f6b088b18 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -317,14 +317,13 @@ def __init__( if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, _Backend.FLASH_ATTN_VLLM_V1 + _Backend.ROCM_AITER_FA }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN_VLLM_V1 + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -739,8 +738,7 @@ def compute_attn_mask_seqlen( ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA - or self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1): + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index ca60a6c6d185..6e75f3fc0341 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -84,9 +84,9 @@ def __init__( self.uses_mrope = self.vllm_config.model_config.uses_mrope if self.uses_mrope: # M-RoPE need (3, max_num_tokens) - self.positions = torch.zeros((3, self.max_num_tokens), - dtype=torch.int64, - device=device) + self.mrope_positions = torch.zeros((3, self.max_num_tokens), + dtype=torch.int64, + device=device) else: # RoPE need (max_num_tokens,) self.positions = torch.zeros(self.max_num_tokens, @@ -152,6 +152,18 @@ def __init__( dtype=torch.int32, ).repeat(max_batch_size, 1) + # M-RoPE + def _get_positions(self, num_tokens: int): + if self.uses_mrope: + return self.mrope_positions[:, :num_tokens] + return self.positions[:num_tokens] + + def _set_positions(self, num_tokens: int, positions: torch.Tensor): + if self.uses_mrope: + self.mrope_positions[:, :num_tokens] = positions + else: + self.positions[:num_tokens] = positions + def propose( self, # [num_tokens] @@ -208,10 +220,7 @@ def propose( num_input_tokens = num_tokens # copy inputs to buffer for cudagraph # M-RoPE - if self.uses_mrope: - self.positions[:, :num_tokens] = target_positions - else: - self.positions[:num_tokens] = target_positions + self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] @@ -230,10 +239,7 @@ def propose( self.vllm_config, num_tokens=num_input_tokens): # M-RoPE - if self.uses_mrope: - forward_positions = self.positions[:, :num_input_tokens] - else: - forward_positions = self.positions[:num_input_tokens] + forward_positions = self._get_positions(num_input_tokens) ret_hidden_states = self.model( input_ids=input_ids, positions=forward_positions, @@ -375,10 +381,7 @@ def propose( # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids # M-RoPE - if self.uses_mrope: - self.positions[:, :batch_size] = clamped_positions - else: - self.positions[:batch_size] = clamped_positions + self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.is_multimodal_model: inputs_embeds = self.model.get_input_embeddings(input_ids) @@ -394,10 +397,7 @@ def propose( self.vllm_config, num_tokens=input_batch_size): # M-RoPE - if self.uses_mrope: - forward_positions = self.positions[:, :input_batch_size] - else: - forward_positions = self.positions[:input_batch_size] + forward_positions = self._get_positions(input_batch_size) ret_hidden_states = self.model( input_ids=input_ids, positions=forward_positions, @@ -934,10 +934,7 @@ def dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): # M-RoPE - if self.uses_mrope: - forward_positions = self.positions[:, :num_tokens] - else: - forward_positions = self.positions[:num_tokens] + forward_positions = self._get_positions(num_tokens) if self.is_multimodal_model: input_ids = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0493f828050c..cf4311f3cc63 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -439,6 +439,17 @@ def __init__( device="cpu", pin_memory=self.pin_memory) + # M-RoPE + def _get_positions(self, num_tokens: Any): + if isinstance(num_tokens, int): + if self.uses_mrope: + return self.mrope_positions.gpu[:, :num_tokens] + return self.positions.gpu[:num_tokens] + else: + if self.uses_mrope: + return self.mrope_positions.gpu[:, num_tokens] + return self.positions.gpu[num_tokens] + def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, @@ -2519,12 +2530,7 @@ def propose_draft_token_ids( # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] # M-RoPE - if self.uses_mrope: - target_positions = \ - self.mrope_positions.gpu[:, :num_scheduled_tokens] - else: - target_positions = self.positions.gpu[: - num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( @@ -2550,11 +2556,7 @@ def propose_draft_token_ids( target_token_ids = self.input_ids.gpu[token_indices] # M-RoPE - if self.uses_mrope: - target_positions = \ - self.mrope_positions.gpu[:, token_indices] - else: - target_positions = self.positions.gpu[token_indices] + target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( From 6d2882f5880e77718738018af38d22561af4a364 Mon Sep 17 00:00:00 2001 From: Junhong Date: Sat, 27 Sep 2025 00:19:29 +0800 Subject: [PATCH 28/32] opt Signed-off-by: Junhong --- vllm/v1/spec_decode/eagle.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6e75f3fc0341..c2480c579685 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -238,11 +238,9 @@ def propose( with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - # M-RoPE - forward_positions = self._get_positions(num_input_tokens) ret_hidden_states = self.model( input_ids=input_ids, - positions=forward_positions, + positions=self._get_positions(num_input_tokens), # M-RoPE hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) @@ -397,10 +395,9 @@ def propose( self.vllm_config, num_tokens=input_batch_size): # M-RoPE - forward_positions = self._get_positions(input_batch_size) ret_hidden_states = self.model( input_ids=input_ids, - positions=forward_positions, + positions=self._get_positions(input_batch_size), # M-RoPE hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) @@ -933,9 +930,6 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - # M-RoPE - forward_positions = self._get_positions(num_tokens) - if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -945,7 +939,7 @@ def dummy_run( self.model( input_ids=input_ids, - positions=forward_positions, + positions=self._get_positions(num_tokens), # M-RoPE hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) From cbe0744b213034f4ac67c94d995cdb0cc3dac0ab Mon Sep 17 00:00:00 2001 From: Junhong Date: Sat, 27 Sep 2025 00:25:07 +0800 Subject: [PATCH 29/32] opt Signed-off-by: Junhong --- vllm/v1/spec_decode/eagle.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index c2480c579685..7c631184587b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -80,7 +80,6 @@ def __init__( self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) - # M-RoPE self.uses_mrope = self.vllm_config.model_config.uses_mrope if self.uses_mrope: # M-RoPE need (3, max_num_tokens) @@ -152,7 +151,6 @@ def __init__( dtype=torch.int32, ).repeat(max_batch_size, 1) - # M-RoPE def _get_positions(self, num_tokens: int): if self.uses_mrope: return self.mrope_positions[:, :num_tokens] @@ -219,7 +217,6 @@ def propose( else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph - # M-RoPE self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states if self.is_multimodal_model: @@ -240,7 +237,7 @@ def propose( num_tokens=num_input_tokens): ret_hidden_states = self.model( input_ids=input_ids, - positions=self._get_positions(num_input_tokens), # M-RoPE + positions=self._get_positions(num_input_tokens), hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) @@ -258,7 +255,6 @@ def propose( draft_token_ids = logits.argmax(dim=-1) return draft_token_ids.view(-1, 1) - # M-RoPE if self.uses_mrope: positions = target_positions[:, last_token_indices] else: @@ -309,7 +305,6 @@ def propose( # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() - # M-RoPE if self.uses_mrope: positions += 1 # NOTE(woosuk): We should handle the case where the draft model @@ -345,7 +340,6 @@ def propose( common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. - # M-RoPE if self.uses_mrope: # all dimensions of positions are the same block_numbers = clamped_positions[0] // self.block_size @@ -354,7 +348,6 @@ def propose( block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - # M-RoPE if self.uses_mrope: common_attn_metadata.slot_mapping = ( block_ids * self.block_size + @@ -378,7 +371,6 @@ def propose( # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids - # M-RoPE self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.is_multimodal_model: @@ -394,10 +386,9 @@ def propose( with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): - # M-RoPE ret_hidden_states = self.model( input_ids=input_ids, - positions=self._get_positions(input_batch_size), # M-RoPE + positions=self._get_positions(input_batch_size), hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) @@ -939,7 +930,7 @@ def dummy_run( self.model( input_ids=input_ids, - positions=self._get_positions(num_tokens), # M-RoPE + positions=self._get_positions(num_tokens), hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) From e5e29e09a40fd46fd497d334659e8c263450d242 Mon Sep 17 00:00:00 2001 From: Junhong Date: Sat, 27 Sep 2025 00:29:52 +0800 Subject: [PATCH 30/32] opt Signed-off-by: Junhong --- vllm/v1/worker/gpu_model_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ee1e091cfc58..2b7e09a0dd88 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -439,7 +439,6 @@ def __init__( device="cpu", pin_memory=self.pin_memory) - # M-RoPE def _get_positions(self, num_tokens: Any): if isinstance(num_tokens, int): if self.uses_mrope: @@ -2538,7 +2537,6 @@ def propose_draft_token_ids( token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] - # M-RoPE target_positions = self._get_positions(num_scheduled_tokens) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None @@ -2564,7 +2562,6 @@ def propose_draft_token_ids( valid_sampled_tokens_count) target_token_ids = self.input_ids.gpu[token_indices] - # M-RoPE target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None From 4e24437b505442e70577e081fc705d943951ca95 Mon Sep 17 00:00:00 2001 From: Junhong Date: Sat, 27 Sep 2025 00:39:59 +0800 Subject: [PATCH 31/32] fix pre-commit Signed-off-by: Junhong --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7c631184587b..7f7ee67bbe1e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -84,8 +84,8 @@ def __init__( if self.uses_mrope: # M-RoPE need (3, max_num_tokens) self.mrope_positions = torch.zeros((3, self.max_num_tokens), - dtype=torch.int64, - device=device) + dtype=torch.int64, + device=device) else: # RoPE need (max_num_tokens,) self.positions = torch.zeros(self.max_num_tokens, From c38907b4fa6c0bfa054f518a0b42a5e8fa67196c Mon Sep 17 00:00:00 2001 From: Junhong Date: Sat, 27 Sep 2025 09:27:53 +0800 Subject: [PATCH 32/32] fix bug Signed-off-by: Junhong --- tests/v1/e2e/test_spec_decode.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index c2b184041a74..c850c3391975 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -186,7 +186,9 @@ def test_eagle_correctness( m.setenv("VLLM_ROCM_USE_AITER", "1") method, model_name, spec_model_name, tp_size = model_setup - + if "Qwen2.5-VL" in model_name: + pytest.skip("FLASH_ATTN does not support Qwen2.5-VL " + "due to its head_dim not being a a multiple of 32") ref_llm = LLM(model=model_name, max_model_len=2048, tensor_parallel_size=tp_size)