Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,6 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
frames.sizes()[1] == 3,
"frame must have 3 channels (R, G, B), got ",
frames.sizes()[1]);
// TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
return frames.contiguous();
}

Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from ._audio_encoder import AudioEncoder # noqa
from ._video_encoder import VideoEncoder # noqa
97 changes: 97 additions & 0 deletions src/torchcodec/encoders/_video_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from pathlib import Path
from typing import Union

import torch
from torch import Tensor

from torchcodec import _core


class VideoEncoder:
"""A video encoder.
Args:
frames (``torch.Tensor``): The frames to encode. This must be a 4D
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
C is 3 channels (RGB), H is height, and W is width.
A 3D tensor of shape ``(C, H, W)`` is also accepted as a single RGB frame.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        A 3D tensor of shape ``(C, H, W)`` is also accepted as a single RGB frame.

Q - Do we need to support that? I'm wondering if it makes a lot of sense to just encode a single image as a video. I suspect this was made to mimic the AudioEncoder behavior but that was a different use-case. In the AudioEncoder we want to allow for 1D audio to be supported as it's still a valid waveform. But I don't think we need to treat a single frame as a valid video.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the use case is for encoding an image as a video, but since FFmpeg allows encoding an image to video, I believe we can retain this functionality for a relatively low cost.

Values must be uint8 in the range ``[0, 255]``.
frame_rate (int): The frame rate to use when encoding the
**input** ``frames``.
Comment on lines +19 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My interpretation of this description is that this parameter actually defines the the frame rate of the encoded output, because of the "frame rate to use when encoding" part.

I think it might be less ambiguous as

The frame rate of the input frames. This is not the frame rate of the encoded output.

"""

def __init__(self, frames: Tensor, *, frame_rate: int):
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
if not isinstance(frames, Tensor):
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")
if frames.ndim == 3:
# make it 4D and assume single RGB frame, CHW -> NCHW
frames = torch.unsqueeze(frames, 0)
if frames.ndim != 4:
raise ValueError(f"Expected 3D or 4D frames, got {frames.shape = }.")
if frames.dtype != torch.uint8:
raise ValueError(f"Expected uint8 frames, got {frames.dtype = }.")
if frame_rate <= 0:
raise ValueError(f"{frame_rate = } must be > 0.")

self._frames = frames
self._frame_rate = frame_rate

def to_file(
self,
dest: Union[str, Path],
) -> None:
"""Encode frames into a file.
Args:
dest (str or ``pathlib.Path``): The path to the output file, e.g.
``video.mp4``. The extension of the file determines the video
format and container.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the distinction between "format" and "container" here? I would just use one or the other?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The terms are used interchangeably, so I included both here to be understandable to users of both terms. Let me know if that is actually more confusing that just naming one term.

I might have added the distinction after discovering the format mkv is defined as a matroska container. I have not encountered other formats where these are different.

"""
_core.encode_video_to_file(
frames=self._frames,
frame_rate=self._frame_rate,
filename=str(dest),
)

def to_tensor(
self,
format: str,
) -> Tensor:
"""Encode frames into raw bytes, as a 1D uint8 Tensor.
Args:
format (str): The format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif"
Returns:
Tensor: The raw encoded bytes as 4D uint8 Tensor.
"""
return _core.encode_video_to_tensor(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
)

def to_file_like(
self,
file_like,
format: str,
) -> None:
"""Encode frames into a file-like object.
Args:
file_like: A file-like object that supports ``write()`` and
``seek()`` methods, such as io.BytesIO(), an open file in binary
write mode, etc. Methods must have the following signature:
``write(data: bytes) -> int`` and ``seek(offset: int, whence:
int = 0) -> int``.
format (str): The format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif".
"""
_core.encode_video_to_file_like(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
file_like=file_like,
)
114 changes: 113 additions & 1 deletion test/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torchcodec.decoders import AudioDecoder

from torchcodec.encoders import AudioEncoder
from torchcodec.encoders import AudioEncoder, VideoEncoder

from .utils import (
assert_tensor_close_on_at_least,
Expand Down Expand Up @@ -564,3 +564,115 @@ def write(self, data):
RuntimeError, match="File like object must implement a seek method"
):
encoder.to_file_like(NoSeekMethod(), format="wav")


class TestVideoEncoder:
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_bad_input_parameterized(self, tmp_path, method):
if method == "to_file":
valid_params = dict(dest=str(tmp_path / "output.mp4"))
elif method == "to_tensor":
valid_params = dict(format="mp4")
elif method == "to_file_like":
valid_params = dict(file_like=io.BytesIO(), format="mp4")
else:
raise ValueError(f"Unknown method: {method}")

