Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/torchcodec/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
encode_audio_to_file_like,
encode_audio_to_tensor,
encode_video_to_file,
encode_video_to_file_like,
encode_video_to_tensor,
get_ffmpeg_library_versions,
get_frame_at_index,
Expand Down
27 changes: 27 additions & 0 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
m.def(
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
m.def(
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -606,6 +608,30 @@ at::Tensor encode_video_to_tensor(
.encodeToTensor();
}

void _encode_video_to_file_like(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view format,
int64_t file_like_context,
std::optional<int64_t> crf = std::nullopt) {
auto fileLikeContext =
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
TORCH_CHECK(
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);

VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;

VideoEncoder encoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
format,
std::move(avioContextHolder),
videoStreamOptions);
encoder.encode();
}

// For testing only. We need to implement this operation as a core library
// function because what we're testing is round-tripping pts values as
// double-precision floating point numbers from C++ to Python and back to C++.
Expand Down Expand Up @@ -870,6 +896,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
m.impl("encode_video_to_file", &encode_video_to_file);
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
m.impl("_encode_video_to_file_like", &_encode_video_to_file_like);
m.impl("seek_to_pts", &seek_to_pts);
m.impl("add_video_stream", &add_video_stream);
m.impl("_add_video_stream", &_add_video_stream);
Expand Down
41 changes: 41 additions & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def load_torchcodec_shared_libraries():
encode_video_to_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_video_to_tensor.default
)
_encode_video_to_file_like = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns._encode_video_to_file_like.default
)
create_from_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.create_from_tensor.default
)
Expand Down Expand Up @@ -203,6 +206,33 @@ def encode_audio_to_file_like(
)


def encode_video_to_file_like(
frames: torch.Tensor,
frame_rate: int,
format: str,
file_like: Union[io.RawIOBase, io.BufferedIOBase],
crf: Optional[int] = None,
) -> None:
"""Encode video frames to a file-like object.

Args:
frames: Video frames tensor
frame_rate: Frame rate in frames per second
format: Video format (e.g., "mp4", "mov", "mkv")
file_like: File-like object that supports write() and seek() methods
crf: Optional constant rate factor for encoding quality
"""
assert _pybind_ops is not None

_encode_video_to_file_like(
frames,
frame_rate,
format,
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
crf,
)


