Skip to content

Commit 404c0e7

Browse files
committed
Merge branch 'main' into unique_filepath
2 parents 0ee6d16 + dd70437 commit 404c0e7

File tree

16 files changed

+1024
-100
lines changed

16 files changed

+1024
-100
lines changed

examples/offline_inference/basic/chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def print_outputs(outputs):
8787
use_tqdm=False,
8888
chat_template=chat_template,
8989
)
90+
print_outputs(outputs)
9091

9192

9293
if __name__ == "__main__":
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
6+
from vllm.multimodal.video import sample_frames_from_video
7+
8+
from ....conftest import VIDEO_ASSETS
9+
10+
models = ["Qwen/Qwen2.5-VL-3B-Instruct"]
11+
target_dtype = "bfloat16"
12+
13+
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
14+
15+
16+
def qwen2_5_vl_chat_template(*query):
17+
return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501
18+
19+
20+
VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
21+
"baby_reading":
22+
qwen2_5_vl_chat_template(
23+
VIDEO_PLACEHOLDER,
24+
"Describe this video with a short sentence ",
25+
"(no more than 20 words)",
26+
),
27+
})
28+
29+
30+
@pytest.mark.core_model
31+
@pytest.mark.parametrize("model", models)
32+
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
33+
@pytest.mark.parametrize("num_frames", [16])
34+
@pytest.mark.parametrize("dtype", [target_dtype])
35+
@pytest.mark.parametrize("max_tokens", [128])
36+
def test_qwen2_5_vl_evs_functionality(vllm_runner, video_assets, model,
37+
video_pruning_rate: float,
38+
num_frames: int, dtype: str,
39+
max_tokens: int) -> None:
40+
"""Test EVS (Efficient Video Sampling) functionality with different
41+
pruning rates.
42+
"""
43+
44+
# Sample frames from video assets
45+
sampled_vids = [
46+
sample_frames_from_video(asset.np_ndarrays, num_frames)
47+
for asset in video_assets
48+
]
49+
50+
prompts = [VIDEO_PROMPTS[0]]
51+
videos = [sampled_vids[0]]
52+
53+
# Initialize model with EVS configuration
54+
with vllm_runner(model,
55+
runner="generate",
56+
max_model_len=4000,
57+
max_num_seqs=1,
58+
dtype=dtype,
59+
limit_mm_per_prompt={"video": 1},
60+
tensor_parallel_size=1,
61+
video_pruning_rate=video_pruning_rate) as vllm_model:
62+
63+
# Generate output - this should not crash
64+
outputs = vllm_model.generate_greedy(prompts,
65+
max_tokens,
66+
videos=videos)
67+
68+
# Basic validation that we got a response
69+
assert len(outputs) == 1
70+
output_ids, output_text = outputs[0]
71+
72+
# Ensure we got some output
73+
assert len(output_ids) > 0
74+
assert len(output_text) > 0
75+
76+
# Ensure the output is a string
77+
assert isinstance(output_text, str)
78+
79+
80+
@pytest.mark.core_model
81+
@pytest.mark.parametrize("model", models)
82+
@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75])
83+
@pytest.mark.parametrize("num_frames", [16])
84+
@pytest.mark.parametrize("dtype", [target_dtype])
85+
@pytest.mark.parametrize("max_tokens", [128])
86+
def test_qwen2_5_vl_evs_batched_videos(vllm_runner, video_assets, model,
87+
video_pruning_rate: float,
88+
num_frames: int, dtype: str,
89+
max_tokens: int) -> None:
90+
"""Test EVS functionality with batched videos.
91+
92+
This test validates that:
93+
1. The model handles batched video inputs correctly with EVS
94+
2. Both pruning configurations work with multiple videos
95+
3. The model doesn't crash when processing multiple videos simultaneously
96+
"""
97+
# Sample frames from video assets
98+
sampled_vids = [
99+
sample_frames_from_video(asset.np_ndarrays, num_frames)
100+
for asset in video_assets
101+
]
102+
103+
# Test batched videos
104+
prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]]
105+
videos = [sampled_vids[0],
106+
sampled_vids[0]] # Use same video twice for testing
107+
108+
# Initialize model with EVS configuration
109+
with vllm_runner(model,
110+
runner="generate",
111+
max_model_len=4000,
112+
max_num_seqs=2,
113+
dtype=dtype,
114+
limit_mm_per_prompt={"video": 2},
115+
tensor_parallel_size=1,
116+
video_pruning_rate=video_pruning_rate) as vllm_model:
117+
118+
# Generate output - this should not crash
119+
outputs = vllm_model.generate_greedy(prompts,
120+
max_tokens,
121+
videos=videos)
122+
123+
# Basic validation that we got responses for both videos
124+
assert len(outputs) == 2
125+
126+
for output_ids, output_text in outputs:
127+
# Ensure we got some output for each video
128+
assert len(output_ids) > 0
129+
assert len(output_text) > 0
130+
131+
# Ensure the output is a string
132+
assert isinstance(output_text, str)

vllm/config/model.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ class ModelConfig:
283283
mm_encoder_tp_mode: InitVar[Optional[MMEncoderTPMode]] = None
284284
interleave_mm_strings: InitVar[Optional[bool]] = None
285285
skip_mm_profiling: InitVar[Optional[bool]] = None
286+
video_pruning_rate: InitVar[Optional[float]] = None
286287