with pytest.raises(
ValueError, match="Expected uint8 frames, got frames.dtype = torch.float32"
):
encoder = VideoEncoder(
frames=torch.rand(5, 3, 64, 64),
frame_rate=30,
)
getattr(encoder, method)(**valid_params)

with pytest.raises(
ValueError, match=r"Expected 3D or 4D frames, got frames.shape = torch.Size"
):
encoder = VideoEncoder(
frames=torch.zeros(10),
frame_rate=30,
)
getattr(encoder, method)(**valid_params)

with pytest.raises(
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
):
encoder = VideoEncoder(
frames=torch.zeros((5, 2, 64, 64), dtype=torch.uint8),
frame_rate=30,
)
getattr(encoder, method)(**valid_params)

def test_bad_input(self, tmp_path):
encoder = VideoEncoder(
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),
frame_rate=30,
)

with pytest.raises(
RuntimeError,
match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?",
):
encoder.to_file("./file.bad_extension")

with pytest.raises(
RuntimeError,
match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?",
):
encoder.to_file("./bad/path.mp3")

with pytest.raises(
RuntimeError,
match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
):
encoder.to_tensor(format="bad_format")

@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_contiguity(self, method, tmp_path):
# Ensure that 2 sets of video frames with the same pixel values are encoded
# in the same way, regardless of their memory layout. Here we encode 2 equal
# frame tensors, one is contiguous while the other is non-contiguous.

num_frames, channels, height, width = 5, 3, 64, 64
contiguous_frames = torch.randint(
0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8
).contiguous()
assert contiguous_frames.is_contiguous()

# Permute NCHW to NHWC, then update the memory layout, then permute back
non_contiguous_frames = (
contiguous_frames.permute(0, 2, 3, 1).contiguous().permute(0, 3, 1, 2)
)
assert non_contiguous_frames.stride() != contiguous_frames.stride()
assert not non_contiguous_frames.is_contiguous()
assert non_contiguous_frames.is_contiguous(memory_format=torch.channels_last)

torch.testing.assert_close(
contiguous_frames, non_contiguous_frames, rtol=0, atol=0
)

def encode_to_tensor(frames):
if method == "to_file":
dest = str(tmp_path / "output.mp4")
VideoEncoder(frames, frame_rate=30).to_file(dest=dest)
with open(dest, "rb") as f:
return torch.frombuffer(f.read(), dtype=torch.uint8)
elif method == "to_tensor":
return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4")
elif method == "to_file_like":
file_like = io.BytesIO()
VideoEncoder(frames, frame_rate=30).to_file_like(
file_like, format="mp4"
)
return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8)
else:
raise ValueError(f"Unknown method: {method}")

encoded_from_contiguous = encode_to_tensor(contiguous_frames)
encoded_from_non_contiguous = encode_to_tensor(non_contiguous_frames)

torch.testing.assert_close(
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
)
64 changes: 1 addition & 63 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,68 +1152,6 @@ def test_bad_input(self, tmp_path):


class TestVideoEncoderOps:
# TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity)
# TODO-VideoEncoder: Parametrize test after moving to test_encoders
def test_bad_input(self, tmp_path):
output_file = str(tmp_path / ".mp4")

with pytest.raises(
RuntimeError, match="frames must have uint8 dtype, got float"
):
encode_video_to_file(
frames=torch.rand((10, 3, 60, 60), dtype=torch.float),
frame_rate=10,
filename=output_file,
)

with pytest.raises(
RuntimeError, match=r"frames must have 4 dimensions \(N, C, H, W\), got 3"
):
encode_video_to_file(
frames=torch.randint(high=1, size=(3, 60, 60), dtype=torch.uint8),
frame_rate=10,
filename=output_file,
)

with pytest.raises(
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
):
encode_video_to_file(
frames=torch.randint(high=1, size=(10, 2, 60, 60), dtype=torch.uint8),
frame_rate=10,
filename=output_file,
)

with pytest.raises(
RuntimeError,
match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?",
):
encode_video_to_file(
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
frame_rate=10,
filename="./file.bad_extension",
)

with pytest.raises(
RuntimeError,
match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?",
):
encode_video_to_file(
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
frame_rate=10,
filename="./bad/path.mp3",
)

with pytest.raises(
RuntimeError,
match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
):
encode_video_to_tensor(
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
frame_rate=10,
format="bad_format",
)

def decode(self, source=None) -> torch.Tensor:
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)

Expand Down Expand Up @@ -1406,7 +1344,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
)

def test_to_file_like_custom_file_object(self):
"""Test with a custom file-like object that implements write and seek."""
"""Test to_file_like with a custom file-like object that implements write and seek."""

class CustomFileObject:
def __init__(self):
Expand Down
Loading