diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index c15e86e74..18484a9a6 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -85,17 +85,27 @@ class VideoStreamMetadata(StreamMetadata): def duration_seconds(self) -> Optional[float]: """Duration of the stream in seconds. We try to calculate the duration from the actual frames if a :term:`scan` was performed. Otherwise we - fall back to ``duration_seconds_from_header``. + fall back to ``duration_seconds_from_header``. If that value is also None, + we instead calculate the duration from ``num_frames_from_header`` and + ``average_fps_from_header``. """ if ( - self.end_stream_seconds_from_content is None - or self.begin_stream_seconds_from_content is None + self.end_stream_seconds_from_content is not None + and self.begin_stream_seconds_from_content is not None ): + return ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) + elif self.duration_seconds_from_header is not None: return self.duration_seconds_from_header - return ( - self.end_stream_seconds_from_content - - self.begin_stream_seconds_from_content - ) + elif ( + self.num_frames_from_header is not None + and self.average_fps_from_header is not None + ): + return self.num_frames_from_header / self.average_fps_from_header + else: + return None @property def begin_stream_seconds(self) -> float: @@ -123,14 +133,22 @@ def end_stream_seconds(self) -> Optional[float]: @property def num_frames(self) -> Optional[int]: - """Number of frames in the stream. This corresponds to - ``num_frames_from_content`` if a :term:`scan` was made, otherwise it - corresponds to ``num_frames_from_header``. + """Number of frames in the stream (int or None). + This corresponds to ``num_frames_from_content`` if a :term:`scan` was made, + otherwise it corresponds to ``num_frames_from_header``. If that value is also + None, the number of frames is calculated from the duration and the average fps. """ if self.num_frames_from_content is not None: return self.num_frames_from_content - else: + elif self.num_frames_from_header is not None: return self.num_frames_from_header + elif ( + self.average_fps_from_header is not None + and self.duration_seconds_from_header is not None + ): + return int(self.average_fps_from_header * self.duration_seconds_from_header) + else: + return None @property def average_fps(self) -> Optional[float]: diff --git a/test/test_decoders.py b/test/test_decoders.py index 0b7912387..fbe552908 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -6,6 +6,8 @@ import contextlib import gc +import json +from unittest.mock import patch import numpy import pytest @@ -738,6 +740,56 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + @patch("torchcodec._core._metadata._get_stream_json_metadata") + def test_get_frames_with_missing_num_frames_metadata( + self, mock_get_stream_json_metadata, device, seek_mode + ): + # Create a mock stream_dict to test that initializing VideoDecoder without + # num_frames_from_header and num_frames_from_content calculates num_frames + # using the average_fps and duration_seconds metadata. + mock_stream_dict = { + "averageFpsFromHeader": 29.97003, + "beginStreamSecondsFromContent": 0.0, + "beginStreamSecondsFromHeader": 0.0, + "bitRate": 128783.0, + "codec": "h264", + "durationSecondsFromHeader": 13.013, + "endStreamSecondsFromContent": 13.013, + "width": 480, + "height": 270, + "mediaType": "video", + "numFramesFromHeader": None, + "numFramesFromContent": None, + } + # Set the return value of the mock to be the mock_stream_dict + mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict) + + decoder = VideoDecoder( + NASA_VIDEO.path, + stream_index=3, + device=device, + seek_mode=seek_mode, + ) + + assert decoder.metadata.num_frames_from_header is None + assert decoder.metadata.num_frames_from_content is None + assert decoder.metadata.duration_seconds is not None + assert decoder.metadata.average_fps is not None + assert decoder.metadata.num_frames == int( + decoder.metadata.duration_seconds * decoder.metadata.average_fps + ) + assert len(decoder) == 390 + + # Test get_frames_in_range Python logic which uses the num_frames metadata mocked earlier. + # The frame is read at the C++ level. + ref_frames9 = NASA_VIDEO.get_frame_data_by_range( + start=9, stop=10, stream_index=3 + ).to(device) + frames9 = decoder.get_frames_in_range(start=9, stop=10) + assert_frames_equal(ref_frames9, frames9.data) + @pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"]) @pytest.mark.parametrize( "frame_getter", diff --git a/test/test_metadata.py b/test/test_metadata.py index 7ed9508fd..9f929f5a8 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -119,7 +119,7 @@ def test_get_metadata_audio_file(metadata_getter): @pytest.mark.parametrize( "num_frames_from_header, num_frames_from_content, expected_num_frames", - [(None, 10, 10), (10, None, 10), (None, None, None)], + [(10, 20, 20), (None, 10, 10), (10, None, 10)], ) def test_num_frames_fallback( num_frames_from_header, num_frames_from_content, expected_num_frames @@ -143,6 +143,93 @@ def test_num_frames_fallback( assert metadata.num_frames == expected_num_frames +@pytest.mark.parametrize( + "average_fps_from_header, duration_seconds_from_header, expected_num_frames", + [(60, 10, 600), (60, None, None), (None, 10, None), (None, None, None)], +) +def test_calculate_num_frames_using_fps_and_duration( + average_fps_from_header, duration_seconds_from_header, expected_num_frames +): + """Check that if num_frames_from_content and num_frames_from_header are missing, + `.num_frames` is calculated using average_fps_from_header and duration_seconds_from_header + """ + metadata = VideoStreamMetadata( + duration_seconds_from_header=duration_seconds_from_header, + bit_rate=123, + num_frames_from_header=None, # None to test calculating num_frames + num_frames_from_content=None, # None to test calculating num_frames + begin_stream_seconds_from_header=0, + begin_stream_seconds_from_content=0, + end_stream_seconds_from_content=4, + codec="whatever", + width=123, + height=321, + average_fps_from_header=average_fps_from_header, + stream_index=0, + ) + + assert metadata.num_frames == expected_num_frames + + +@pytest.mark.parametrize( + "duration_seconds_from_header, begin_stream_seconds_from_content, end_stream_seconds_from_content, expected_duration_seconds", + [(60, 5, 20, 15), (60, 1, None, 60), (60, None, 1, 60), (None, 0, 10, 10)], +) +def test_duration_seconds_fallback( + duration_seconds_from_header, + begin_stream_seconds_from_content, + end_stream_seconds_from_content, + expected_duration_seconds, +): + """Check that using begin_stream_seconds_from_content and end_stream_seconds_from_content to calculate `.duration_seconds` + has priority. If either value is missing, duration_seconds_from_header is used. + """ + metadata = VideoStreamMetadata( + duration_seconds_from_header=duration_seconds_from_header, + bit_rate=123, + num_frames_from_header=5, + num_frames_from_content=10, + begin_stream_seconds_from_header=0, + begin_stream_seconds_from_content=begin_stream_seconds_from_content, + end_stream_seconds_from_content=end_stream_seconds_from_content, + codec="whatever", + width=123, + height=321, + average_fps_from_header=5, + stream_index=0, + ) + + assert metadata.duration_seconds == expected_duration_seconds + + +@pytest.mark.parametrize( + "num_frames_from_header, average_fps_from_header, expected_duration_seconds", + [(100, 10, 10), (100, None, None), (None, 10, None), (None, None, None)], +) +def test_calculate_duration_seconds_using_fps_and_num_frames( + num_frames_from_header, average_fps_from_header, expected_duration_seconds +): + """Check that duration_seconds is calculated using average_fps_from_header and num_frames_from_header + if duration_seconds_from_header is missing. + """ + metadata = VideoStreamMetadata( + duration_seconds_from_header=None, # None to test calculating duration_seconds + bit_rate=123, + num_frames_from_header=num_frames_from_header, + num_frames_from_content=10, + begin_stream_seconds_from_header=0, + begin_stream_seconds_from_content=None, # None to test calculating duration_seconds + end_stream_seconds_from_content=None, # None to test calculating duration_seconds + codec="whatever", + width=123, + height=321, + average_fps_from_header=average_fps_from_header, + stream_index=0, + ) + assert metadata.duration_seconds_from_header is None + assert metadata.duration_seconds == expected_duration_seconds + + def test_repr(): # Test for calls to print(), str(), etc. Useful to make sure we don't forget # to add additional @properties to __repr__ diff --git a/test/test_samplers.py b/test/test_samplers.py index 72ee108e3..938be0d91 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -592,6 +592,9 @@ def restore_metadata(): with restore_metadata(): decoder.metadata.end_stream_seconds_from_content = None decoder.metadata.duration_seconds_from_header = None + decoder.metadata.num_frames_from_header = ( + None # Set to none to prevent fallback calculation + ) with pytest.raises( ValueError, match="Could not infer stream end from video metadata" ):