Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5e784b0
EVS support for Qwen 2.5 VL
BloodAxe Aug 26, 2025
d5eb66e
Merge branch 'main' into feature/evs-support
BloodAxe Aug 26, 2025
f00efd7
Added missing sync op for mrope_positions buffer
BloodAxe Aug 26, 2025
93d212f
Optimize copy operation by copying only necessary part of positions
BloodAxe Aug 26, 2025
338f460
Make the default come from MultiModalConfig
BloodAxe Aug 29, 2025
9dba062
Avoid extra slicing operation when not using EVS
BloodAxe Aug 29, 2025
f440149
Fix support for DP in EVS
BloodAxe Aug 29, 2025
d94cb7f
Simplify SupportsMultiModalPruning interface:
BloodAxe Aug 29, 2025
5cee296
Added tests for EVS
BloodAxe Sep 1, 2025
6a1ace0
Added tests for EVS
BloodAxe Sep 1, 2025
d177d71
Merge branch 'main' into feature/evs-support
BloodAxe Sep 1, 2025
29bd7b8
Added check and update mrope positions for EVS model that actually us…
BloodAxe Sep 1, 2025
8d0285d
Merge remote-tracking branch 'origin/feature/evs-support' into featur…
BloodAxe Sep 1, 2025
a34818d
Remove method with unnecessary comments
BloodAxe Sep 1, 2025
9071216
Merge branch 'main' into feature/evs-support
BloodAxe Sep 2, 2025
89588f4
Merge branch 'main' into feature/evs-support
ywang96 Sep 6, 2025
456eb6b
Added NVidia copyright
BloodAxe Sep 15, 2025
e3af30e
Merge remote-tracking branch 'origin/feature/evs-support' into featur…
BloodAxe Sep 15, 2025
3408e79
Merge branch 'refs/heads/main' into feature/evs-support-clean
BloodAxe Sep 15, 2025
2db1293
Merge branch 'main' into feature/evs-support
BloodAxe Sep 15, 2025
1866d18
Revert vllm/_custom_ops.py
BloodAxe Sep 15, 2025
5e3bceb
Merge remote-tracking branch 'origin/feature/evs-support' into featur…
BloodAxe Sep 15, 2025
42c32c5
Fix EVS logic of computing number of retained tokens:
BloodAxe Sep 15, 2025
22c7f5a
Update with SPDX
BloodAxe Sep 15, 2025
82eeb7d
Bugfix: Wrong clamping
BloodAxe Sep 16, 2025
c1b8137
Merge branch 'refs/heads/main' into feature/evs-support-clean
BloodAxe Sep 17, 2025
d7eca9f
Update initializaiton of video_pruning_rate
BloodAxe Sep 17, 2025
17aab5b
Ensure we cast input video size to int to avoid inconsistencies in ro…
BloodAxe Sep 17, 2025
f32f76e
Merge branch 'main' into feature/evs-support
BloodAxe Sep 18, 2025
dd8f78d
Merge branch 'main' into feature/evs-support-clean
BloodAxe Sep 25, 2025
32ae043
Post-merge fixes for video_pruning_rate argument
BloodAxe Sep 25, 2025
6573e67
Post-merge fixes for video_pruning_rate argument
BloodAxe Sep 25, 2025
3b3677c
Post-merge fixes for video_pruning_rate argument
BloodAxe Sep 25, 2025
46d76ab
Merge branch 'main' into feature/evs-support
BloodAxe Sep 25, 2025
e05a485
Update docstrings to make CI happy
BloodAxe Sep 25, 2025
76eb839
Merge remote-tracking branch 'origin/feature/evs-support' into featur…
BloodAxe Sep 25, 2025
d33d15d
Update docstrings to make CI happy
BloodAxe Sep 25, 2025
783b80d
Merge branch 'main' into feature/evs-support
BloodAxe Sep 25, 2025
d3375d5
Merge branch 'main' into feature/evs-support-clean
BloodAxe Sep 25, 2025
2f00af4
Update docstrings to make CI happy
BloodAxe Sep 25, 2025
2d2e4a2
Merge branch 'main' into feature/evs-support
BloodAxe Sep 25, 2025
85a648b
Update docstrings to make CI happy
BloodAxe Sep 25, 2025
f0ba647
Merge remote-tracking branch 'origin/feature/evs-support' into featur…
BloodAxe Sep 25, 2025
6d35a7f
Fix intialization of is_multimodal_pruning_enabled
BloodAxe Sep 25, 2025
ca121c9
Merge branch 'main' into feature/evs-support
BloodAxe Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions tests/models/multimodal/generation/test_qwen2_5_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from vllm.multimodal.video import sample_frames_from_video

