From 2c7722f5d2f3dddf1e8db7f305f43751d8f858f6 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Wed, 18 Jun 2025 13:37:00 -0700 Subject: [PATCH 1/8] Calculate num_frames, add mock tests --- src/torchcodec/_core/_metadata.py | 9 +++- test/test_decoders.py | 72 +++++++++++++++++++++++++++++++ test/test_metadata.py | 30 ++++++++++++- 3 files changed, 109 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index c15e86e74..2ce7cc75f 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -129,8 +129,15 @@ def num_frames(self) -> Optional[int]: """ if self.num_frames_from_content is not None: return self.num_frames_from_content - else: + elif self.num_frames_from_header is not None: return self.num_frames_from_header + elif ( + self.average_fps_from_header is not None + and self.duration_seconds_from_header is not None + ): + return int(self.average_fps_from_header * self.duration_seconds_from_header) + else: + return None @property def average_fps(self) -> Optional[float]: diff --git a/test/test_decoders.py b/test/test_decoders.py index 0b7912387..99c7445ba 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -6,6 +6,8 @@ import contextlib import gc +import json +from unittest.mock import patch import numpy import pytest @@ -738,6 +740,76 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + @patch("torchcodec._core._metadata._get_stream_json_metadata") + def test_get_frames_with_missing_num_frames_metadata( + self, mock_get_stream_json_metadata, device, seek_mode + ): + # Create a mock stream_dict to test that initializing VideoDecoder without + # num_frames_from_header and num_frames_from_content calculates num_frames + # using the average_fps and duration_seconds metadata. + mock_stream_dict = { + "averageFpsFromHeader": 29.97003, + "beginStreamSecondsFromContent": 0.0, + "beginStreamSecondsFromHeader": 0.0, + "bitRate": 128783.0, + "codec": "h264", + "durationSecondsFromHeader": 13.013, + "endStreamSecondsFromContent": 13.013, + "width": 480, + "height": 270, + "mediaType": "video", + "numFramesFromHeader": None, + "numFramesFromContent": None, + } + # Set the return value of the mock to be the mock_stream_dict + mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict) + + decoder = VideoDecoder( + NASA_VIDEO.path, + stream_index=3, + device=device, + seek_mode=seek_mode, + ) + + assert decoder.metadata.num_frames_from_header is None + assert decoder.metadata.num_frames_from_content is None + assert decoder.metadata.duration_seconds is not None + assert decoder.metadata.average_fps is not None + assert decoder.metadata.num_frames == int( + decoder.metadata.duration_seconds * decoder.metadata.average_fps + ) + + # Test get_frames_in_range + ref_frames9 = NASA_VIDEO.get_frame_data_by_range( + start=9, stop=10, stream_index=3 + ).to(device) + frames9 = decoder.get_frames_in_range(start=9, stop=10) + assert_frames_equal(ref_frames9, frames9.data) + + # Test get_frame_at + ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9, stream_index=3).to(device) + frame9 = decoder.get_frame_at(9) + torch.testing.assert_close(ref_frame9, frame9.data) + + # Test get_frames_at + indices = [0, 1, 25, 35] + ref_frames = [ + NASA_VIDEO.get_frame_data_by_index(i, stream_index=3).to(device) + for i in indices + ] + frames = decoder.get_frames_at(indices) + for ref, frame in zip(ref_frames, frames.data): + torch.testing.assert_close(ref, frame) + + # Test get_frames_played_in_range to get all frames + assert decoder.metadata.end_stream_seconds is not None + all_frames = decoder.get_frames_played_in_range( + decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds + ) + assert_frames_equal(all_frames.data, decoder[:]) + @pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"]) @pytest.mark.parametrize( "frame_getter", diff --git a/test/test_metadata.py b/test/test_metadata.py index 7ed9508fd..6814cefc7 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -119,7 +119,7 @@ def test_get_metadata_audio_file(metadata_getter): @pytest.mark.parametrize( "num_frames_from_header, num_frames_from_content, expected_num_frames", - [(None, 10, 10), (10, None, 10), (None, None, None)], + [(10, 20, 20), (None, 10, 10), (10, None, 10)], ) def test_num_frames_fallback( num_frames_from_header, num_frames_from_content, expected_num_frames @@ -143,6 +143,34 @@ def test_num_frames_fallback( assert metadata.num_frames == expected_num_frames +@pytest.mark.parametrize( + "average_fps_from_header, duration_seconds_from_header, expected_num_frames", + [(60, 10, 600), (60, None, None), (None, 10, None), (None, None, None)], +) +def test_calculate_num_frames_using_fps_and_duration( + average_fps_from_header, duration_seconds_from_header, expected_num_frames +): + """Check that if num_frames_from_content and num_frames_from_header are missing, + `.num_frames` is calculated using average_fps_from_header and duration_seconds_from_header + """ + metadata = VideoStreamMetadata( + duration_seconds_from_header=duration_seconds_from_header, + bit_rate=123, + num_frames_from_header=None, # None to test calculating num_frames + num_frames_from_content=None, # None to test calculating num_frames + begin_stream_seconds_from_header=0, + begin_stream_seconds_from_content=0, + end_stream_seconds_from_content=4, + codec="whatever", + width=123, + height=321, + average_fps_from_header=average_fps_from_header, + stream_index=0, + ) + + assert metadata.num_frames == expected_num_frames + + def test_repr(): # Test for calls to print(), str(), etc. Useful to make sure we don't forget # to add additional @properties to __repr__ From 8167eb2fbb8bae1e8f5b4fefc5358f665702cd95 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 23 Jun 2025 07:43:00 -0700 Subject: [PATCH 2/8] Add comment when num_frames is none --- src/torchcodec/_core/_metadata.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 2ce7cc75f..e2776b475 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -123,9 +123,10 @@ def end_stream_seconds(self) -> Optional[float]: @property def num_frames(self) -> Optional[int]: - """Number of frames in the stream. This corresponds to - ``num_frames_from_content`` if a :term:`scan` was made, otherwise it - corresponds to ``num_frames_from_header``. + """Number of frames in the stream (int or None). + This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, + otherwise it corresponds to ``num_frames_from_header``. If it is None, + the number of frames is calculated from the duration and the average fps. """ if self.num_frames_from_content is not None: return self.num_frames_from_content From 0ae588f2e6d24a7ebe7aac8cb32e0e82fa603a3d Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 23 Jun 2025 07:43:39 -0700 Subject: [PATCH 3/8] Use common frames comparison function --- test/test_decoders.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_decoders.py b/test/test_decoders.py index 99c7445ba..e90a2aed6 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -780,6 +780,7 @@ def test_get_frames_with_missing_num_frames_metadata( assert decoder.metadata.num_frames == int( decoder.metadata.duration_seconds * decoder.metadata.average_fps ) + assert len(decoder) == 390 # Test get_frames_in_range ref_frames9 = NASA_VIDEO.get_frame_data_by_range( @@ -791,7 +792,7 @@ def test_get_frames_with_missing_num_frames_metadata( # Test get_frame_at ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9, stream_index=3).to(device) frame9 = decoder.get_frame_at(9) - torch.testing.assert_close(ref_frame9, frame9.data) + assert_frames_equal(ref_frame9, frame9.data) # Test get_frames_at indices = [0, 1, 25, 35] @@ -801,7 +802,7 @@ def test_get_frames_with_missing_num_frames_metadata( ] frames = decoder.get_frames_at(indices) for ref, frame in zip(ref_frames, frames.data): - torch.testing.assert_close(ref, frame) + assert_frames_equal(ref, frame) # Test get_frames_played_in_range to get all frames assert decoder.metadata.end_stream_seconds is not None From e71ab069c8fc940652c4b76e04aea2cf47301482 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 23 Jun 2025 13:08:07 -0700 Subject: [PATCH 4/8] Remove unnecessary get_frames functions, add comment --- test/test_decoders.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/test/test_decoders.py b/test/test_decoders.py index e90a2aed6..fbe552908 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -782,35 +782,14 @@ def test_get_frames_with_missing_num_frames_metadata( ) assert len(decoder) == 390 - # Test get_frames_in_range + # Test get_frames_in_range Python logic which uses the num_frames metadata mocked earlier. + # The frame is read at the C++ level. ref_frames9 = NASA_VIDEO.get_frame_data_by_range( start=9, stop=10, stream_index=3 ).to(device) frames9 = decoder.get_frames_in_range(start=9, stop=10) assert_frames_equal(ref_frames9, frames9.data) - # Test get_frame_at - ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9, stream_index=3).to(device) - frame9 = decoder.get_frame_at(9) - assert_frames_equal(ref_frame9, frame9.data) - - # Test get_frames_at - indices = [0, 1, 25, 35] - ref_frames = [ - NASA_VIDEO.get_frame_data_by_index(i, stream_index=3).to(device) - for i in indices - ] - frames = decoder.get_frames_at(indices) - for ref, frame in zip(ref_frames, frames.data): - assert_frames_equal(ref, frame) - - # Test get_frames_played_in_range to get all frames - assert decoder.metadata.end_stream_seconds is not None - all_frames = decoder.get_frames_played_in_range( - decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds - ) - assert_frames_equal(all_frames.data, decoder[:]) - @pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"]) @pytest.mark.parametrize( "frame_getter", From 38aa16b4edf7e26645da58f439f4dbc5b63e0632 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 23 Jun 2025 13:12:10 -0700 Subject: [PATCH 5/8] fallback to calculate duration using frames and fps --- src/torchcodec/_core/_metadata.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index e2776b475..2e254351d 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -88,14 +88,22 @@ def duration_seconds(self) -> Optional[float]: fall back to ``duration_seconds_from_header``. """ if ( - self.end_stream_seconds_from_content is None - or self.begin_stream_seconds_from_content is None + self.end_stream_seconds_from_content is not None + and self.begin_stream_seconds_from_content is not None ): + return ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) + elif self.duration_seconds_from_header is not None: return self.duration_seconds_from_header - return ( - self.end_stream_seconds_from_content - - self.begin_stream_seconds_from_content - ) + elif ( + self.num_frames_from_header is not None + and self.average_fps_from_header is not None + ): + return self.num_frames_from_header / self.average_fps_from_header + else: + return None @property def begin_stream_seconds(self) -> float: From d6b2ef0eecedd4488981375c892d8ef608ddbd4e Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 23 Jun 2025 13:13:41 -0700 Subject: [PATCH 6/8] Test duration_seconds metadata --- test/test_metadata.py | 59 +++++++++++++++++++++++++++++++++++++++++++ test/test_samplers.py | 3 +++ 2 files changed, 62 insertions(+) diff --git a/test/test_metadata.py b/test/test_metadata.py index 6814cefc7..732fe9438 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -171,6 +171,65 @@ def test_calculate_num_frames_using_fps_and_duration( assert metadata.num_frames == expected_num_frames +@pytest.mark.parametrize( + "duration_seconds_from_header, begin_stream_seconds_from_content, end_stream_seconds_from_content, expected_duration_seconds", + [(60, 5, 20, 15), (60, 1, None, 60), (None, 0, 10, 10)], +) +def test_duration_seconds_fallback( + duration_seconds_from_header, + begin_stream_seconds_from_content, + end_stream_seconds_from_content, + expected_duration_seconds, +): + """Check that using begin_stream_seconds_from_content and end_stream_seconds_from_content to calculate `.duration_seconds` + has priority. If either value is missing, duration_seconds_from_header is used. + """ + metadata = VideoStreamMetadata( + duration_seconds_from_header=duration_seconds_from_header, + bit_rate=123, + num_frames_from_header=5, + num_frames_from_content=10, + begin_stream_seconds_from_header=0, + begin_stream_seconds_from_content=begin_stream_seconds_from_content, + end_stream_seconds_from_content=end_stream_seconds_from_content, + codec="whatever", + width=123, + height=321, + average_fps_from_header=5, + stream_index=0, + ) + + assert metadata.duration_seconds == expected_duration_seconds + + +@pytest.mark.parametrize( + "num_frames_from_header, average_fps_from_header, expected_duration_seconds", + [(100, 10, 10), (100, None, None), (None, None, None)], +) +def test_calculate_duration_seconds_using_fps_and_num_frames( + num_frames_from_header, average_fps_from_header, expected_duration_seconds +): + """Check that duration_seconds is calculated using average_fps_from_header and num_frames_from_header + if duration_seconds_from_header is missing. + """ + metadata = VideoStreamMetadata( + duration_seconds_from_header=None, # None to test calculating duration_seconds + bit_rate=123, + num_frames_from_header=num_frames_from_header, + num_frames_from_content=10, + begin_stream_seconds_from_header=0, + begin_stream_seconds_from_content=None, # None to test calculating duration_seconds + end_stream_seconds_from_content=None, # None to test calculating duration_seconds + codec="whatever", + width=123, + height=321, + average_fps_from_header=average_fps_from_header, + stream_index=0, + ) + assert metadata.duration_seconds_from_header is None + assert metadata.duration_seconds == expected_duration_seconds + + def test_repr(): # Test for calls to print(), str(), etc. Useful to make sure we don't forget # to add additional @properties to __repr__ diff --git a/test/test_samplers.py b/test/test_samplers.py index 72ee108e3..938be0d91 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -592,6 +592,9 @@ def restore_metadata(): with restore_metadata(): decoder.metadata.end_stream_seconds_from_content = None decoder.metadata.duration_seconds_from_header = None + decoder.metadata.num_frames_from_header = ( + None # Set to none to prevent fallback calculation + ) with pytest.raises( ValueError, match="Could not infer stream end from video metadata" ): From e94fc047859986b849b9903c028f8ce653cfb6d4 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 24 Jun 2025 06:59:47 -0700 Subject: [PATCH 7/8] Added test cases, updated docstring --- src/torchcodec/_core/_metadata.py | 4 +++- test/test_metadata.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 2e254351d..4369adcc4 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -85,7 +85,9 @@ class VideoStreamMetadata(StreamMetadata): def duration_seconds(self) -> Optional[float]: """Duration of the stream in seconds. We try to calculate the duration from the actual frames if a :term:`scan` was performed. Otherwise we - fall back to ``duration_seconds_from_header``. + fall back to ``duration_seconds_from_header``. If that value is None, + we instead calculate the duration from ``num_frames_from_header`` and + ``average_fps_from_header``. """ if ( self.end_stream_seconds_from_content is not None diff --git a/test/test_metadata.py b/test/test_metadata.py index 732fe9438..9f929f5a8 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -173,7 +173,7 @@ def test_calculate_num_frames_using_fps_and_duration( @pytest.mark.parametrize( "duration_seconds_from_header, begin_stream_seconds_from_content, end_stream_seconds_from_content, expected_duration_seconds", - [(60, 5, 20, 15), (60, 1, None, 60), (None, 0, 10, 10)], + [(60, 5, 20, 15), (60, 1, None, 60), (60, None, 1, 60), (None, 0, 10, 10)], ) def test_duration_seconds_fallback( duration_seconds_from_header, @@ -204,7 +204,7 @@ def test_duration_seconds_fallback( @pytest.mark.parametrize( "num_frames_from_header, average_fps_from_header, expected_duration_seconds", - [(100, 10, 10), (100, None, None), (None, None, None)], + [(100, 10, 10), (100, None, None), (None, 10, None), (None, None, None)], ) def test_calculate_duration_seconds_using_fps_and_num_frames( num_frames_from_header, average_fps_from_header, expected_duration_seconds From 3b0d4edbe39aca39fd598aa6d0a8a7431fe95e1e Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 24 Jun 2025 13:41:02 -0700 Subject: [PATCH 8/8] knits --- src/torchcodec/_core/_metadata.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 4369adcc4..18484a9a6 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -85,7 +85,7 @@ class VideoStreamMetadata(StreamMetadata): def duration_seconds(self) -> Optional[float]: """Duration of the stream in seconds. We try to calculate the duration from the actual frames if a :term:`scan` was performed. Otherwise we - fall back to ``duration_seconds_from_header``. If that value is None, + fall back to ``duration_seconds_from_header``. If that value is also None, we instead calculate the duration from ``num_frames_from_header`` and ``average_fps_from_header``. """ @@ -135,8 +135,8 @@ def end_stream_seconds(self) -> Optional[float]: def num_frames(self) -> Optional[int]: """Number of frames in the stream (int or None). This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, - otherwise it corresponds to ``num_frames_from_header``. If it is None, - the number of frames is calculated from the duration and the average fps. + otherwise it corresponds to ``num_frames_from_header``. If that value is also + None, the number of frames is calculated from the duration and the average fps. """ if self.num_frames_from_content is not None: return self.num_frames_from_content