Skip to content

Commit 03e655f

Browse files
Isotr0pydsxsteven
authored andcommitted
[VLM] Optimize GLM4.5-V-style video processing to only decode necessary frames (vllm-project#24161)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent f707706 commit 03e655f

File tree

5 files changed

+233
-55
lines changed

5 files changed

+233
-55
lines changed

tests/models/multimodal/processing/test_glm4_1v.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from vllm.assets.video import VideoAsset
77
from vllm.multimodal import MULTIMODAL_REGISTRY
8+
from vllm.multimodal.video import OpenCVDynamicVideoBackend, OpenCVVideoBackend
89

910
from ...utils import build_model_context
1011

@@ -50,3 +51,49 @@ def test_processor_override(
5051

5152
assert grid_t == expected_grid_t
5253
assert video_tok_count == expected_toks_per_frame * grid_t
54+
55+
56+
@pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"])
57+
@pytest.mark.parametrize("fps", [2])
58+
def test_video_loader_consistency(
59+
model_id: str,
60+
fps: int,
61+
):
62+
"""
63+
Ensure dynamic video loader (pre-sampled by loader) and normal video
64+
loader (post-sampled by processor) produce same video processing outputs.
65+
"""
66+
ctx = build_model_context(
67+
model_id,
68+
mm_processor_kwargs=None,
69+
limit_mm_per_prompt={"video": 1},
70+
)
71+
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
72+
hf_processor_mm_kwargs = {"fps": fps}
73+
74+
# Build the image str / prompt based on the number of images we pass
75+
prompt = "<|begin_of_video|><|video|><|end_of_video|>"
76+
77+
video_path = VideoAsset(name="baby_reading", num_frames=-1).video_path
78+
with open(video_path, "rb") as f:
79+
video_bytes = f.read()
80+
81+
static_video, static_metadata = OpenCVVideoBackend.load_bytes(video_bytes)
82+
dynamic_video, dynamic_metadata = OpenCVDynamicVideoBackend.load_bytes(
83+
video_bytes, requested_fps=fps)
84+
85+
# pre-sampled loader shouldn't read all frames
86+
assert len(dynamic_video) < len(static_video)
87+
88+
static_mm_data = {"video": [(static_video, static_metadata)]}
89+
dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]}
90+
91+
static_outputs = processor.apply(prompt, static_mm_data,
92+
hf_processor_mm_kwargs)
93+
dynamic_outputs = processor.apply(prompt, dynamic_mm_data,
94+
hf_processor_mm_kwargs)
95+
96+
assert static_outputs["prompt_token_ids"] == dynamic_outputs[
97+
"prompt_token_ids"]
98+
assert static_outputs["mm_kwargs"].get_data(
99+
) == dynamic_outputs["mm_kwargs"].get_data()

tests/multimodal/test_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,32 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
204204
assert metadata_sync == metadata_async
205205

206206

207+
@pytest.mark.asyncio
208+
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
209+
@pytest.mark.parametrize("max_duration", [1, 60, 1800])
210+
@pytest.mark.parametrize("requested_fps", [2, 24])
211+
async def test_fetch_video_http_with_dynamic_loader(
212+
video_url: str, max_duration: int, requested_fps: int,
213+
monkeypatch: pytest.MonkeyPatch):
214+
with monkeypatch.context() as m:
215+
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic")
216+
connector = MediaConnector(
217+
media_io_kwargs={
218+
"video": {
219+
"max_duration": max_duration,
220+
"requested_fps": requested_fps,
221+
}
222+
})
223+
224+
video_sync, metadata_sync = connector.fetch_video(video_url)
225+
video_async, metadata_async = await connector.fetch_video_async(
226+
video_url)
227+
228+
assert np.array_equal(video_sync, video_async)
229+
assert metadata_sync == metadata_async
230+
assert metadata_sync["video_backend"] == "opencv_dynamic"
231+
232+
207233
# Used for `test_argsort_mm_positions`.
208234
class TestCase(NamedTuple):
209235
mm_positions: "MultiModalPlaceholderDict"

