Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,18 @@ def _preprocess(
fps: Optional[int] = None,
num_frames: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional["torch.Tensor"] = None,
) -> BatchFeature:
if do_sample_frames:
videos = [
self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata)
]

# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
# moving the whole video incurs high GPU mem usage for long videos
if device is not None:
videos = [video.to(device) for video in videos]

# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/internvl/video_processing_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _preprocess(
num_frames: Optional[int] = None,
initial_shift: Optional[Union[bool, float, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional["torch.Tensor"] = None,
) -> BatchFeature:
if do_sample_frames:
# Sample video frames
Expand All @@ -155,6 +156,11 @@ def _preprocess(
for video, metadata in zip(videos, video_metadata)
]

# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
# moving the whole video incurs high GPU mem usage for long videos
if device is not None:
videos = [video.to(device) for video in videos]

# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _preprocess(
min_frames: Optional[int] = None,
max_frames: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional["torch.Tensor"] = None,
**kwargs,
):
if do_sample_frames:
Expand All @@ -230,6 +231,11 @@ def _preprocess(
for video, metadata in zip(videos, video_metadata)
]

# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
# moving the whole video incurs high GPU mem usage for long videos
if device is not None:
videos = [video.to(device) for video in videos]

# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/smolvlm/video_processing_smolvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def _preprocess(
num_frames: Optional[int] = None,
skip_secs: Optional[int] = 0,
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional["torch.Tensor"] = None,
**kwargs,
):
# Group videos by size for batched resizing
Expand All @@ -356,6 +357,11 @@ def _preprocess(
]
durations_list = [len(video) // 24 for video in videos]

# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
# moving the whole video incurs high GPU mem usage for long videos
if device is not None:
videos = [video.to(device) for video in videos]

grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
resized_videos_grouped = {}
for shape, stacked_videos in grouped_videos.items():
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchcodec_available,
is_torchdynamo_available,
is_torchvision_available,
is_vision_available,
Expand Down Expand Up @@ -634,6 +635,16 @@ def require_torchvision(test_case):
return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)


def require_torchcodec(test_case):
"""
Decorator marking a test that requires Torchcodec.

These tests are skipped when Torchcodec isn't installed.

"""
return unittest.skipUnless(is_torchcodec_available(), "test requires Torchvision")(test_case)