from ....conftest import VIDEO_ASSETS

models = ["Qwen/Qwen2.5-VL-3B-Instruct"]
target_dtype = "bfloat16"

VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"


def qwen2_5_vl_chat_template(*query):
return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501


VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
"baby_reading":
qwen2_5_vl_chat_template(
VIDEO_PLACEHOLDER,
"Describe this video with a short sentence ",
"(no more than 20 words)",
),
})


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
@pytest.mark.parametrize("num_frames", [16])
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_qwen2_5_vl_evs_functionality(vllm_runner, video_assets, model,
video_pruning_rate: float,
num_frames: int, dtype: str,
max_tokens: int) -> None:
"""Test EVS (Efficient Video Sampling) functionality with different
pruning rates.
"""

# Sample frames from video assets
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]

prompts = [VIDEO_PROMPTS[0]]
videos = [sampled_vids[0]]

# Initialize model with EVS configuration
with vllm_runner(model,
runner="generate",
max_model_len=4000,
max_num_seqs=1,
dtype=dtype,
limit_mm_per_prompt={"video": 1},
tensor_parallel_size=1,
video_pruning_rate=video_pruning_rate) as vllm_model:

# Generate output - this should not crash
outputs = vllm_model.generate_greedy(prompts,
max_tokens,
videos=videos)

# Basic validation that we got a response
assert len(outputs) == 1
output_ids, output_text = outputs[0]

# Ensure we got some output
assert len(output_ids) > 0
assert len(output_text) > 0

# Ensure the output is a string
assert isinstance(output_text, str)


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
@pytest.mark.parametrize("num_frames", [16])
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
def test_qwen2_5_vl_evs_batched_videos(vllm_runner, video_assets, model,
video_pruning_rate: float,
num_frames: int, dtype: str,
max_tokens: int) -> None:
"""Test EVS functionality with batched videos.

This test validates that:
1. The model handles batched video inputs correctly with EVS
2. Both pruning configurations work with multiple videos
3. The model doesn't crash when processing multiple videos simultaneously
"""
# Sample frames from video assets
sampled_vids = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]

# Test batched videos
prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]]
videos = [sampled_vids[0],
sampled_vids[0]] # Use same video twice for testing

# Initialize model with EVS configuration
with vllm_runner(model,
runner="generate",
max_model_len=4000,
max_num_seqs=2,
dtype=dtype,
limit_mm_per_prompt={"video": 2},
tensor_parallel_size=1,
video_pruning_rate=video_pruning_rate) as vllm_model:

# Generate output - this should not crash
outputs = vllm_model.generate_greedy(prompts,
max_tokens,
videos=videos)

# Basic validation that we got responses for both videos
assert len(outputs) == 2

for output_ids, output_text in outputs:
# Ensure we got some output for each video
assert len(output_ids) > 0
assert len(output_text) > 0

# Ensure the output is a string
assert isinstance(output_text, str)
27 changes: 16 additions & 11 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class ModelConfig:
mm_encoder_tp_mode: InitVar[Optional[MMEncoderTPMode]] = None
interleave_mm_strings: InitVar[Optional[bool]] = None
skip_mm_profiling: InitVar[Optional[bool]] = None
video_pruning_rate: InitVar[Optional[float]] = None

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -311,6 +312,7 @@ def compute_hash(self) -> str:
factors.append(self.override_generation_config)
factors.append(self.rope_scaling)
factors.append(self.rope_theta)
factors.append(self.video_pruning_rate)

# hf_config can control how the model looks!
try:
Expand Down Expand Up @@ -338,17 +340,19 @@ def compute_hash(self) -> str:
return hashlib.sha256(str(factors).encode()).hexdigest()