vllm/assets/video.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,23 @@ class VideoAsset:
110110
def filename(self) -> str:
111111
return self._NAME_TO_FILE[self.name]
112112

113+
@property
114+
def video_path(self) -> str:
115+
return download_video_asset(self.filename)
116+
113117
@property
114118
def pil_images(self) -> list[Image.Image]:
115-
video_path = download_video_asset(self.filename)
116-
ret = video_to_pil_images_list(video_path, self.num_frames)
119+
ret = video_to_pil_images_list(self.video_path, self.num_frames)
117120
return ret
118121

119122
@property
120123
def np_ndarrays(self) -> npt.NDArray:
121-
video_path = download_video_asset(self.filename)
122-
ret = video_to_ndarrays(video_path, self.num_frames)
124+
ret = video_to_ndarrays(self.video_path, self.num_frames)
123125
return ret
124126

125127
@property
126128
def metadata(self) -> dict[str, Any]:
127-
video_path = download_video_asset(self.filename)
128-
ret = video_get_metadata(video_path)
129+
ret = video_get_metadata(self.video_path)
129130
return ret
130131

131132
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
@@ -134,5 +135,4 @@ def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
134135
135136
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
136137
"""
137-
video_path = download_video_asset(self.filename)
138-
return librosa.load(video_path, sr=sampling_rate)[0]
138+
return librosa.load(self.video_path, sr=sampling_rate)[0]

vllm/model_executor/models/glm4_1v.py

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,43 @@ def _get_video_second_idx(self, metadata: dict[str, Any],
10231023
selected_timestamps.append(timestamps_list[idx])
10241024
return selected_timestamps
10251025

1026+
def _construct_video_placeholder(
1027+
self,
1028+
video_array: np.ndarray,
1029+
metadata: dict[str, Any],
1030+
grid_thw: torch.Tensor,
1031+
) -> str:
1032+
hf_processor = self.get_hf_processor()
1033+
tokenizer = self.get_tokenizer()
1034+
image_processor = hf_processor.image_processor
1035+
1036+
hf_config = self.get_hf_config()
1037+
boi_token_id = hf_config.image_start_token_id
1038+
eoi_token_id = hf_config.image_end_token_id
1039+
bov_token_id = hf_config.video_start_token_id
1040+
eov_token_id = hf_config.video_end_token_id
1041+
merge_length = image_processor.merge_size**2
1042+
1043+
assert isinstance(grid_thw, torch.Tensor)
1044+
timestamps = self._get_video_second_idx(metadata, len(video_array))
1045+
frames_idx_token = [
1046+
tokenizer.encode(str(i), add_special_tokens=False)
1047+
for i in timestamps
1048+
]
1049+
T, H, W = grid_thw
1050+
num_tokens_per_frame = int(H * W) // merge_length
1051+
placeholder = []
1052+
placeholder.append(bov_token_id)
1053+
for frame_idx in frames_idx_token:
1054+
placeholder.append(boi_token_id)
1055+
placeholder.extend([hf_processor.video_token_id] *
1056+
num_tokens_per_frame)
1057+
placeholder.append(eoi_token_id)
1058+
placeholder.extend(frame_idx)
1059+
placeholder.append(eov_token_id)
1060+
1061+
return placeholder
1062+
10261063

10271064
class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
10281065

@@ -1118,17 +1155,10 @@ def _call_hf_processor(
11181155
for item in mm_data.pop("videos", []):
11191156
video_array, metadata = item
11201157

1121-
# FIXME(Isotr0py): Activate the below logic after we can disable
1122-
# resampling from video loader backend.
1123-
# assert metadata["total_num_frames"] == len(video_array), (
1124-
# f"Total frames {metadata['total_num_frames']} does not "
1125-
# f"match the length of video array {len(video_array)}.")
1158+
if metadata["video_backend"] == "opencv_dynamic":
1159+
mm_kwargs["do_sample_frames"] = False
11261160

1127-
# NOTE: Temporary workaround for resampled videos.
1128-
# this can cause a divergence with HF implementation if
1129-
# the input video is resampled in advance.
1130-
1131-
if metadata["total_num_frames"] != len(video_array):
1161+
elif metadata["total_num_frames"] != len(video_array):
11321162
logger.warning(
11331163
"Total frames in metadata "
11341164
"(%s) does not match the length of "
@@ -1140,23 +1170,34 @@ def _call_hf_processor(
11401170
len(video_array),
11411171
)
11421172
metadata["total_num_frames"] = len(video_array)
1143-
metadata = VideoMetadata(**metadata)
11441173

11451174
video_mm_data = dict()
11461175
video_mm_data["videos"] = [[video_array]]
1147-
video_mm_data["video_metadata"] = [[metadata]]
1176+
video_mm_data["video_metadata"] = [[VideoMetadata(**metadata)]]
11481177

11491178
video_outputs = super()._call_hf_processor(
11501179
prompt="<|begin_of_video|><|video|><|end_of_video|>",
11511180
mm_data=video_mm_data,
11521181
mm_kwargs=mm_kwargs,
11531182
tok_kwargs=tok_kwargs,
11541183
)
1155-
input_ids = video_outputs.pop("input_ids")
1156-
input_ids[input_ids == processor.image_token_id] = (
1157-
processor.video_token_id)
1158-
video_placeholder = processor.tokenizer.batch_decode(
1159-
input_ids)[0]
1184+
if "do_sample_frames" in mm_kwargs and not mm_kwargs[
1185+
"do_sample_frames"]:
1186+
# Transformers v4.55 has incorrect timestamps issue for
1187+
# skip sampling. We construct the placeholder manually to
1188+
# get placeholders with correct timestamps.
1189+
placeholder = self.info._construct_video_placeholder(
1190+
video_array,
1191+
metadata,
1192+
video_outputs["video_grid_thw"].squeeze(0),
1193+
)
1194+
video_placeholder = processor.tokenizer.decode(placeholder)
1195+
else:
1196+
input_ids = video_outputs.pop("input_ids")
1197+
input_ids[input_ids == processor.image_token_id] = (
1198+
processor.video_token_id)
1199+
video_placeholder = processor.tokenizer.batch_decode(
1200+
input_ids)[0]
11601201
prompt = prompt.replace(
11611202
"<|begin_of_video|><|video|><|end_of_video|>",
11621203
video_placeholder,
@@ -1202,14 +1243,6 @@ def _get_prompt_updates(
12021243
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
12031244
image_processor = self.info.get_image_processor(
12041245
**hf_processor_mm_kwargs)
1205-
tokenizer = self.info.get_tokenizer()
1206-
hf_config = self.info.get_hf_config()
1207-
1208-
boi_token_id = hf_config.image_start_token_id
1209-
eoi_token_id = hf_config.image_end_token_id
1210-
1211-
bov_token_id = hf_config.video_start_token_id
1212-
eov_token_id = hf_config.video_end_token_id
12131246

12141247
merge_length = image_processor.merge_size**2
12151248

@@ -1227,21 +1260,8 @@ def get_video_replacement_glm4v(item_idx: int):
12271260
assert isinstance(grid_thw, torch.Tensor)
12281261

12291262
video, metadata = mm_items["video"][item_idx]
1230-
timestamps = self.info._get_video_second_idx(metadata, len(video))
1231-
frames_idx_token = [
1232-
tokenizer.encode(str(i), add_special_tokens=False)
1233-
for i in timestamps
1234-
]
1235-
num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
1236-
placeholder = []
1237-
placeholder.append(bov_token_id)
1238-
for frame_idx in frames_idx_token:
1239-
placeholder.append(boi_token_id)
1240-
placeholder.extend([hf_processor.video_token_id] *
1241-
num_tokens_per_frame)
1242-
placeholder.append(eoi_token_id)
1243-
placeholder.extend(frame_idx)
1244-
placeholder.append(eov_token_id)
1263+
placeholder = self.info._construct_video_placeholder(
1264+
video, metadata, grid_thw)
12451265
return PromptUpdateDetails.select_token_id(
12461266
placeholder,
12471267
embed_token_id=hf_processor.video_token_id,

0 commit comments

Comments
 (0)