def require_torch_or_tf(test_case):
"""
Decorator marking a test that requires PyTorch or TensorFlow.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchcodec_available,
is_torchdistx_available,
is_torchdynamo_available,
is_torchdynamo_compiling,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
_av_available = importlib.util.find_spec("av") is not None
_decord_available = importlib.util.find_spec("decord") is not None
_torchcodec_available = importlib.util.find_spec("torchcodec") is not None
_bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq")
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
Expand Down Expand Up @@ -976,6 +977,10 @@ def is_decord_available():
return _decord_available


def is_torchcodec_available():
return _torchcodec_available


def is_ninja_available():
r"""
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
Expand Down Expand Up @@ -1502,6 +1507,14 @@ def check_torch_load_is_safe():
Please note that you may need to restart your runtime after installation.
"""

TORCHCODEC_IMPORT_ERROR = """
{0} requires the TorchCodec (https://github.com/pytorch/torchcodec) library, but it was not found in your environment. You can install it with:
```
pip install torchcodec
```
Please note that you may need to restart your runtime after installation.
"""

# docstyle-ignore
CV2_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
Expand Down Expand Up @@ -1882,6 +1895,7 @@ def check_torch_load_is_safe():
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
("torchcodec", (is_torchcodec_available, TORCHCODEC_IMPORT_ERROR)),
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
Expand Down
16 changes: 9 additions & 7 deletions src/transformers/video_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def _prepare_input_videos(
videos: VideoInput,
video_metadata: VideoMetadata = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
) -> list["torch.Tensor"]:
"""
Prepare the input videos for processing.
Expand All @@ -313,10 +312,6 @@ def _prepare_input_videos(
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
video = torch.from_numpy(video).contiguous()

# Now that we have torch tensors, we can move them to the right device
if device is not None:
video = video.to(device)

processed_videos.append(video)
return processed_videos, batch_metadata

Expand All @@ -336,10 +331,9 @@ def preprocess(
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))

input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
video_metadata = kwargs.pop("video_metadata")
videos, video_metadata = self._prepare_input_videos(
videos=videos, video_metadata=video_metadata, input_data_format=input_data_format, device=device
videos=videos, video_metadata=video_metadata, input_data_format=input_data_format
)

kwargs = self._further_process_kwargs(**kwargs)
Expand Down Expand Up @@ -378,6 +372,7 @@ def _preprocess(
fps: Optional[int] = None,
num_frames: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional["torch.Tensor"] = None,
) -> BatchFeature:
if do_sample_frames:
# Sample video frames
Expand All @@ -386,6 +381,11 @@ def _preprocess(
for video, metadata in zip(videos, video_metadata)
]

# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
# moving the whole video incurs high GPU mem usage for long videos
if device is not None:
videos = [video.to(device) for video in videos]

# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
Expand Down Expand Up @@ -775,6 +775,8 @@ def to_dict(self) -> dict[str, Any]:
`dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance.
"""
output = copy.deepcopy(self.__dict__)
output.pop("model_valid_processing_keys", None)
output.pop("_valid_kwargs_names", None)
output["video_processor_type"] = self.__class__.__name__

return output
Expand Down
59 changes: 57 additions & 2 deletions src/transformers/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import os
import warnings
from collections.abc import Iterable
from contextlib import redirect_stdout
from dataclasses import dataclass
Expand All @@ -33,6 +34,7 @@
is_numpy_array,
is_torch_available,
is_torch_tensor,
is_torchcodec_available,
is_torchvision_available,
is_vision_available,
is_yt_dlp_available,
Expand Down Expand Up @@ -425,6 +427,10 @@ def sample_indices_fn(metadata, **kwargs):
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
warnings.warn(
"Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
"Please use `torchcodec` instead."
)
video, _, info = torchvision_io.read_video(
video_path,
start_pts=0.0,
Expand All @@ -449,11 +455,59 @@ def sample_indices_fn(metadata, **kwargs):
return video, metadata


def read_video_torchcodec(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with torchcodec decoder.

Args:
video_path (`str`):
Path to the video file.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)

Returns:
Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
requires_backends(read_video_torchcodec, ["torchcodec"])
from torchcodec.decoders import VideoDecoder

decoder = VideoDecoder(
video_path,
dimension_order="NHWC", # to be consistent with other decoders
# Interestingly `exact` mode takes less than approximate when we load the whole video
seek_mode="exact",
Comment on lines +489 to +490
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to load the whole video? I suppose the entire idea is to avoid loading long videos -> save on RAM and increase decoding speed, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally yeah, we should get only the necessary frames. This is a result of prev modifications, which move video_sample logic inside video processors (more intuitive than having it only in apply_templates)

I haven't thought about RAM usage at that time and now I see that it's not very efficient. Seems like the best way for videos would be to load -> sample-with-decoder -> optionally cast to torch -> transforms. Here I am facing an issue, because the "load_media" is decoupled from processors' call. We can load-media only for instruct models when a conversation history is defined, and for base models user are expected to pre-load all images/videos themselves

Do you think we should start allowing users to pass url/path to processor's call directly (like Pixtral already does)? I want to keep sampling code in each model's processing file, to make it explicit for users/contributors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like working with videos is in general less efficient than images...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should start allowing users to pass url/path to processor's call directly (like Pixtral already does)? I want to keep sampling code in each model's processing file, to make it explicit for users/contributors

You mean Processor.__call__ (not the image processor), right? I don't have a strong opinion on it, it looks okay to me. Now that we allow users to pass fps/num frames, it seems like a logical next step to allow reading only the required frames if the backend supports this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, no objections from me as well. The only concern I have, we might be bloating video processors. If we will finally accept "url/path" in ModelProcessor.__call__ then I will think of abstracting it (given we have tons of decoders with many options) (cc @yonigozlan we talked last year on allowing urls for __call__)

Prob by default we won't let users configure decoder-related stuff

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, it's also ok to provide just a basic usage, such as loading with the default settings. And clearly let them know that for an advanced setup, they can read and sample by themselves.

)
metadata = VideoMetadata(
total_num_frames=decoder.metadata.num_frames,
fps=decoder.metadata.average_fps,
duration=decoder.metadata.duration_seconds,
video_backend="torchcodec",
)
indices = sample_indices_fn(metadata=metadata, **kwargs)

video = decoder.get_frames_at(indices=indices).data.contiguous()
metadata.frames_indices = indices
return video, metadata


VIDEO_DECODERS = {
"decord": read_video_decord,
"opencv": read_video_opencv,
"pyav": read_video_pyav,
"torchvision": read_video_torchvision,
"torchcodec": read_video_torchcodec,
}


Expand All @@ -477,7 +531,7 @@ def load_video(
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"pyav"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
Expand Down Expand Up @@ -535,7 +589,7 @@ def sample_indices_fn_func(metadata, **fn_kwargs):
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend in ["opencv", "torchvision"]:
raise ValueError(
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
"If you are trying to load a video from URL, you can decode the video only with `pyav`, `decord` or `torchcodec` as backend"
)

if file_obj is None:
Expand All @@ -546,6 +600,7 @@ def sample_indices_fn_func(metadata, **fn_kwargs):
or (not is_av_available() and backend == "pyav")
or (not is_cv2_available() and backend == "opencv")
or (not is_torchvision_available() and backend == "torchvision")
or (not is_torchcodec_available() and backend == "torchcodec")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
Expand Down
13 changes: 13 additions & 0 deletions tests/utils/test_video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
require_cv2,
require_decord,
require_torch,
require_torchcodec,
require_torchvision,
require_vision,
)
Expand Down Expand Up @@ -261,6 +262,7 @@ def test_load_video_local(self):

@require_decord
@require_torchvision
@require_torchcodec
@require_cv2
def test_load_video_backend_url(self):
video, _ = load_video(
Expand All @@ -269,6 +271,12 @@ def test_load_video_backend_url(self):
)
self.assertEqual(video.shape, (243, 360, 640, 3))

video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
backend="torchcodec",
)
self.assertEqual(video.shape, (243, 360, 640, 3))

# Can't use certain backends with url
with self.assertRaises(ValueError):
video, _ = load_video(
Expand All @@ -283,6 +291,7 @@ def test_load_video_backend_url(self):

@require_decord
@require_torchvision
@require_torchcodec
@require_cv2
def test_load_video_backend_local(self):
video_file_path = hf_hub_download(
Expand All @@ -300,6 +309,10 @@ def test_load_video_backend_local(self):
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)

video, metadata = load_video(video_file_path, backend="torchcodec")
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)

def test_load_video_num_frames(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
Expand Down