From 1a5b599bdead292e7363c3aa90c136b543600763 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 17 Jun 2025 15:47:32 +0200 Subject: [PATCH 1/9] don't move the whole video to GPU --- .../models/internvl/video_processing_internvl.py | 6 ++++++ .../models/qwen2_vl/video_processing_qwen2_vl.py | 6 ++++++ .../models/smolvlm/video_processing_smolvlm.py | 6 ++++++ src/transformers/video_processing_utils.py | 14 +++++++------- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/internvl/video_processing_internvl.py b/src/transformers/models/internvl/video_processing_internvl.py index 74f5981af95b..173ef2f5856d 100644 --- a/src/transformers/models/internvl/video_processing_internvl.py +++ b/src/transformers/models/internvl/video_processing_internvl.py @@ -147,6 +147,7 @@ def _preprocess( num_frames: Optional[int] = None, initial_shift: Optional[Union[bool, float, int]] = None, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, ) -> BatchFeature: if do_sample_frames: # Sample video frames @@ -155,6 +156,11 @@ def _preprocess( for video, metadata in zip(videos, video_metadata) ] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} diff --git a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py index 49a4e9d2efca..e993d4a22a9c 100644 --- a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py @@ -213,6 +213,7 @@ def _preprocess( min_frames: Optional[int] = None, max_frames: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, **kwargs, ): if do_sample_frames: @@ -230,6 +231,11 @@ def _preprocess( for video, metadata in zip(videos, video_metadata) ] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} diff --git a/src/transformers/models/smolvlm/video_processing_smolvlm.py b/src/transformers/models/smolvlm/video_processing_smolvlm.py index dbfc94ba2f8d..76afa7cbfec4 100644 --- a/src/transformers/models/smolvlm/video_processing_smolvlm.py +++ b/src/transformers/models/smolvlm/video_processing_smolvlm.py @@ -332,6 +332,7 @@ def _preprocess( num_frames: Optional[int] = None, skip_secs: Optional[int] = 0, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, **kwargs, ): # Group videos by size for batched resizing @@ -356,6 +357,11 @@ def _preprocess( ] durations_list = [len(video) // 24 for video in videos] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos) resized_videos_grouped = {} for shape, stacked_videos in grouped_videos.items(): diff --git a/src/transformers/video_processing_utils.py b/src/transformers/video_processing_utils.py index c316b8ffb027..23f150b42563 100644 --- a/src/transformers/video_processing_utils.py +++ b/src/transformers/video_processing_utils.py @@ -294,7 +294,6 @@ def _prepare_input_videos( videos: VideoInput, video_metadata: VideoMetadata = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, - device: Optional["torch.device"] = None, ) -> List["torch.Tensor"]: """ Prepare the input videos for processing. @@ -313,10 +312,6 @@ def _prepare_input_videos( # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays video = torch.from_numpy(video).contiguous() - # Now that we have torch tensors, we can move them to the right device - if device is not None: - video = video.to(device) - processed_videos.append(video) return processed_videos, batch_metadata @@ -336,10 +331,9 @@ def preprocess( kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") video_metadata = kwargs.pop("video_metadata") videos, video_metadata = self._prepare_input_videos( - videos=videos, video_metadata=video_metadata, input_data_format=input_data_format, device=device + videos=videos, video_metadata=video_metadata, input_data_format=input_data_format ) kwargs = self._further_process_kwargs(**kwargs) @@ -378,6 +372,7 @@ def _preprocess( fps: Optional[int] = None, num_frames: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, ) -> BatchFeature: if do_sample_frames: # Sample video frames @@ -386,6 +381,11 @@ def _preprocess( for video, metadata in zip(videos, video_metadata) ] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} From e4b3478727bf4e15af58cf4dade6a8783ad9938f Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 18 Jun 2025 10:20:00 +0200 Subject: [PATCH 2/9] add torchcodec --- src/transformers/utils/import_utils.py | 14 ++++++ src/transformers/video_utils.py | 59 +++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ebd1ae9ef19e..b55f39acf5ce 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -119,6 +119,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _vptq_available, _vptq_version = _is_package_available("vptq", return_version=True) _av_available = importlib.util.find_spec("av") is not None _decord_available = importlib.util.find_spec("decord") is not None +_torchcodec_available = importlib.util.find_spec("torchcoded") is not None _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") _fbgemm_gpu_available = _is_package_available("fbgemm_gpu") @@ -957,6 +958,10 @@ def is_decord_available(): return _decord_available +def is_torchcodec_available(): + return _torchcodec_available + + def is_ninja_available(): r""" Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the @@ -1479,6 +1484,14 @@ def check_torch_load_is_safe(): Please note that you may need to restart your runtime after installation. """ +TORCHCODEC_IMPORT_ERROR = """ +{0} requires the PyAv library but it was not found in your environment. You can install it with: +``` +pip install torchcodec +``` +Please note that you may need to restart your runtime after installation. +""" + # docstyle-ignore CV2_IMPORT_ERROR = """ {0} requires the OpenCV library but it was not found in your environment. You can install it with: @@ -1859,6 +1872,7 @@ def check_torch_load_is_safe(): ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)), + ("torchcodec", (is_torchcodec_available, TORCHCODEC_IMPORT_ERROR)), ("vision", (is_vision_available, VISION_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), diff --git a/src/transformers/video_utils.py b/src/transformers/video_utils.py index 7e23e377d8b9..1830834c7faf 100644 --- a/src/transformers/video_utils.py +++ b/src/transformers/video_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import warnings from collections.abc import Iterable from contextlib import redirect_stdout from dataclasses import dataclass @@ -33,6 +34,7 @@ is_numpy_array, is_torch_available, is_torch_tensor, + is_torchcodec_available, is_torchvision_available, is_vision_available, is_yt_dlp_available, @@ -425,6 +427,10 @@ def sample_indices_fn(metadata, **kwargs): - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - `VideoMetadata` object. """ + warnings.warn( + "Using `torchvision` for video decoding is deprecated and will be removed in future versions. " + "Please use `torchcodec` instead." + ) video, _, info = torchvision_io.read_video( video_path, start_pts=0.0, @@ -449,11 +455,59 @@ def sample_indices_fn(metadata, **kwargs): return video, metadata +def read_video_torchcodec( + video_path: str, + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode the video with torchcodec decoder. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import av + requires_backends(read_video_torchcodec, ["torchcodec"]) + from torchcodec.decoders import VideoDecoder + + decoder = VideoDecoder( + video_path, + dimension_order="NCHW", + # Interestingly `exact` mode takes less than approximate when we load the whole video + seek_mode="exact", + ) + metadata = VideoMetadata( + total_num_frames=decoder.metadata.num_frames, + fps=decoder.metadata.average_fps, + duration=decoder.metadata.duration_seconds, + video_backend="torchcodec", + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) + + video = decoder.get_frames_at(indices=indices).data.contiguous() + metadata.frames_indices = indices + return video, metadata + + VIDEO_DECODERS = { "decord": read_video_decord, "opencv": read_video_opencv, "pyav": read_video_pyav, "torchvision": read_video_torchvision, + "torchcodec": read_video_torchcodec, } @@ -477,7 +531,7 @@ def load_video( Number of frames to sample per second. Should be passed only when `num_frames=None`. If not specified and `num_frames==None`, all frames are sampled. backend (`str`, *optional*, defaults to `"pyav"`): - The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav". + The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav". sample_indices_fn (`Callable`, *optional*): A callable function that will return indices at which the video should be sampled. If the video has to be loaded using by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. @@ -535,7 +589,7 @@ def sample_indices_fn_func(metadata, **fn_kwargs): video_is_url = video.startswith("http://") or video.startswith("https://") if video_is_url and backend in ["opencv", "torchvision"]: raise ValueError( - "If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend" + "If you are trying to load a video from URL, you can decode the video only with `pyav`, `decord` or `torchcodec` as backend" ) if file_obj is None: @@ -546,6 +600,7 @@ def sample_indices_fn_func(metadata, **fn_kwargs): or (not is_av_available() and backend == "pyav") or (not is_cv2_available() and backend == "opencv") or (not is_torchvision_available() and backend == "torchvision") + or (not is_torchcodec_available() and backend == "torchcodec") ): raise ImportError( f"You chose backend={backend} for loading the video but the required library is not found in your environment " From 6bfcce8da388bf1179e8624e0b6a2d795080de9a Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 18 Jun 2025 10:42:10 +0200 Subject: [PATCH 3/9] add tests --- src/transformers/testing_utils.py | 11 +++++++++++ tests/utils/test_video_utils.py | 13 +++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 8f7aedb625a0..f5cf2c21d4df 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -158,6 +158,7 @@ is_torch_xpu_available, is_torchao_available, is_torchaudio_available, + is_torchcodec_available, is_torchdynamo_available, is_torchvision_available, is_vision_available, @@ -634,6 +635,16 @@ def require_torchvision(test_case): return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case) +def require_torchcodec(test_case): + """ + Decorator marking a test that requires Torchcodec. + + These tests are skipped when Torchcodec isn't installed. + + """ + return unittest.skipUnless(is_torchcodec_available(), "test requires Torchvision")(test_case) + + def require_torch_or_tf(test_case): """ Decorator marking a test that requires PyTorch or TensorFlow. diff --git a/tests/utils/test_video_utils.py b/tests/utils/test_video_utils.py index 21a5b44ff8e1..a8bc961ea153 100644 --- a/tests/utils/test_video_utils.py +++ b/tests/utils/test_video_utils.py @@ -27,6 +27,7 @@ require_cv2, require_decord, require_torch, + require_torchcodec, require_torchvision, require_vision, ) @@ -261,6 +262,7 @@ def test_load_video_local(self): @require_decord @require_torchvision + @require_torchcodec @require_cv2 def test_load_video_backend_url(self): video, _ = load_video( @@ -269,6 +271,12 @@ def test_load_video_backend_url(self): ) self.assertEqual(video.shape, (243, 360, 640, 3)) + video, _ = load_video( + "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", + backend="torchcodec", + ) + self.assertEqual(video.shape, (243, 360, 640, 3)) + # Can't use certain backends with url with self.assertRaises(ValueError): video, _ = load_video( @@ -283,6 +291,7 @@ def test_load_video_backend_url(self): @require_decord @require_torchvision + @require_torchcodec @require_cv2 def test_load_video_backend_local(self): video_file_path = hf_hub_download( @@ -300,6 +309,10 @@ def test_load_video_backend_local(self): self.assertEqual(video.shape, (243, 360, 640, 3)) self.assertIsInstance(metadata, VideoMetadata) + video, metadata = load_video(video_file_path, backend="torchcodec") + self.assertEqual(video.shape, (243, 360, 640, 3)) + self.assertIsInstance(metadata, VideoMetadata) + def test_load_video_num_frames(self): video, _ = load_video( "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", From 30f0b4be416217501183f524c310c57034cbb159 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 18 Jun 2025 10:45:30 +0200 Subject: [PATCH 4/9] make style --- src/transformers/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 386a85228ffe..4a742db13466 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -259,6 +259,7 @@ is_torchdynamo_exporting, is_torchvision_available, is_torchvision_v2_available, + is_torchcodec_available, is_training_run_on_sagemaker, is_uroman_available, is_vision_available, From 8cd050166545824564437291fbc239cf206f750f Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 18 Jun 2025 11:09:59 +0200 Subject: [PATCH 5/9] instrucblip as well --- .../instructblipvideo/video_processing_instructblipvideo.py | 6 ++++++ src/transformers/utils/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py index ea08466568e3..330dba0c3b85 100644 --- a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py @@ -94,12 +94,18 @@ def _preprocess( fps: Optional[int] = None, num_frames: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, ) -> BatchFeature: if do_sample_frames: videos = [ self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata) ] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 4a742db13466..8f09939cd268 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -253,13 +253,13 @@ is_torch_xpu_available, is_torchao_available, is_torchaudio_available, + is_torchcodec_available, is_torchdistx_available, is_torchdynamo_available, is_torchdynamo_compiling, is_torchdynamo_exporting, is_torchvision_available, is_torchvision_v2_available, - is_torchcodec_available, is_training_run_on_sagemaker, is_uroman_available, is_vision_available, From 1907bfc3e5eff2b7aefb11b6b571d62fdbcfb344 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 18 Jun 2025 16:06:38 +0200 Subject: [PATCH 6/9] consistency --- src/transformers/video_processing_utils.py | 2 ++ src/transformers/video_utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/video_processing_utils.py b/src/transformers/video_processing_utils.py index 504f54897a53..b21b38d34f08 100644 --- a/src/transformers/video_processing_utils.py +++ b/src/transformers/video_processing_utils.py @@ -775,6 +775,8 @@ def to_dict(self) -> dict[str, Any]: `dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance. """ output = copy.deepcopy(self.__dict__) + output.pop("model_valid_processing_keys", None) + output.pop("_valid_kwargs_names", None) output["video_processor_type"] = self.__class__.__name__ return output diff --git a/src/transformers/video_utils.py b/src/transformers/video_utils.py index 4387c434adf2..430fad42ef6b 100644 --- a/src/transformers/video_utils.py +++ b/src/transformers/video_utils.py @@ -485,7 +485,7 @@ def sample_indices_fn(metadata, **kwargs): decoder = VideoDecoder( video_path, - dimension_order="NCHW", + dimension_order="NHWC", # to be consistent with other decoders # Interestingly `exact` mode takes less than approximate when we load the whole video seek_mode="exact", ) From e1b106b6fecbd7ec260bb312d12a39d81e5866b4 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 20 Jun 2025 12:31:05 +0200 Subject: [PATCH 7/9] Update src/transformers/utils/import_utils.py Co-authored-by: Pavel Iakubovskii --- src/transformers/utils/import_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 5bc22e458b31..90ba221e943a 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1490,7 +1490,7 @@ def check_torch_load_is_safe(): """ TORCHCODEC_IMPORT_ERROR = """ -{0} requires the PyAv library but it was not found in your environment. You can install it with: +{0} requires the TorchCodec (https://github.com/pytorch/torchcodec) library, but it was not found in your environment. You can install it with: ``` pip install torchcodec ``` From e33b3a8d1209a1971045a32a76866e4d98d6e4ec Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 20 Jun 2025 12:31:11 +0200 Subject: [PATCH 8/9] Update src/transformers/utils/import_utils.py Co-authored-by: Pavel Iakubovskii --- src/transformers/utils/import_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 90ba221e943a..9a404dd5f160 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -119,7 +119,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _vptq_available, _vptq_version = _is_package_available("vptq", return_version=True) _av_available = importlib.util.find_spec("av") is not None _decord_available = importlib.util.find_spec("decord") is not None -_torchcodec_available = importlib.util.find_spec("torchcoded") is not None +_torchcodec_available = importlib.util.find_spec("torchcodec") is not None _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") _fbgemm_gpu_available = _is_package_available("fbgemm_gpu") From f55e9b3902697b679b4de9b4a6c59e1a891296fd Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 20 Jun 2025 12:31:18 +0200 Subject: [PATCH 9/9] Update src/transformers/video_utils.py Co-authored-by: Pavel Iakubovskii --- src/transformers/video_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/video_utils.py b/src/transformers/video_utils.py index 430fad42ef6b..ea02eefd5fbd 100644 --- a/src/transformers/video_utils.py +++ b/src/transformers/video_utils.py @@ -479,7 +479,7 @@ def sample_indices_fn(metadata, **kwargs): - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - `VideoMetadata` object. """ - # Lazy import av + # Lazy import torchcodec requires_backends(read_video_torchcodec, ["torchcodec"]) from torchcodec.decoders import VideoDecoder