Skip to content

Commit 3ce5629

Browse files
zucchini-nlpCyrilvallez
authored andcommitted
[Glm4.5V] fix vLLM support (#40696)
* fix * add a test case
1 parent 26a7e6d commit 3ce5629

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

src/transformers/models/glm4v/image_processing_glm4v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=Non
453453
"""
454454
patch_size = images_kwargs.get("patch_size", self.patch_size)
455455
merge_size = images_kwargs.get("merge_size", self.merge_size)
456-
size = images_kwargs.get("size", self.size)
456+
size = images_kwargs.get("size", {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000})
457457

458458
factor = patch_size * merge_size
459459
resized_height, resized_width = smart_resize(

src/transformers/video_processing_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,14 @@ def _decode_and_sample_videos(
305305
# Only sample frames if an array video is passed, otherwise first decode -> then sample
306306
if is_valid_video(videos[0]) and do_sample_frames:
307307
sampled_videos = []
308+
sampled_metadata = []
308309
for video, metadata in zip(videos, video_metadata):
309310
indices = sample_indices_fn(metadata=metadata)
311+
metadata.frames_indices = indices
310312
sampled_videos.append(video[indices])
313+
sampled_metadata.append(metadata)
311314
videos = sampled_videos
315+
video_metadata = sampled_metadata
312316
elif not is_valid_video(videos[0]):
313317
if isinstance(videos[0], list):
314318
# Videos sometimes are passed as a list of image URLs, especially through templates

src/transformers/video_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
import os
1717
import warnings
18-
from collections.abc import Iterable
18+
from collections.abc import Iterable, Mapping
1919
from contextlib import redirect_stdout
20-
from dataclasses import dataclass
20+
from dataclasses import dataclass, fields
2121
from io import BytesIO
2222
from typing import Callable, NewType, Optional, Union
2323
from urllib.parse import urlparse
@@ -78,7 +78,7 @@
7878

7979

8080
@dataclass
81-
class VideoMetadata:
81+
class VideoMetadata(Mapping):
8282
total_num_frames: int
8383
fps: float = None
8484
width: int = None
@@ -87,6 +87,12 @@ class VideoMetadata:
8787
video_backend: str = None
8888
frames_indices: list[int] = None
8989

90+
def __iter__(self):
91+
return (f.name for f in fields(self))
92+
93+
def __len__(self):
94+
return len(fields(self))
95+
9096
def __getitem__(self, item):
9197
return getattr(self, item)
9298

@@ -96,8 +102,8 @@ def __setitem__(self, key, value):
96102
@property
97103
def timestamps(self) -> float:
98104
"Timestamps of the sampled frames in seconds."
99-
if self.fps is None:
100-
raise ValueError("Cannot infer video `timestamps` when `fps` is None.")
105+
if self.fps is None or self.frames_indices is None:
106+
raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.")
101107
return [frame_idx / self.fps for frame_idx in self.frames_indices]
102108

103109
def update(self, dictionary):

tests/test_video_processing_common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,13 @@ def test_call_sample_frames(self):
342342
self.assertEqual(encoded_videos.shape[1], 6)
343343
self.assertEqual(encoded_videos_batched.shape[1], 6)
344344

345+
# The same as above but uses a `VideoMetadata` object in the input
346+
metadata = [[VideoMetadata(duration=2.0, total_num_frames=8, fps=4)]]
347+
batched_metadata = metadata * len(video_inputs)
348+
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3, video_metadata=metadata)[
349+
self.input_name
350+
]
351+
345352
# We should raise error when asked to sample more frames than there are in input video
346353
with self.assertRaises(ValueError):
347354
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=10)[self.input_name]

0 commit comments

Comments
 (0)