def __post_init__(
self,
# Multimodal config init vars
limit_mm_per_prompt: Optional[dict[str, int]],
media_io_kwargs: Optional[dict[str, dict[str, Any]]],
mm_processor_kwargs: Optional[dict[str, Any]],
mm_processor_cache_gb: Optional[float],
mm_processor_cache_type: Optional[MMCacheType],
mm_shm_cache_max_object_size_mb: Optional[int],
mm_encoder_tp_mode: Optional[MMEncoderTPMode],
interleave_mm_strings: Optional[bool],
skip_mm_profiling: Optional[bool]) -> None:
self,
# Multimodal config init vars
limit_mm_per_prompt: Optional[dict[str, int]],
media_io_kwargs: Optional[dict[str, dict[str, Any]]],
mm_processor_kwargs: Optional[dict[str, Any]],
mm_processor_cache_gb: Optional[float],
mm_processor_cache_type: Optional[MMCacheType],
mm_shm_cache_max_object_size_mb: Optional[int],
mm_encoder_tp_mode: Optional[MMEncoderTPMode],
interleave_mm_strings: Optional[bool],
skip_mm_profiling: Optional[bool],
video_pruning_rate: Optional[float],
) -> None:
# Set the default seed to 0 in V1.
# NOTE(woosuk): In V0, we set the default seed to None because the
# driver worker shares the same process as the user process, and thus
Expand Down Expand Up @@ -612,6 +616,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
mm_encoder_tp_mode=mm_encoder_tp_mode,
interleave_mm_strings=interleave_mm_strings,
skip_mm_profiling=skip_mm_profiling,
video_pruning_rate=video_pruning_rate,
)

mm_config_kwargs = {
Expand Down
9 changes: 9 additions & 0 deletions vllm/config/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class MultiModalConfig:
This reduces engine startup time but shifts the responsibility to users for
estimating the peak memory usage of the activation of multimodal encoder and
embedding cache."""
video_pruning_rate: Optional[float] = None
"""Sets pruning rate for video pruning via Efficient Video Sampling.
Value sits in range [0;1) and determines fraction of media tokens
from each video to be pruned.
"""

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -118,3 +123,7 @@ def merge_mm_processor_kwargs(
"""
kwargs = self.mm_processor_kwargs or {}
return kwargs | dict(inference_kwargs)

def is_multimodal_pruning_enabled(self):
return (self.video_pruning_rate is not None
and self.video_pruning_rate > 0)
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ class EngineArgs:
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
io_processor_plugin: Optional[str] = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
Expand Down Expand Up @@ -813,6 +814,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
multimodal_group.add_argument("--skip-mm-profiling",
**multimodal_kwargs["skip_mm_profiling"])

multimodal_group.add_argument(
"--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"])

# LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig)
lora_group = parser.add_argument_group(
Expand Down Expand Up @@ -1032,6 +1036,7 @@ def create_model_config(self) -> ModelConfig:
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors,
video_pruning_rate=self.video_pruning_rate,
io_processor_plugin=self.io_processor_plugin,
)

Expand Down
55 changes: 55 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,42 @@ def get_input_embeddings(
...


@runtime_checkable
class SupportsMultiModalPruning(Protocol):
"""The interface required for models that support returning both input
embeddings and positions. Model may require custom positions for dynamic
pruning of multimodal embeddings.
"""
supports_multimodal_pruning: ClassVar[Literal[True]] = True

def recompute_mrope_positions(
self, input_ids: list[int],
multimodal_embeddings: MultiModalEmbeddings,
mrope_positions: torch.LongTensor, num_computed_tokens: int
) -> tuple[MultiModalEmbeddings, Tensor, int]:
"""
Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed
for unpruned sequence and becomes incorrect once pruning occurs,
so once we prune media tokens we should reflect this in the
mrope_positions before we feed it to LLM.

Args:
input_ids: (N,) All input tokens of the prompt containing
entire sequence.
multimodal_embeddings: Tuple of multimodal embeddings that
fits into the prefill chunk that is being processed.
mrope_positions: Existing mrope positions (3, N) for entire
sequence
num_computed_tokens: A number of computed tokens so far.

Returns:
Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta).
"""
...


@overload
def supports_multimodal(
model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
Expand Down Expand Up @@ -142,6 +178,25 @@ def supports_multimodal_encoder_tp_data(
return getattr(model, "supports_encoder_tp_data", False)


@overload
def supports_multimodal_pruning(
model: type[object]) -> TypeIs[type[SupportsMultiModalPruning]]:
...


@overload
def supports_multimodal_pruning(
model: object) -> TypeIs[SupportsMultiModalPruning]:
...


def supports_multimodal_pruning(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModalPruning]],
TypeIs[SupportsMultiModalPruning]]:
return getattr(model, "supports_multimodal_pruning", False)


@runtime_checkable
class SupportsScoreTemplate(Protocol):
"""The interface required for all models that support score template."""
Expand Down
Loading