287288
def compute_hash(self) -> str:
288289
"""
@@ -311,6 +312,7 @@ def compute_hash(self) -> str:
311312
factors.append(self.override_generation_config)
312313
factors.append(self.rope_scaling)
313314
factors.append(self.rope_theta)
315+
factors.append(self.video_pruning_rate)
314316

315317
# hf_config can control how the model looks!
316318
try:
@@ -338,17 +340,19 @@ def compute_hash(self) -> str:
338340
return hashlib.sha256(str(factors).encode()).hexdigest()
339341

340342
def __post_init__(
341-
self,
342-
# Multimodal config init vars
343-
limit_mm_per_prompt: Optional[dict[str, int]],
344-
media_io_kwargs: Optional[dict[str, dict[str, Any]]],
345-
mm_processor_kwargs: Optional[dict[str, Any]],
346-
mm_processor_cache_gb: Optional[float],
347-
mm_processor_cache_type: Optional[MMCacheType],
348-
mm_shm_cache_max_object_size_mb: Optional[int],
349-
mm_encoder_tp_mode: Optional[MMEncoderTPMode],
350-
interleave_mm_strings: Optional[bool],
351-
skip_mm_profiling: Optional[bool]) -> None:
343+
self,
344+
# Multimodal config init vars
345+
limit_mm_per_prompt: Optional[dict[str, int]],
346+
media_io_kwargs: Optional[dict[str, dict[str, Any]]],
347+
mm_processor_kwargs: Optional[dict[str, Any]],
348+
mm_processor_cache_gb: Optional[float],
349+
mm_processor_cache_type: Optional[MMCacheType],
350+
mm_shm_cache_max_object_size_mb: Optional[int],
351+
mm_encoder_tp_mode: Optional[MMEncoderTPMode],
352+
interleave_mm_strings: Optional[bool],
353+
skip_mm_profiling: Optional[bool],
354+
video_pruning_rate: Optional[float],
355+
) -> None:
352356
# Set the default seed to 0 in V1.
353357
# NOTE(woosuk): In V0, we set the default seed to None because the
354358
# driver worker shares the same process as the user process, and thus
@@ -612,6 +616,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
612616
mm_encoder_tp_mode=mm_encoder_tp_mode,
613617
interleave_mm_strings=interleave_mm_strings,
614618
skip_mm_profiling=skip_mm_profiling,
619+
video_pruning_rate=video_pruning_rate,
615620
)
616621

617622
mm_config_kwargs = {

vllm/config/multimodal.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class MultiModalConfig:
7878
This reduces engine startup time but shifts the responsibility to users for
7979
estimating the peak memory usage of the activation of multimodal encoder and
8080
embedding cache."""
81+
video_pruning_rate: Optional[float] = None
82+
"""Sets pruning rate for video pruning via Efficient Video Sampling.
83+
Value sits in range [0;1) and determines fraction of media tokens
84+
from each video to be pruned.
85+
"""
8186

8287
def compute_hash(self) -> str:
8388
"""
@@ -118,3 +123,7 @@ def merge_mm_processor_kwargs(
118123
"""
119124
kwargs = self.mm_processor_kwargs or {}
120125
return kwargs | dict(inference_kwargs)
126+
127+
def is_multimodal_pruning_enabled(self):
128+
return (self.video_pruning_rate is not None
129+
and self.video_pruning_rate > 0)

vllm/engine/arg_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ class EngineArgs:
391391
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
392392
io_processor_plugin: Optional[str] = None
393393
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
394+
video_pruning_rate: float = MultiModalConfig.video_pruning_rate
394395
# LoRA fields
395396
enable_lora: bool = False
396397
enable_lora_bias: bool = LoRAConfig.bias_enabled
@@ -813,6 +814,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
813814
multimodal_group.add_argument("--skip-mm-profiling",
814815
**multimodal_kwargs["skip_mm_profiling"])
815816

817+
multimodal_group.add_argument(
818+
"--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"])
819+
816820
# LoRA related configs
817821
lora_kwargs = get_kwargs(LoRAConfig)
818822
lora_group = parser.add_argument_group(
@@ -1032,6 +1036,7 @@ def create_model_config(self) -> ModelConfig:
10321036
model_impl=self.model_impl,
10331037
override_attention_dtype=self.override_attention_dtype,
10341038
logits_processors=self.logits_processors,
1039+
video_pruning_rate=self.video_pruning_rate,
10351040
io_processor_plugin=self.io_processor_plugin,
10361041
)
10371042

vllm/inputs/preprocess.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ def _process_embeds(
278278
raise ValueError(
279279
"prompt_embeds must be of shape (seq_len, hidden_size).")
280280

281+
# Tensors must be on CPU for serialization between processes
282+
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
283+
# hidden device transfer in the critical path of generation.
284+
prompt_embeds = prompt_embeds.cpu()
285+
281286
return embeds_inputs(prompt_embeds=prompt_embeds,
282287
cache_salt=parsed_content.get("cache_salt"))
283288

0 commit comments

Comments
 (0)