diff --git a/tests/models/registry.py b/tests/models/registry.py index 6a6e2538559f..e321acc873c6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -651,6 +651,9 @@ def check_available_online( "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL"), + "Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo( + "Qwen/Qwen2.5-VL-7B-Instruct", + speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"), "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3"), } diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index ea8d94722859..8f048775352e 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -129,6 +129,11 @@ def test_ngram_correctness( ["model_setup", "mm_enabled"], [ (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct", + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), + False, + marks=pytest.mark.skip(reason="Skipping due to its " \ + "head_dim not being a a multiple of 32")), (("eagle", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", @@ -145,8 +150,8 @@ def test_ngram_correctness( "eagle618/eagle-deepseek-v3-random", 1), False), ], ids=[ - "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", "llama4_eagle", - "llama4_eagle_mm", "deepseek_eagle" + "qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3", + "llama4_eagle", "llama4_eagle_mm", "deepseek_eagle" ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 68a937d5750e..f0c0d829a393 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ): dataset_class = MLPerfDataset args.hf_split = "train" + elif ( + args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MMStarDataset + args.hf_split = "val" + args.hf_subset = None else: supported_datasets = set([ dataset_name for cls in HuggingFaceDataset.__subclasses__() @@ -2721,3 +2728,76 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]: random.shuffle(requests) return requests + + +# ----------------------------------------------------------------------------- +# MMStar Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MMStarDataset(HuggingFaceDataset): + """ + Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar + refer to: https://github.com/sgl-project/SpecForge/pull/106 + """ + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"} + IS_MULTIMODAL = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list[SampleRequest]: + # If --hf-output-len is not set, use the default output length. + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests: list[SampleRequest] = [] + + for ind, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + # Split the question text from options + # (keep only the part before "Options:"). + full_q: str = item.get("question", "") + question_text = full_q.split("Options:", 1)[0].strip() + + # Multimodal image content. + mm_content = process_image(item["image"]) + + # Compute prompt token length (note: this is plain text length + # if enable_multimodal_chat is False). + prompt_len = len(tokenizer(question_text).input_ids) + + if enable_multimodal_chat: + # If multimodal content should be embedded in the chat message, + # convert to [{"role":"user","content":[...]}] + prompt = self.apply_multimodal_chat_transformation( + question_text, mm_content + ) + mm_for_request = None # Already embedded in chat content. + else: + # Default: prompt is plain text, + # image is in mm_content for the bench to assemble. + prompt = question_text + mm_for_request = mm_content + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_for_request, + request_id=request_id_prefix + str(ind), + ) + ) + + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) + return sampled_requests diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index b99a1547918e..55b6ae6ee0e9 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -8,7 +8,6 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -19,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) @@ -102,7 +102,6 @@ def forward( return hidden_states, residual -@support_torch_compile class LlamaModel(nn.Module): def __init__( @@ -145,13 +144,21 @@ def __init__( eps=self.config.rms_norm_eps, ) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - input_embeds = self.embed_tokens(input_ids) + if input_embeds is None: + input_embeds = self.get_input_embeddings(input_ids) assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None @@ -239,11 +246,7 @@ def forward( hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if inputs_embeds is not None: - raise NotImplementedError( - f"{type(self).__name__} does not support multimodal inputs yet." - ) - return self.model(input_ids, positions, hidden_states) + return self.model(input_ids, positions, hidden_states, inputs_embeds) def compute_logits( self, @@ -299,3 +302,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): skip_substrs=skip_substrs, ) loader.load_weights(model_weights.items()) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + return inputs_embeds diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6af6faa2b296..3199f53a0539 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -68,7 +68,7 @@ from vllm.utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, +from .interfaces import (MultiModalEmbeddings, SupportsEagle3, SupportsLoRA, SupportsMultiModal, SupportsMultiModalPruning, SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder @@ -965,7 +965,7 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, - SupportsQuant, + SupportsQuant, SupportsEagle3, SupportsMultiModalPruning): packed_modules_mapping = { @@ -1028,6 +1028,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.language_model.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.language_model.model.layers) + return (2, num_layers // 2, num_layers - 3) + def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 10e9aa4db078..0471164ab8a6 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -286,6 +286,7 @@ "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 394df48b4153..51e54e0dc337 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -80,9 +80,17 @@ def __init__( self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) + self.uses_mrope = self.vllm_config.model_config.uses_mrope + if self.uses_mrope: + # M-RoPE need (3, max_num_tokens) + self.mrope_positions = torch.zeros((3, self.max_num_tokens), + dtype=torch.int64, + device=device) + else: + # RoPE need (max_num_tokens,) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=device) self.hidden_states = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -143,11 +151,22 @@ def __init__( dtype=torch.int32, ).repeat(max_batch_size, 1) + def _get_positions(self, num_tokens: int): + if self.uses_mrope: + return self.mrope_positions[:, :num_tokens] + return self.positions[:num_tokens] + + def _set_positions(self, num_tokens: int, positions: torch.Tensor): + if self.uses_mrope: + self.mrope_positions[:, :num_tokens] = positions + else: + self.positions[:num_tokens] = positions + def propose( self, # [num_tokens] target_token_ids: torch.Tensor, - # [num_tokens] + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, @@ -198,7 +217,7 @@ def propose( else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions + self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] @@ -218,7 +237,7 @@ def propose( num_tokens=num_input_tokens): ret_hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:num_input_tokens], + positions=self._get_positions(num_input_tokens), hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) @@ -235,7 +254,10 @@ def propose( draft_token_ids = logits.argmax(dim=-1) return draft_token_ids.view(-1, 1) - positions = target_positions[last_token_indices] + if self.uses_mrope: + positions = target_positions[:, last_token_indices] + else: + positions = target_positions[last_token_indices] if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"): hidden_states = self.hidden_states[last_token_indices] else: @@ -282,25 +304,34 @@ def propose( # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + if self.uses_mrope: + positions += 1 + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. + # Since it is complex to remove such requests from the batch, + # we keep them in the batch but adjust the position ids + # and slot mappings to avoid the + # out-of-range access during the model execution. + # The draft tokens generated with this adjustment + # should be ignored. + exceeds_max_model_len = positions[0] >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where\ + (exceeds_max_model_len.unsqueeze(0), \ + torch.zeros_like(positions), positions) + else: + positions += 1 + exceeds_max_model_len = positions >= self.max_model_len + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 common_attn_metadata.seq_lens_cpu += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) @@ -308,13 +339,22 @@ def propose( common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size + if self.uses_mrope: + # all dimensions of positions are the same + block_numbers = clamped_positions[0] // self.block_size + else: + block_numbers = clamped_positions // self.block_size block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions % self.block_size) + if self.uses_mrope: + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions[0] % self.block_size) + else: + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -330,7 +370,7 @@ def propose( # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions + self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.is_multimodal_model: inputs_embeds = self.model.get_input_embeddings(input_ids) @@ -347,7 +387,7 @@ def propose( num_tokens=input_batch_size): ret_hidden_states = self.model( input_ids=input_ids, - positions=self.positions[:input_batch_size], + positions=self._get_positions(input_batch_size), hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) @@ -787,6 +827,11 @@ def prepare_inputs( return spec_common_attn_metadata, token_indices + def get_model_name(self, model: nn.Module) -> str: + if hasattr(model, 'module'): # multi-GPU + model = model.module + return model.__class__.__name__ + def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -820,8 +865,13 @@ def load_model(self, target_model: nn.Module) -> None: if supports_multimodal(target_model): # handle multimodality - self.model.config.image_token_index = ( - target_model.config.image_token_index) + if (self.get_model_name(target_model) == + "Qwen2_5_VLForConditionalGeneration"): + self.model.config.image_token_index = ( + target_model.config.image_token_id) + else: + self.model.config.image_token_index = ( + target_model.config.image_token_index) target_language_model = target_model.get_language_model() else: target_language_model = target_model @@ -892,7 +942,7 @@ def dummy_run( self.model( input_ids=input_ids, - positions=self.positions[:num_tokens], + positions=self._get_positions(num_tokens), hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f87a327d02a5..22a177dd7cc7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -442,6 +442,16 @@ def __init__( device="cpu", pin_memory=self.pin_memory) + def _get_positions(self, num_tokens: Any): + if isinstance(num_tokens, int): + if self.uses_mrope: + return self.mrope_positions.gpu[:, :num_tokens] + return self.positions.gpu[:num_tokens] + else: + if self.uses_mrope: + return self.mrope_positions.gpu[:, num_tokens] + return self.positions.gpu[num_tokens] + def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, @@ -2544,8 +2554,7 @@ def propose_draft_token_ids( token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions.gpu[:num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( @@ -2570,8 +2579,7 @@ def propose_draft_token_ids( valid_sampled_tokens_count) target_token_ids = self.input_ids.gpu[token_indices] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions.gpu[token_indices] + target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat(