Skip to content

Commit e212ff9

Browse files
zucchini-nlpqubvel
andauthored
[video processor] support torchcodec and decrease cuda memory usage (#38880)
* don't move the whole video to GPU * add torchcodec * add tests * make style * instrucblip as well * consistency * Update src/transformers/utils/import_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/utils/import_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/video_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent 11d0fea commit e212ff9

File tree

10 files changed

+129
-9
lines changed

10 files changed

+129
-9
lines changed

src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,18 @@ def _preprocess(
9494
fps: Optional[int] = None,
9595
num_frames: Optional[int] = None,
9696
return_tensors: Optional[Union[str, TensorType]] = None,
97+
device: Optional["torch.Tensor"] = None,
9798
) -> BatchFeature:
9899
if do_sample_frames:
99100
videos = [
100101
self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata)
101102
]
102103

104+
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
105+
# moving the whole video incurs high GPU mem usage for long videos
106+
if device is not None:
107+
videos = [video.to(device) for video in videos]
108+
103109
# Group videos by size for batched resizing
104110
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
105111
resized_videos_grouped = {}

src/transformers/models/internvl/video_processing_internvl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _preprocess(
147147
num_frames: Optional[int] = None,
148148
initial_shift: Optional[Union[bool, float, int]] = None,
149149
return_tensors: Optional[Union[str, TensorType]] = None,
150+
device: Optional["torch.Tensor"] = None,
150151
) -> BatchFeature:
151152
if do_sample_frames:
152153
# Sample video frames
@@ -155,6 +156,11 @@ def _preprocess(
155156
for video, metadata in zip(videos, video_metadata)
156157
]
157158

159+
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
160+
# moving the whole video incurs high GPU mem usage for long videos
161+
if device is not None:
162+
videos = [video.to(device) for video in videos]
163+
158164
# Group videos by size for batched resizing
159165
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
160166
resized_videos_grouped = {}

src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def _preprocess(
213213
min_frames: Optional[int] = None,
214214
max_frames: Optional[int] = None,
215215
return_tensors: Optional[Union[str, TensorType]] = None,
216+
device: Optional["torch.Tensor"] = None,
216217
**kwargs,
217218
):
218219
if do_sample_frames:
@@ -230,6 +231,11 @@ def _preprocess(
230231
for video, metadata in zip(videos, video_metadata)
231232
]
232233

234+
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
235+
# moving the whole video incurs high GPU mem usage for long videos
236+
if device is not None:
237+
videos = [video.to(device) for video in videos]
238+
233239
# Group videos by size for batched resizing
234240
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
235241
resized_videos_grouped = {}