def get_frames_at_indices(
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -302,6 +332,17 @@ def encode_video_to_tensor_abstract(
return torch.empty([], dtype=torch.long)


@register_fake("torchcodec_ns::_encode_video_to_file_like")
def _encode_video_to_file_like_abstract(
frames: torch.Tensor,
frame_rate: int,
format: str,
file_like_context: int,
crf: Optional[int] = None,
) -> None:
return


@register_fake("torchcodec_ns::create_from_tensor")
def create_from_tensor_abstract(
video_tensor: torch.Tensor, seek_mode: Optional[str]
Expand Down
118 changes: 111 additions & 7 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
create_from_tensor,
encode_audio_to_file,
encode_video_to_file,
encode_video_to_file_like,
encode_video_to_tensor,
get_ffmpeg_library_versions,
get_frame_at_index,
Expand Down Expand Up @@ -1329,7 +1330,7 @@ def test_bad_input(self, tmp_path):


class TestVideoEncoderOps:

# TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity)
# TODO-VideoEncoder: Parametrize test after moving to test_encoders
def test_bad_input(self, tmp_path):
output_file = str(tmp_path / ".mp4")
Expand Down Expand Up @@ -1397,7 +1398,7 @@ def decode(self, source=None) -> torch.Tensor:
@pytest.mark.parametrize(
"format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow))
)
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_video_encoder_round_trip(self, tmp_path, format, method):
# Test that decode(encode(decode(frames))) == decode(frames)
ffmpeg_version = get_ffmpeg_major_version()
Expand All @@ -1424,11 +1425,22 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
**params,
)
round_trip_frames = self.decode(encoded_path).data
else: # to_tensor
elif method == "to_tensor":
encoded_tensor = encode_video_to_tensor(
source_frames, format=format, **params
)
round_trip_frames = self.decode(encoded_tensor).data
elif method == "to_file_like":
file_like = io.BytesIO()
encode_video_to_file_like(
frames=source_frames,
format=format,
file_like=file_like,
**params,
)
round_trip_frames = self.decode(file_like.getvalue()).data
else:
raise ValueError(f"Unknown method: {method}")

assert source_frames.shape == round_trip_frames.shape
assert source_frames.dtype == round_trip_frames.dtype
Expand All @@ -1445,6 +1457,7 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
assert psnr(s_frame, rt_frame) > 30
assert_close(s_frame, rt_frame, atol=atol, rtol=0)

@pytest.mark.slow
@pytest.mark.parametrize(
"format",
(
Expand All @@ -1457,8 +1470,9 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
pytest.param("webm", marks=pytest.mark.slow),
),
)
def test_against_to_file(self, tmp_path, format):
# Test that to_file and to_tensor produce the same results
@pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
def test_against_to_file(self, tmp_path, format, method):
# Test that to_file, to_tensor, and to_file_like produce the same results
ffmpeg_version = get_ffmpeg_major_version()
if format == "webm" and (
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
Expand All @@ -1470,11 +1484,25 @@ def test_against_to_file(self, tmp_path, format):

encoded_file = tmp_path / f"output.{format}"
encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params)
encoded_tensor = encode_video_to_tensor(source_frames, format=format, **params)

if method == "to_tensor":
encoded_output = encode_video_to_tensor(
source_frames, format=format, **params
)
else: # to_file_like
file_like = io.BytesIO()
encode_video_to_file_like(
frames=source_frames,
file_like=file_like,
format=format,
**params,
)
file_like.seek(0)
encoded_output = file_like

torch.testing.assert_close(
self.decode(encoded_file).data,
self.decode(encoded_tensor).data,
self.decode(encoded_output).data,
atol=0,
rtol=0,
)
Expand Down Expand Up @@ -1557,6 +1585,82 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
ff_frame, enc_frame, percentage=percentage, atol=2
)

def test_to_file_like_custom_file_object(self):
"""Test with a custom file-like object that implements write and seek."""

class CustomFileObject:
def __init__(self):
self._file = io.BytesIO()

def write(self, data):
return self._file.write(data)

def seek(self, offset, whence=0):
return self._file.seek(offset, whence)

def get_encoded_data(self):
return self._file.getvalue()

source_frames = self.decode(TEST_SRC_2_720P.path).data
file_like = CustomFileObject()
encode_video_to_file_like(
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
)
decoded_samples = self.decode(file_like.get_encoded_data())

torch.testing.assert_close(
decoded_samples.data,
source_frames,
atol=2,
rtol=0,
)

def test_to_file_like_real_file(self, tmp_path):
"""Test to_file_like with a real file opened in binary write mode."""
source_frames = self.decode(TEST_SRC_2_720P.path).data
file_path = tmp_path / "test_file_like.mp4"

with open(file_path, "wb") as file_like:
encode_video_to_file_like(
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
)
decoded_samples = self.decode(str(file_path))

torch.testing.assert_close(
decoded_samples.data,
source_frames,
atol=2,
rtol=0,
)

def test_to_file_like_bad_methods(self):
source_frames = self.decode(TEST_SRC_2_720P.path).data

class NoWriteMethod:
def seek(self, offset, whence=0):
return 0

with pytest.raises(
RuntimeError, match="File like object must implement a write method"
):
encode_video_to_file_like(
source_frames,
frame_rate=30,
format="mp4",
file_like=NoWriteMethod(),
)

class NoSeekMethod:
def write(self, data):
return len(data)

with pytest.raises(
RuntimeError, match="File like object must implement a seek method"
):
encode_video_to_file_like(
source_frames, frame_rate=30, format="mp4", file_like=NoSeekMethod()
)


if __name__ == "__main__":
pytest.main()
Loading