diff --git a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py index ea08466568e3..330dba0c3b85 100644 --- a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py @@ -94,12 +94,18 @@ def _preprocess( fps: Optional[int] = None, num_frames: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, ) -> BatchFeature: if do_sample_frames: videos = [ self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata) ] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} diff --git a/src/transformers/models/internvl/video_processing_internvl.py b/src/transformers/models/internvl/video_processing_internvl.py index 149f780c8fdb..c9be4ebb94c8 100644 --- a/src/transformers/models/internvl/video_processing_internvl.py +++ b/src/transformers/models/internvl/video_processing_internvl.py @@ -147,6 +147,7 @@ def _preprocess( num_frames: Optional[int] = None, initial_shift: Optional[Union[bool, float, int]] = None, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, ) -> BatchFeature: if do_sample_frames: # Sample video frames @@ -155,6 +156,11 @@ def _preprocess( for video, metadata in zip(videos, video_metadata) ] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} diff --git a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py index 9782964ea137..5640b8d33381 100644 --- a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py @@ -213,6 +213,7 @@ def _preprocess( min_frames: Optional[int] = None, max_frames: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, **kwargs, ): if do_sample_frames: @@ -230,6 +231,11 @@ def _preprocess( for video, metadata in zip(videos, video_metadata) ] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} diff --git a/src/transformers/models/smolvlm/video_processing_smolvlm.py b/src/transformers/models/smolvlm/video_processing_smolvlm.py index d65b8affea85..730079f9b400 100644 --- a/src/transformers/models/smolvlm/video_processing_smolvlm.py +++ b/src/transformers/models/smolvlm/video_processing_smolvlm.py @@ -332,6 +332,7 @@ def _preprocess( num_frames: Optional[int] = None, skip_secs: Optional[int] = 0, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, **kwargs, ): # Group videos by size for batched resizing @@ -356,6 +357,11 @@ def _preprocess( ] durations_list = [len(video) // 24 for video in videos] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos) resized_videos_grouped = {} for shape, stacked_videos in grouped_videos.items(): diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index cedfad084ccd..1a4232adc8c5 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -158,6 +158,7 @@ is_torch_xpu_available, is_torchao_available, is_torchaudio_available, + is_torchcodec_available, is_torchdynamo_available, is_torchvision_available, is_vision_available, @@ -634,6 +635,16 @@ def require_torchvision(test_case): return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case) +def require_torchcodec(test_case): + """ + Decorator marking a test that requires Torchcodec. + + These tests are skipped when Torchcodec isn't installed. + + """ + return unittest.skipUnless(is_torchcodec_available(), "test requires Torchvision")(test_case) + + def require_torch_or_tf(test_case): """ Decorator marking a test that requires PyTorch or TensorFlow. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 21a361628107..6d73b8d0325b 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -254,6 +254,7 @@ is_torch_xpu_available, is_torchao_available, is_torchaudio_available, + is_torchcodec_available, is_torchdistx_available, is_torchdynamo_available, is_torchdynamo_compiling, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a933c9638d6f..0fe8ba55c9e9 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -119,6 +119,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _vptq_available, _vptq_version = _is_package_available("vptq", return_version=True) _av_available = importlib.util.find_spec("av") is not None _decord_available = importlib.util.find_spec("decord") is not None +_torchcodec_available = importlib.util.find_spec("torchcodec") is not None _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") _fbgemm_gpu_available = _is_package_available("fbgemm_gpu") @@ -976,6 +977,10 @@ def is_decord_available(): return _decord_available +def is_torchcodec_available(): + return _torchcodec_available + + def is_ninja_available(): r""" Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the @@ -1502,6 +1507,14 @@ def check_torch_load_is_safe(): Please note that you may need to restart your runtime after installation. """ +TORCHCODEC_IMPORT_ERROR = """ +{0} requires the TorchCodec (https://github.com/pytorch/torchcodec) library, but it was not found in your environment. You can install it with: +``` +pip install torchcodec +``` +Please note that you may need to restart your runtime after installation. +""" + # docstyle-ignore CV2_IMPORT_ERROR = """ {0} requires the OpenCV library but it was not found in your environment. You can install it with: @@ -1882,6 +1895,7 @@ def check_torch_load_is_safe(): ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)), + ("torchcodec", (is_torchcodec_available, TORCHCODEC_IMPORT_ERROR)), ("vision", (is_vision_available, VISION_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), diff --git a/src/transformers/video_processing_utils.py b/src/transformers/video_processing_utils.py index 87130c7fef78..b21b38d34f08 100644 --- a/src/transformers/video_processing_utils.py +++ b/src/transformers/video_processing_utils.py @@ -294,7 +294,6 @@ def _prepare_input_videos( videos: VideoInput, video_metadata: VideoMetadata = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, - device: Optional["torch.device"] = None, ) -> list["torch.Tensor"]: """ Prepare the input videos for processing. @@ -313,10 +312,6 @@ def _prepare_input_videos( # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays video = torch.from_numpy(video).contiguous() - # Now that we have torch tensors, we can move them to the right device - if device is not None: - video = video.to(device) - processed_videos.append(video) return processed_videos, batch_metadata @@ -336,10 +331,9 @@ def preprocess( kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") video_metadata = kwargs.pop("video_metadata") videos, video_metadata = self._prepare_input_videos( - videos=videos, video_metadata=video_metadata, input_data_format=input_data_format, device=device + videos=videos, video_metadata=video_metadata, input_data_format=input_data_format ) kwargs = self._further_process_kwargs(**kwargs) @@ -378,6 +372,7 @@ def _preprocess( fps: Optional[int] = None, num_frames: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional["torch.Tensor"] = None, ) -> BatchFeature: if do_sample_frames: # Sample video frames @@ -386,6 +381,11 @@ def _preprocess( for video, metadata in zip(videos, video_metadata) ] + # We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise + # moving the whole video incurs high GPU mem usage for long videos + if device is not None: + videos = [video.to(device) for video in videos] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} @@ -775,6 +775,8 @@ def to_dict(self) -> dict[str, Any]: `dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance. """ output = copy.deepcopy(self.__dict__) + output.pop("model_valid_processing_keys", None) + output.pop("_valid_kwargs_names", None) output["video_processor_type"] = self.__class__.__name__ return output diff --git a/src/transformers/video_utils.py b/src/transformers/video_utils.py index 71594bb6bc25..ea02eefd5fbd 100644 --- a/src/transformers/video_utils.py +++ b/src/transformers/video_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import warnings from collections.abc import Iterable from contextlib import redirect_stdout from dataclasses import dataclass @@ -33,6 +34,7 @@ is_numpy_array, is_torch_available, is_torch_tensor, + is_torchcodec_available, is_torchvision_available, is_vision_available, is_yt_dlp_available, @@ -425,6 +427,10 @@ def sample_indices_fn(metadata, **kwargs): - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - `VideoMetadata` object. """ + warnings.warn( + "Using `torchvision` for video decoding is deprecated and will be removed in future versions. " + "Please use `torchcodec` instead." + ) video, _, info = torchvision_io.read_video( video_path, start_pts=0.0, @@ -449,11 +455,59 @@ def sample_indices_fn(metadata, **kwargs): return video, metadata +def read_video_torchcodec( + video_path: str, + sample_indices_fn: Callable, + **kwargs, +): + """ + Decode the video with torchcodec decoder. + + Args: + video_path (`str`): + Path to the video file. + sample_indices_fn (`Callable`, *optional*): + A callable function that will return indices at which the video should be sampled. If the video has to be loaded using + by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. + If not provided, simple uniform sampling with fps is performed. + Example: + def sample_indices_fn(metadata, **kwargs): + return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) + + Returns: + Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing: + - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). + - `VideoMetadata` object. + """ + # Lazy import torchcodec + requires_backends(read_video_torchcodec, ["torchcodec"]) + from torchcodec.decoders import VideoDecoder + + decoder = VideoDecoder( + video_path, + dimension_order="NHWC", # to be consistent with other decoders + # Interestingly `exact` mode takes less than approximate when we load the whole video + seek_mode="exact", + ) + metadata = VideoMetadata( + total_num_frames=decoder.metadata.num_frames, + fps=decoder.metadata.average_fps, + duration=decoder.metadata.duration_seconds, + video_backend="torchcodec", + ) + indices = sample_indices_fn(metadata=metadata, **kwargs) + + video = decoder.get_frames_at(indices=indices).data.contiguous() + metadata.frames_indices = indices + return video, metadata + + VIDEO_DECODERS = { "decord": read_video_decord, "opencv": read_video_opencv, "pyav": read_video_pyav, "torchvision": read_video_torchvision, + "torchcodec": read_video_torchcodec, } @@ -477,7 +531,7 @@ def load_video( Number of frames to sample per second. Should be passed only when `num_frames=None`. If not specified and `num_frames==None`, all frames are sampled. backend (`str`, *optional*, defaults to `"pyav"`): - The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav". + The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav". sample_indices_fn (`Callable`, *optional*): A callable function that will return indices at which the video should be sampled. If the video has to be loaded using by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. @@ -535,7 +589,7 @@ def sample_indices_fn_func(metadata, **fn_kwargs): video_is_url = video.startswith("http://") or video.startswith("https://") if video_is_url and backend in ["opencv", "torchvision"]: raise ValueError( - "If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend" + "If you are trying to load a video from URL, you can decode the video only with `pyav`, `decord` or `torchcodec` as backend" ) if file_obj is None: @@ -546,6 +600,7 @@ def sample_indices_fn_func(metadata, **fn_kwargs): or (not is_av_available() and backend == "pyav") or (not is_cv2_available() and backend == "opencv") or (not is_torchvision_available() and backend == "torchvision") + or (not is_torchcodec_available() and backend == "torchcodec") ): raise ImportError( f"You chose backend={backend} for loading the video but the required library is not found in your environment " diff --git a/tests/utils/test_video_utils.py b/tests/utils/test_video_utils.py index 74f81cfe3624..7c598222bd6b 100644 --- a/tests/utils/test_video_utils.py +++ b/tests/utils/test_video_utils.py @@ -27,6 +27,7 @@ require_cv2, require_decord, require_torch, + require_torchcodec, require_torchvision, require_vision, ) @@ -261,6 +262,7 @@ def test_load_video_local(self): @require_decord @require_torchvision + @require_torchcodec @require_cv2 def test_load_video_backend_url(self): video, _ = load_video( @@ -269,6 +271,12 @@ def test_load_video_backend_url(self): ) self.assertEqual(video.shape, (243, 360, 640, 3)) + video, _ = load_video( + "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", + backend="torchcodec", + ) + self.assertEqual(video.shape, (243, 360, 640, 3)) + # Can't use certain backends with url with self.assertRaises(ValueError): video, _ = load_video( @@ -283,6 +291,7 @@ def test_load_video_backend_url(self): @require_decord @require_torchvision + @require_torchcodec @require_cv2 def test_load_video_backend_local(self): video_file_path = hf_hub_download( @@ -300,6 +309,10 @@ def test_load_video_backend_local(self): self.assertEqual(video.shape, (243, 360, 640, 3)) self.assertIsInstance(metadata, VideoMetadata) + video, metadata = load_video(video_file_path, backend="torchcodec") + self.assertEqual(video.shape, (243, 360, 640, 3)) + self.assertIsInstance(metadata, VideoMetadata) + def test_load_video_num_frames(self): video, _ = load_video( "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",