diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 1d9c2c089..4e5d6a604 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -531,7 +531,6 @@ torch::Tensor validateFrames(const torch::Tensor& frames) { frames.sizes()[1] == 3, "frame must have 3 channels (R, G, B), got ", frames.sizes()[1]); - // TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted return frames.contiguous(); } diff --git a/src/torchcodec/encoders/__init__.py b/src/torchcodec/encoders/__init__.py index 51f5942b3..cf78fe427 100644 --- a/src/torchcodec/encoders/__init__.py +++ b/src/torchcodec/encoders/__init__.py @@ -1 +1,2 @@ from ._audio_encoder import AudioEncoder # noqa +from ._video_encoder import VideoEncoder # noqa diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py new file mode 100644 index 000000000..f6a725278 --- /dev/null +++ b/src/torchcodec/encoders/_video_encoder.py @@ -0,0 +1,92 @@ +from pathlib import Path +from typing import Union + +import torch +from torch import Tensor + +from torchcodec import _core + + +class VideoEncoder: + """A video encoder. + + Args: + frames (``torch.Tensor``): The frames to encode. This must be a 4D + tensor of shape ``(N, C, H, W)`` where N is the number of frames, + C is 3 channels (RGB), H is height, and W is width. + Values must be uint8 in the range ``[0, 255]``. + frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. + """ + + def __init__(self, frames: Tensor, *, frame_rate: int): + torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder") + if not isinstance(frames, Tensor): + raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.") + if frames.ndim != 4: + raise ValueError(f"Expected 4D frames, got {frames.shape = }.") + if frames.dtype != torch.uint8: + raise ValueError(f"Expected uint8 frames, got {frames.dtype = }.") + if frame_rate <= 0: + raise ValueError(f"{frame_rate = } must be > 0.") + + self._frames = frames + self._frame_rate = frame_rate + + def to_file( + self, + dest: Union[str, Path], + ) -> None: + """Encode frames into a file. + + Args: + dest (str or ``pathlib.Path``): The path to the output file, e.g. + ``video.mp4``. The extension of the file determines the video + container format. + """ + _core.encode_video_to_file( + frames=self._frames, + frame_rate=self._frame_rate, + filename=str(dest), + ) + + def to_tensor( + self, + format: str, + ) -> Tensor: + """Encode frames into raw bytes, as a 1D uint8 Tensor. + + Args: + format (str): The container format of the encoded frames, e.g. "mp4", "mov", + "mkv", "avi", "webm", "flv", or "gif" + + Returns: + Tensor: The raw encoded bytes as 4D uint8 Tensor. + """ + return _core.encode_video_to_tensor( + frames=self._frames, + frame_rate=self._frame_rate, + format=format, + ) + + def to_file_like( + self, + file_like, + format: str, + ) -> None: + """Encode frames into a file-like object. + + Args: + file_like: A file-like object that supports ``write()`` and + ``seek()`` methods, such as io.BytesIO(), an open file in binary + write mode, etc. Methods must have the following signature: + ``write(data: bytes) -> int`` and ``seek(offset: int, whence: + int = 0) -> int``. + format (str): The container format of the encoded frames, e.g. "mp4", "mov", + "mkv", "avi", "webm", "flv", or "gif". + """ + _core.encode_video_to_file_like( + frames=self._frames, + frame_rate=self._frame_rate, + format=format, + file_like=file_like, + ) diff --git a/test/test_encoders.py b/test/test_encoders.py index c5946654d..b7223c88a 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -11,7 +11,7 @@ import torch from torchcodec.decoders import AudioDecoder -from torchcodec.encoders import AudioEncoder +from torchcodec.encoders import AudioEncoder, VideoEncoder from .utils import ( assert_tensor_close_on_at_least, @@ -564,3 +564,115 @@ def write(self, data): RuntimeError, match="File like object must implement a seek method" ): encoder.to_file_like(NoSeekMethod(), format="wav") + + +class TestVideoEncoder: + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_bad_input_parameterized(self, tmp_path, method): + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp4")) + elif method == "to_tensor": + valid_params = dict(format="mp4") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp4") + else: + raise ValueError(f"Unknown method: {method}") + + with pytest.raises( + ValueError, match="Expected uint8 frames, got frames.dtype = torch.float32" + ): + encoder = VideoEncoder( + frames=torch.rand(5, 3, 64, 64), + frame_rate=30, + ) + getattr(encoder, method)(**valid_params) + + with pytest.raises( + ValueError, match=r"Expected 4D frames, got frames.shape = torch.Size" + ): + encoder = VideoEncoder( + frames=torch.zeros(10), + frame_rate=30, + ) + getattr(encoder, method)(**valid_params) + + with pytest.raises( + RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2" + ): + encoder = VideoEncoder( + frames=torch.zeros((5, 2, 64, 64), dtype=torch.uint8), + frame_rate=30, + ) + getattr(encoder, method)(**valid_params) + + def test_bad_input(self, tmp_path): + encoder = VideoEncoder( + frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8), + frame_rate=30, + ) + + with pytest.raises( + RuntimeError, + match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?", + ): + encoder.to_file("./file.bad_extension") + + with pytest.raises( + RuntimeError, + match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?", + ): + encoder.to_file("./bad/path.mp3") + + with pytest.raises( + RuntimeError, + match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format", + ): + encoder.to_tensor(format="bad_format") + + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_contiguity(self, method, tmp_path): + # Ensure that 2 sets of video frames with the same pixel values are encoded + # in the same way, regardless of their memory layout. Here we encode 2 equal + # frame tensors, one is contiguous while the other is non-contiguous. + + num_frames, channels, height, width = 5, 3, 64, 64 + contiguous_frames = torch.randint( + 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 + ).contiguous() + assert contiguous_frames.is_contiguous() + + # Permute NCHW to NHWC, then update the memory layout, then permute back + non_contiguous_frames = ( + contiguous_frames.permute(0, 2, 3, 1).contiguous().permute(0, 3, 1, 2) + ) + assert non_contiguous_frames.stride() != contiguous_frames.stride() + assert not non_contiguous_frames.is_contiguous() + assert non_contiguous_frames.is_contiguous(memory_format=torch.channels_last) + + torch.testing.assert_close( + contiguous_frames, non_contiguous_frames, rtol=0, atol=0 + ) + + def encode_to_tensor(frames): + if method == "to_file": + dest = str(tmp_path / "output.mp4") + VideoEncoder(frames, frame_rate=30).to_file(dest=dest) + with open(dest, "rb") as f: + return torch.frombuffer(f.read(), dtype=torch.uint8) + elif method == "to_tensor": + return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4") + elif method == "to_file_like": + file_like = io.BytesIO() + VideoEncoder(frames, frame_rate=30).to_file_like( + file_like, format="mp4" + ) + return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8) + else: + raise ValueError(f"Unknown method: {method}") + + encoded_from_contiguous = encode_to_tensor(contiguous_frames) + encoded_from_non_contiguous = encode_to_tensor(non_contiguous_frames) + + torch.testing.assert_close( + encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 + ) diff --git a/test/test_ops.py b/test/test_ops.py index 627829689..e798a7a2b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1152,68 +1152,6 @@ def test_bad_input(self, tmp_path): class TestVideoEncoderOps: - # TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity) - # TODO-VideoEncoder: Parametrize test after moving to test_encoders - def test_bad_input(self, tmp_path): - output_file = str(tmp_path / ".mp4") - - with pytest.raises( - RuntimeError, match="frames must have uint8 dtype, got float" - ): - encode_video_to_file( - frames=torch.rand((10, 3, 60, 60), dtype=torch.float), - frame_rate=10, - filename=output_file, - ) - - with pytest.raises( - RuntimeError, match=r"frames must have 4 dimensions \(N, C, H, W\), got 3" - ): - encode_video_to_file( - frames=torch.randint(high=1, size=(3, 60, 60), dtype=torch.uint8), - frame_rate=10, - filename=output_file, - ) - - with pytest.raises( - RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2" - ): - encode_video_to_file( - frames=torch.randint(high=1, size=(10, 2, 60, 60), dtype=torch.uint8), - frame_rate=10, - filename=output_file, - ) - - with pytest.raises( - RuntimeError, - match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?", - ): - encode_video_to_file( - frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8), - frame_rate=10, - filename="./file.bad_extension", - ) - - with pytest.raises( - RuntimeError, - match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?", - ): - encode_video_to_file( - frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8), - frame_rate=10, - filename="./bad/path.mp3", - ) - - with pytest.raises( - RuntimeError, - match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format", - ): - encode_video_to_tensor( - frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8), - frame_rate=10, - format="bad_format", - ) - def decode(self, source=None) -> torch.Tensor: return VideoDecoder(source).get_frames_in_range(start=0, stop=60) @@ -1406,7 +1344,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): ) def test_to_file_like_custom_file_object(self): - """Test with a custom file-like object that implements write and seek.""" + """Test to_file_like with a custom file-like object that implements write and seek.""" class CustomFileObject: def __init__(self):