src/transformers/models/smolvlm/video_processing_smolvlm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def _preprocess(
332332
num_frames: Optional[int] = None,
333333
skip_secs: Optional[int] = 0,
334334
return_tensors: Optional[Union[str, TensorType]] = None,
335+
device: Optional["torch.Tensor"] = None,
335336
**kwargs,
336337
):
337338
# Group videos by size for batched resizing
@@ -356,6 +357,11 @@ def _preprocess(
356357
]
357358
durations_list = [len(video) // 24 for video in videos]
358359

360+
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
361+
# moving the whole video incurs high GPU mem usage for long videos
362+
if device is not None:
363+
videos = [video.to(device) for video in videos]
364+
359365
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
360366
resized_videos_grouped = {}
361367
for shape, stacked_videos in grouped_videos.items():

src/transformers/testing_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@
158158
is_torch_xpu_available,
159159
is_torchao_available,
160160
is_torchaudio_available,
161+
is_torchcodec_available,
161162
is_torchdynamo_available,
162163
is_torchvision_available,
163164
is_vision_available,
@@ -634,6 +635,16 @@ def require_torchvision(test_case):
634635
return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
635636

636637

638+
def require_torchcodec(test_case):
639+
"""
640+
Decorator marking a test that requires Torchcodec.
641+
642+
These tests are skipped when Torchcodec isn't installed.
643+
644+
"""
645+
return unittest.skipUnless(is_torchcodec_available(), "test requires Torchvision")(test_case)
646+
647+
637648
def require_torch_or_tf(test_case):
638649
"""
639650
Decorator marking a test that requires PyTorch or TensorFlow.

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@
254254
is_torch_xpu_available,
255255
is_torchao_available,
256256
is_torchaudio_available,
257+
is_torchcodec_available,
257258
is_torchdistx_available,
258259
is_torchdynamo_available,
259260
is_torchdynamo_compiling,

src/transformers/utils/import_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
119119
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
120120
_av_available = importlib.util.find_spec("av") is not None
121121
_decord_available = importlib.util.find_spec("decord") is not None
122+
_torchcodec_available = importlib.util.find_spec("torchcodec") is not None
122123
_bitsandbytes_available = _is_package_available("bitsandbytes")
123124
_eetq_available = _is_package_available("eetq")
124125
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
@@ -976,6 +977,10 @@ def is_decord_available():
976977
return _decord_available
977978

978979

980+
def is_torchcodec_available():
981+
return _torchcodec_available
982+
983+
979984
def is_ninja_available():
980985
r"""
981986
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
@@ -1502,6 +1507,14 @@ def check_torch_load_is_safe():
15021507
Please note that you may need to restart your runtime after installation.
15031508
"""
15041509

1510+
TORCHCODEC_IMPORT_ERROR = """
1511+
{0} requires the TorchCodec (https://github.com/pytorch/torchcodec) library, but it was not found in your environment. You can install it with:
1512+
```
1513+
pip install torchcodec
1514+
```
1515+
Please note that you may need to restart your runtime after installation.
1516+
"""
1517+
15051518
# docstyle-ignore
15061519
CV2_IMPORT_ERROR = """
15071520
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
@@ -1882,6 +1895,7 @@ def check_torch_load_is_safe():
18821895
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
18831896
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
18841897
("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
1898+
("torchcodec", (is_torchcodec_available, TORCHCODEC_IMPORT_ERROR)),
18851899
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
18861900
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
18871901
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),

src/transformers/video_processing_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ def _prepare_input_videos(
294294
videos: VideoInput,
295295
video_metadata: VideoMetadata = None,
296296
input_data_format: Optional[Union[str, ChannelDimension]] = None,
297-
device: Optional["torch.device"] = None,
298297
) -> list["torch.Tensor"]:
299298
"""
300299
Prepare the input videos for processing.
@@ -313,10 +312,6 @@ def _prepare_input_videos(
313312
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
314313
video = torch.from_numpy(video).contiguous()
315314

316-
# Now that we have torch tensors, we can move them to the right device
317-
if device is not None:
318-
video = video.to(device)
319-
320315
processed_videos.append(video)
321316
return processed_videos, batch_metadata
322317

@@ -336,10 +331,9 @@ def preprocess(
336331
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
337332

338333
input_data_format = kwargs.pop("input_data_format")
339-
device = kwargs.pop("device")
340334
video_metadata = kwargs.pop("video_metadata")
341335
videos, video_metadata = self._prepare_input_videos(
342-
videos=videos, video_metadata=video_metadata, input_data_format=input_data_format, device=device
336+
videos=videos, video_metadata=video_metadata, input_data_format=input_data_format
343337
)
344338

345339
kwargs = self._further_process_kwargs(**kwargs)
@@ -378,6 +372,7 @@ def _preprocess(
378372
fps: Optional[int] = None,
379373
num_frames: Optional[int] = None,
380374
return_tensors: Optional[Union[str, TensorType]] = None,
375+
device: Optional["torch.Tensor"] = None,
381376
) -> BatchFeature:
382377
if do_sample_frames:
383378
# Sample video frames
@@ -386,6 +381,11 @@ def _preprocess(
386381
for video, metadata in zip(videos, video_metadata)
387382
]
388383

384+
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
385+
# moving the whole video incurs high GPU mem usage for long videos
386+
if device is not None:
387+
videos = [video.to(device) for video in videos]
388+
389389
# Group videos by size for batched resizing
390390
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
391391
resized_videos_grouped = {}
@@ -775,6 +775,8 @@ def to_dict(self) -> dict[str, Any]:
775775
`dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance.
776776
"""
777777
output = copy.deepcopy(self.__dict__)
778+
output.pop("model_valid_processing_keys", None)
779+
output.pop("_valid_kwargs_names", None)
778780
output["video_processor_type"] = self.__class__.__name__
779781

780782
return output

src/transformers/video_utils.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import os
17+
import warnings
1718
from collections.abc import Iterable
1819
from contextlib import redirect_stdout
1920
from dataclasses import dataclass
@@ -33,6 +34,7 @@
3334
is_numpy_array,
3435
is_torch_available,
3536
is_torch_tensor,
37+
is_torchcodec_available,
3638
is_torchvision_available,
3739
is_vision_available,
3840
is_yt_dlp_available,
@@ -425,6 +427,10 @@ def sample_indices_fn(metadata, **kwargs):
425427
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
426428
- `VideoMetadata` object.
427429
"""
430+
warnings.warn(
431+
"Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
432+
"Please use `torchcodec` instead."
433+
)
428434
video, _, info = torchvision_io.read_video(
429435
video_path,
430436
start_pts=0.0,
@@ -449,11 +455,59 @@ def sample_indices_fn(metadata, **kwargs):
449455
return video, metadata
450456

451457

458+
def read_video_torchcodec(
459+
video_path: str,
460+
sample_indices_fn: Callable,
461+
**kwargs,
462+
):
463+
"""
464+
Decode the video with torchcodec decoder.
465+
466+
Args:
467+
video_path (`str`):
468+
Path to the video file.
469+
sample_indices_fn (`Callable`, *optional*):
470+
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
471+
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
472+
If not provided, simple uniform sampling with fps is performed.
473+
Example:
474+
def sample_indices_fn(metadata, **kwargs):
475+
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
476+
477+
Returns:
478+
Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
479+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
480+
- `VideoMetadata` object.
481+
"""
482+
# Lazy import torchcodec
483+
requires_backends(read_video_torchcodec, ["torchcodec"])
484+
from torchcodec.decoders import VideoDecoder
485+
486+
decoder = VideoDecoder(
487+
video_path,
488+
dimension_order="NHWC", # to be consistent with other decoders
489+
# Interestingly `exact` mode takes less than approximate when we load the whole video
490+
seek_mode="exact",
491+
)
492+
metadata = VideoMetadata(
493+
total_num_frames=decoder.metadata.num_frames,
494+
fps=decoder.metadata.average_fps,
495+
duration=decoder.metadata.duration_seconds,
496+
video_backend="torchcodec",
497+
)
498+
indices = sample_indices_fn(metadata=metadata, **kwargs)
499+
500+
video = decoder.get_frames_at(indices=indices).data.contiguous()
501+
metadata.frames_indices = indices
502+
return video, metadata
503+
504+
452505
VIDEO_DECODERS = {
453506
"decord": read_video_decord,
454507
"opencv": read_video_opencv,
455508
"pyav": read_video_pyav,
456509
"torchvision": read_video_torchvision,
510+
"torchcodec": read_video_torchcodec,
457511
}
458512

459513

@@ -477,7 +531,7 @@ def load_video(
477531
Number of frames to sample per second. Should be passed only when `num_frames=None`.
478532
If not specified and `num_frames==None`, all frames are sampled.
479533
backend (`str`, *optional*, defaults to `"pyav"`):
480-
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
534+
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
481535
sample_indices_fn (`Callable`, *optional*):
482536
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
483537
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
@@ -535,7 +589,7 @@ def sample_indices_fn_func(metadata, **fn_kwargs):
535589
video_is_url = video.startswith("http://") or video.startswith("https://")
536590
if video_is_url and backend in ["opencv", "torchvision"]:
537591
raise ValueError(
538-
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
592+
"If you are trying to load a video from URL, you can decode the video only with `pyav`, `decord` or `torchcodec` as backend"
539593
)
540594

541595
if file_obj is None:
@@ -546,6 +600,7 @@ def sample_indices_fn_func(metadata, **fn_kwargs):
546600
or (not is_av_available() and backend == "pyav")
547601
or (not is_cv2_available() and backend == "opencv")
548602
or (not is_torchvision_available() and backend == "torchvision")
603+
or (not is_torchcodec_available() and backend == "torchcodec")
549604
):
550605
raise ImportError(
551606
f"You chose backend={backend} for loading the video but the required library is not found in your environment "

tests/utils/test_video_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
require_cv2,
2828
require_decord,
2929
require_torch,
30+
require_torchcodec,
3031
require_torchvision,
3132
require_vision,
3233
)
@@ -261,6 +262,7 @@ def test_load_video_local(self):
261262

262263
@require_decord
263264
@require_torchvision
265+
@require_torchcodec
264266
@require_cv2
265267
def test_load_video_backend_url(self):
266268
video, _ = load_video(
@@ -269,6 +271,12 @@ def test_load_video_backend_url(self):
269271
)
270272
self.assertEqual(video.shape, (243, 360, 640, 3))
271273

274+
video, _ = load_video(
275+
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
276+
backend="torchcodec",
277+
)
278+
self.assertEqual(video.shape, (243, 360, 640, 3))
279+
272280
# Can't use certain backends with url
273281
with self.assertRaises(ValueError):
274282
video, _ = load_video(
@@ -283,6 +291,7 @@ def test_load_video_backend_url(self):
283291

284292
@require_decord
285293
@require_torchvision
294+
@require_torchcodec
286295
@require_cv2
287296
def test_load_video_backend_local(self):
288297
video_file_path = hf_hub_download(
@@ -300,6 +309,10 @@ def test_load_video_backend_local(self):
300309
self.assertEqual(video.shape, (243, 360, 640, 3))
301310
self.assertIsInstance(metadata, VideoMetadata)
302311

312+
video, metadata = load_video(video_file_path, backend="torchcodec")
313+
self.assertEqual(video.shape, (243, 360, 640, 3))
314+
self.assertIsInstance(metadata, VideoMetadata)
315+
303316
def test_load_video_num_frames(self):
304317
video, _ = load_video(
305318
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",

0 commit comments

Comments
 (0)