From 823c8a325209f02e952d5f9107d423fb86a2dcf1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 10:04:55 +0100 Subject: [PATCH 01/15] Let get_frames_at_indices op return a 3-tuple instead of single Tensor --- benchmarks/decoders/benchmark_decoders.py | 4 ++-- src/torchcodec/_samplers/video_clip_sampler.py | 2 +- src/torchcodec/decoders/_core/VideoDecoder.cpp | 4 ++-- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 6 +++--- src/torchcodec/decoders/_core/VideoDecoderOps.h | 2 +- src/torchcodec/decoders/_core/video_decoder_ops.py | 8 ++++++-- test/decoders/test_video_decoder_ops.py | 4 ++-- 7 files changed, 17 insertions(+), 13 deletions(-) diff --git a/benchmarks/decoders/benchmark_decoders.py b/benchmarks/decoders/benchmark_decoders.py index 1c542505..761f269f 100644 --- a/benchmarks/decoders/benchmark_decoders.py +++ b/benchmarks/decoders/benchmark_decoders.py @@ -209,7 +209,7 @@ def get_frames_from_video(self, video_file, pts_list): best_video_stream = metadata["bestVideoStreamIndex"] indices_list = [int(pts * average_fps) for pts in pts_list] frames = [] - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=best_video_stream, frame_indices=indices_list ) return frames @@ -226,7 +226,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): best_video_stream = metadata["bestVideoStreamIndex"] frames = [] indices_list = list(range(numFramesToDecode)) - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=best_video_stream, frame_indices=indices_list ) return frames diff --git a/src/torchcodec/_samplers/video_clip_sampler.py b/src/torchcodec/_samplers/video_clip_sampler.py index 1440edae..4900be53 100644 --- a/src/torchcodec/_samplers/video_clip_sampler.py +++ b/src/torchcodec/_samplers/video_clip_sampler.py @@ -240,7 +240,7 @@ def _get_clips_for_index_based_sampling( clip_start_idx + i * index_based_sampler_args.video_frame_dilation for i in range(index_based_sampler_args.frames_per_clip) ] - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( video_decoder, stream_index=metadata_json["bestVideoStreamIndex"], frame_indices=batch_indexes, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 58f94635..132957c6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1050,8 +1050,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { output.frames[f] = singleOut.frame; } - // Note that for now we ignore the pts and duration parts of the output, - // because they're never used in any caller. + output.ptsSeconds[f] = singleOut.ptsSeconds; + output.durationSeconds[f] = singleOut.durationSeconds; } output.frames = MaybePermuteHWC2CHW(options, output.frames); return output; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 70f4afdc..0be871a3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor"); + "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -218,7 +218,7 @@ OpsDecodedOutput get_frame_at_index( return makeOpsDecodedOutput(result); } -at::Tensor get_frames_at_indices( +OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices) { @@ -226,7 +226,7 @@ at::Tensor get_frames_at_indices( std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); - return result.frames; + return makeOpsBatchDecodedOutput(result); } OpsBatchDecodedOutput get_frames_in_range( diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 7e9621e9..5b442025 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -87,7 +87,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder); // Return the frames at a given index for a given stream as a single stacked // Tensor. -at::Tensor get_frames_at_indices( +OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices); diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index bf170086..01de6ad6 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -190,9 +190,13 @@ def get_frames_at_indices_abstract( *, stream_index: int, frame_indices: List[int], -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] - return torch.empty(image_size) + return ( + torch.empty(image_size), + torch.empty([], dtype=torch.float), + torch.empty([], dtype=torch.float), + ) @register_fake("torchcodec_ns::get_frames_in_range") diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index cc7b5011..0ac06f65 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -116,7 +116,7 @@ def test_get_frames_at_indices(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder) - frames0and180 = get_frames_at_indices( + frames0and180, *_ = get_frames_at_indices( decoder, stream_index=3, frame_indices=[0, 180] ) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -425,7 +425,7 @@ def test_color_conversion_library_with_dimension_order( assert frames.shape[1:] == expected_shape assert_tensor_equal(frames[0], frame0_ref) - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4] ) assert frames.shape[1:] == expected_shape From 61b493758c8b2b90ec44a1976ca3d12989e0123b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 10:21:38 +0100 Subject: [PATCH 02/15] Add deduplication logic --- .../decoders/_core/VideoDecoder.cpp | 26 ++++++++++++++----- src/torchcodec/decoders/_core/VideoDecoder.h | 3 ++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 132957c6..a1e1afc8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1030,7 +1030,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, - const std::vector& frameIndices) { + const std::vector& frameIndices, + const bool sortIndices) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFramesAtIndices"); @@ -1039,21 +1040,32 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); + auto previousFrameIndex = -1; for (auto f = 0; f < frameIndices.size(); ++f) { auto frameIndex = frameIndices[f]; if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) { throw std::runtime_error( "Invalid frame index=" + std::to_string(frameIndex)); } - DecodedOutput singleOut = - getFrameAtIndex(streamIndex, frameIndex, output.frames[f]); - if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - output.frames[f] = singleOut.frame; + if ((f > 0) && (frameIndex == previousFrameIndex)) { + // Avoid decoding the same frame twice + output.frames[f].copy_(output.frames[f - 1]); + output.ptsSeconds[f] = output.ptsSeconds[f - 1]; + output.durationSeconds[f] = output.durationSeconds[f - 1]; + } else { + DecodedOutput singleOut = + getFrameAtIndex(streamIndex, frameIndex, output.frames[f]); + if (options.colorConversionLibrary == + ColorConversionLibrary::FILTERGRAPH) { + output.frames[f] = singleOut.frame; + } + output.ptsSeconds[f] = singleOut.ptsSeconds; + output.durationSeconds[f] = singleOut.durationSeconds; } - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; + previousFrameIndex = frameIndex; } output.frames = MaybePermuteHWC2CHW(options, output.frames); + return output; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 2adbfac6..6c401a70 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -241,7 +241,8 @@ class VideoDecoder { // Tensor. BatchDecodedOutput getFramesAtIndices( int streamIndex, - const std::vector& frameIndices); + const std::vector& frameIndices, + const bool sortIndices = false); // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from // the range are: From f7a70ba53a3d66b36e2a868ba68cd3d48f3340a8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 11:45:36 +0100 Subject: [PATCH 03/15] Added sorting logic --- .../decoders/_core/VideoDecoder.cpp | 44 ++++++++++++------- .../decoders/_core/VideoDecoderOps.cpp | 8 ++-- .../decoders/_core/VideoDecoderOps.h | 3 +- .../decoders/_core/video_decoder_ops.py | 1 + test/decoders/test_video_decoder_ops.py | 33 ++++++++++++++ 5 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index a1e1afc8..eb07ed8e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1040,32 +1040,46 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); - auto previousFrameIndex = -1; + std::vector argsort(frameIndices.size()); + for (size_t i = 0; i < argsort.size(); ++i) { + argsort[i] = i; + } + if (sortIndices) { + std::sort( + argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) { + return frameIndices[a] < frameIndices[b]; + }); + } + + auto previousIndexInVideo = -1; for (auto f = 0; f < frameIndices.size(); ++f) { - auto frameIndex = frameIndices[f]; - if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) { + auto indexInOutput = argsort[f]; + auto indexInVideo = frameIndices[argsort[f]]; + if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) { throw std::runtime_error( - "Invalid frame index=" + std::to_string(frameIndex)); + "Invalid frame index=" + std::to_string(indexInVideo)); } - if ((f > 0) && (frameIndex == previousFrameIndex)) { + if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice - output.frames[f].copy_(output.frames[f - 1]); - output.ptsSeconds[f] = output.ptsSeconds[f - 1]; - output.durationSeconds[f] = output.durationSeconds[f - 1]; + auto previousIndexInOutput = argsort[f - 1]; + output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]); + output.ptsSeconds[indexInOutput] = + output.ptsSeconds[previousIndexInOutput]; + output.durationSeconds[indexInOutput] = + output.durationSeconds[previousIndexInOutput]; } else { - DecodedOutput singleOut = - getFrameAtIndex(streamIndex, frameIndex, output.frames[f]); + DecodedOutput singleOut = getFrameAtIndex( + streamIndex, indexInVideo, output.frames[indexInOutput]); if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - output.frames[f] = singleOut.frame; + output.frames[indexInOutput] = singleOut.frame; } - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; + output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; + output.durationSeconds[indexInOutput] = singleOut.durationSeconds; } - previousFrameIndex = frameIndex; + previousIndexInVideo = indexInVideo; } output.frames = MaybePermuteHWC2CHW(options, output.frames); - return output; } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 0be871a3..a48a2113 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); + "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -221,11 +221,13 @@ OpsDecodedOutput get_frame_at_index( OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_indices) { + at::IntArrayRef frame_indices, + bool sort_indices) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); + auto result = videoDecoder->getFramesAtIndices( + stream_index, frameIndicesVec, sort_indices); return makeOpsBatchDecodedOutput(result); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 5b442025..2ec49d94 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -90,7 +90,8 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder); OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_indices); + at::IntArrayRef frame_indices, + bool sort_indices = false); // Return the frames inside a range as a single stacked Tensor. The range is // defined as [start, stop). diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 01de6ad6..6335f997 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -190,6 +190,7 @@ def get_frames_at_indices_abstract( *, stream_index: int, frame_indices: List[int], + sort_indices: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 0ac06f65..dc875c83 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -124,6 +124,39 @@ def test_get_frames_at_indices(self): assert_tensor_equal(frames0and180[0], reference_frame0) assert_tensor_equal(frames0and180[1], reference_frame180) + @pytest.mark.parametrize("sort_indices", (False, True)) + def test_get_frames_at_indices_with_sort(self, sort_indices): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream(decoder) + scan_all_streams_to_update_metadata(decoder) + stream_index = 3 + + frame_indices = [2, 0, 1, 0, 2] + + expected_frames = [ + get_frame_at_index( + decoder, stream_index=stream_index, frame_index=frame_index + )[0] + for frame_index in frame_indices + ] + + frames, *_ = get_frames_at_indices( + decoder, + stream_index=stream_index, + frame_indices=frame_indices, + sort_indices=sort_indices, + ) + for frame, expected_frame in zip(frames, expected_frames): + assert_tensor_equal(frame, expected_frame) + + # first and last frame should be equal, at index 2. We then modify the + # first frame and assert that it's now different from the last frame. + # This ensures a copy was properly made during the de-duplication logic. + assert_tensor_equal(frames[0], frames[-1]) + frames[0] += 20 + with pytest.raises(AssertionError): + assert_tensor_equal(frames[0], frames[-1]) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) From 133c2131cd2c51afb65954fa333ac39112601c99 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 11:49:34 +0100 Subject: [PATCH 04/15] minor opt --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index eb07ed8e..739be558 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1040,11 +1040,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); - std::vector argsort(frameIndices.size()); - for (size_t i = 0; i < argsort.size(); ++i) { - argsort[i] = i; - } + std::vector argsort; + if (sortIndices) { + argsort.resize(frameIndices.size()); + for (size_t i = 0; i < argsort.size(); ++i) { + argsort[i] = i; + } std::sort( argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) { return frameIndices[a] < frameIndices[b]; @@ -1053,15 +1055,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( auto previousIndexInVideo = -1; for (auto f = 0; f < frameIndices.size(); ++f) { - auto indexInOutput = argsort[f]; - auto indexInVideo = frameIndices[argsort[f]]; + auto indexInOutput = sortIndices ? argsort[f] : f; + auto indexInVideo = frameIndices[indexInOutput]; if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) { throw std::runtime_error( "Invalid frame index=" + std::to_string(indexInVideo)); } if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice - auto previousIndexInOutput = argsort[f - 1]; + auto previousIndexInOutput = sortIndices ? argsort[f - 1] : f - 1; output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]); output.ptsSeconds[indexInOutput] = output.ptsSeconds[previousIndexInOutput]; From f391582d867b0ed30dee560a08a6ed3683824a0f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 11:59:11 +0100 Subject: [PATCH 05/15] Comments --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 739be558..7a179efe 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1040,8 +1040,11 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); + // if frameIndices is [13, 10, 12, 11] + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want + // to use to decode the frames + // and argsort is [ 1, 3, 2, 0] std::vector argsort; - if (sortIndices) { argsort.resize(frameIndices.size()); for (size_t i = 0; i < argsort.size(); ++i) { From d475890cf11735f0fba3480a20e34b79f8ac08be Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 13:32:38 +0100 Subject: [PATCH 06/15] scaffolding --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 7 +++++++ src/torchcodec/decoders/_core/VideoDecoder.h | 7 +++++++ .../decoders/_core/VideoDecoderOps.cpp | 17 +++++++++++++++++ src/torchcodec/decoders/_core/VideoDecoderOps.h | 10 ++++++++-- src/torchcodec/decoders/_core/__init__.py | 1 + .../decoders/_core/video_decoder_ops.py | 16 ++++++++++++++++ 6 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7a179efe..90276f79 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1088,6 +1088,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( return output; } +VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss( + int streamIndex, + const std::vector& framePtss, + const bool sortPtss) { + return getFramesAtIndices(streamIndex, framePtss, sortPtss); + } + VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int streamIndex, int64_t start, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 6c401a70..7f34ff12 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -223,6 +223,7 @@ class VideoDecoder { // i.e. it will be returned when this function is called with seconds=5.0 or // seconds=5.999, etc. DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds); + DecodedOutput getFrameAtIndex( int streamIndex, int64_t frameIndex, @@ -243,6 +244,12 @@ class VideoDecoder { int streamIndex, const std::vector& frameIndices, const bool sortIndices = false); + + BatchDecodedOutput getFramesAtPtss( + int streamIndex, + const std::vector& framePtss, + const bool sortPtss = false); + // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from // the range are: diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index a48a2113..76b457c9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -41,6 +41,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)"); + m.def( + "get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, int[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -209,6 +211,20 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { return makeOpsDecodedOutput(result); } +OpsBatchDecodedOutput get_frames_at_ptss( + at::Tensor& decoder, + int64_t stream_index, + at::IntArrayRef frame_ptss, + bool sort_ptss) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + std::vector framePtssVec( + frame_ptss.begin(), frame_ptss.end()); + auto result = videoDecoder->getFramesAtPtss( + stream_index, framePtssVec, sort_ptss); + return makeOpsBatchDecodedOutput(result); +} + + OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, int64_t stream_index, @@ -485,6 +501,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); + m.impl("get_frames_at_ptss", &get_frames_at_ptss); m.impl("get_frames_in_range", &get_frames_in_range); m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range); m.impl("_test_frame_pts_equality", &_test_frame_pts_equality); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 2ec49d94..0da1edd0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -75,6 +75,13 @@ using OpsBatchDecodedOutput = std::tuple; // given timestamp T has T >= PTS and T < PTS + Duration. OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); +// Return the frames at given ptss for a given stream +OpsBatchDecodedOutput get_frames_at_ptss( + at::Tensor& decoder, + int64_t stream_index, + at::IntArrayRef frame_ptss, + bool sort_ptss = false); + // Return the frame that is visible at a given index in the video. OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, @@ -85,8 +92,7 @@ OpsDecodedOutput get_frame_at_index( // duration as tensors. OpsDecodedOutput get_next_frame(at::Tensor& decoder); -// Return the frames at a given index for a given stream as a single stacked -// Tensor. +// Return the frames at given indices for a given stream OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index bd761fe1..543cb74f 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -22,6 +22,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_at_ptss, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 6335f997..2f0036ab 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -71,6 +71,7 @@ def load_torchcodec_extension(): get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default +get_frames_at_ptss = torch.ops.torchcodec_ns.get_frames_at_ptss.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default @@ -171,6 +172,21 @@ def get_frame_at_pts_abstract( torch.empty([], dtype=torch.float), ) +@register_fake("torchcodec_ns::get_frames_at_ptss") +def get_frames_at_pts_abstract( + decoder: torch.Tensor, + *, + stream_index: int, + frame_ptss: List[int], + sort_ptss: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + image_size = [get_ctx().new_dynamic_size() for _ in range(4)] + return ( + torch.empty(image_size), + torch.empty([], dtype=torch.float), + torch.empty([], dtype=torch.float), + ) + @register_fake("torchcodec_ns::get_frame_at_index") def get_frame_at_index_abstract( From 14e287604e05d17d3f59987b7a3d289ebe5b5516 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 14:16:39 +0100 Subject: [PATCH 07/15] Added logic --- .../decoders/_core/VideoDecoder.cpp | 39 +++++++++++++++++-- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- .../decoders/_core/VideoDecoderOps.cpp | 12 +++--- .../decoders/_core/VideoDecoderOps.h | 2 +- .../decoders/_core/video_decoder_ops.py | 3 +- 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 90276f79..f71f2c6e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1090,10 +1090,43 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss( int streamIndex, - const std::vector& framePtss, + const std::vector& framePtss, const bool sortPtss) { - return getFramesAtIndices(streamIndex, framePtss, sortPtss); - } + validateUserProvidedStreamIndex(streamIndex); + validateScannedAllStreams("getFramesAtPtss"); + + // The frame displayed at timestamp t and the one displayed at timestamp `t + + // eps` are probably the same frame, with the same index. The easiest way to + // avoid decoding that unique frame twice is to convert the input timestamps + // to indices, and leverage the de-duplication logic of getFramesAtIndices. + + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& stream = streams_[streamIndex]; + double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); + double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); + + std::vector frameIndices(framePtss.size()); + for (auto i = 0; i < framePtss.size(); ++i) { + auto framePts = framePtss[i]; + TORCH_CHECK( + framePts >= minSeconds && framePts < maxSeconds, + "frame pts is " + std::to_string(framePts) + "; must be in range [" + + std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + + ")."); + + auto it = std::lower_bound( + stream.allFrames.begin(), + stream.allFrames.end(), + framePts, + [&stream](const FrameInfo& info, double start) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= start; + }); + int64_t frameIndex = it - stream.allFrames.begin(); + frameIndices[i] = frameIndex; + } + + return getFramesAtIndices(streamIndex, frameIndices, sortPtss); +} VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int streamIndex, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 7f34ff12..02d9c10c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -247,7 +247,7 @@ class VideoDecoder { BatchDecodedOutput getFramesAtPtss( int streamIndex, - const std::vector& framePtss, + const std::vector& framePtss, const bool sortPtss = false); // Returns frames within a given range for a given stream as a single stacked diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 76b457c9..dd6db3b4 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -42,7 +42,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, int[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)"); + "get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -214,17 +214,15 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { OpsBatchDecodedOutput get_frames_at_ptss( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_ptss, + at::ArrayRef frame_ptss, bool sort_ptss) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector framePtssVec( - frame_ptss.begin(), frame_ptss.end()); - auto result = videoDecoder->getFramesAtPtss( - stream_index, framePtssVec, sort_ptss); + std::vector framePtssVec(frame_ptss.begin(), frame_ptss.end()); + auto result = + videoDecoder->getFramesAtPtss(stream_index, framePtssVec, sort_ptss); return makeOpsBatchDecodedOutput(result); } - OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, int64_t stream_index, diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 0da1edd0..167387f0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -79,7 +79,7 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); OpsBatchDecodedOutput get_frames_at_ptss( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_ptss, + at::ArrayRef frame_ptss, bool sort_ptss = false); // Return the frame that is visible at a given index in the video. diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 2f0036ab..b80f9122 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -172,12 +172,13 @@ def get_frame_at_pts_abstract( torch.empty([], dtype=torch.float), ) + @register_fake("torchcodec_ns::get_frames_at_ptss") def get_frames_at_pts_abstract( decoder: torch.Tensor, *, stream_index: int, - frame_ptss: List[int], + frame_ptss: List[float], sort_ptss: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] From 9c9e4627d2b37da2b7a509d50d3ee897fa29d8dd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 14:21:38 +0100 Subject: [PATCH 08/15] Added test --- test/decoders/test_video_decoder_ops.py | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index dc875c83..f404ed9c 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -27,6 +27,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_at_ptss, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, @@ -157,6 +158,37 @@ def test_get_frames_at_indices_with_sort(self, sort_indices): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + @pytest.mark.parametrize("sort_ptss", (False, True)) + def test_get_frames_at_ptss_with_sort(self, sort_ptss): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream(decoder) + scan_all_streams_to_update_metadata(decoder) + stream_index = 3 + + frame_ptss = [2, 0, 1, 0 + 1e-3, 2 + 1e-3] + + expected_frames = [ + get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss + ] + + frames, *_ = get_frames_at_ptss( + decoder, + stream_index=stream_index, + frame_ptss=frame_ptss, + sort_ptss=sort_ptss, + ) + for frame, expected_frame in zip(frames, expected_frames): + assert_tensor_equal(frame, expected_frame) + + # # first and last frame should be equal, at pts=2 [+ eps]. We then + # modify the # first frame and assert that it's now different from the + # last frame. # This ensures a copy was properly made during the + # de-duplication logic. + assert_tensor_equal(frames[0], frames[-1]) + frames[0] += 20 + with pytest.raises(AssertionError): + assert_tensor_equal(frames[0], frames[-1]) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) From b8284cc84b575709c87a431bfb339ea7a943dbfd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 17:17:45 +0100 Subject: [PATCH 09/15] Remove parameter, just sort if not already sorted --- .../decoders/_core/VideoDecoder.cpp | 20 ++++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 3 +-- .../decoders/_core/VideoDecoderOps.cpp | 8 +++----- .../decoders/_core/VideoDecoderOps.h | 3 +-- .../decoders/_core/video_decoder_ops.py | 1 - test/decoders/test_video_decoder_ops.py | 4 +--- 6 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7a179efe..416d96a9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1030,8 +1030,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, - const std::vector& frameIndices, - const bool sortIndices) { + const std::vector& frameIndices) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFramesAtIndices"); @@ -1040,12 +1039,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); - // if frameIndices is [13, 10, 12, 11] - // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want - // to use to decode the frames - // and argsort is [ 1, 3, 2, 0] + auto indicesAreSorted = + std::is_sorted(frameIndices.begin(), frameIndices.end()); + std::vector argsort; - if (sortIndices) { + if (!indicesAreSorted) { + // if frameIndices is [13, 10, 12, 11] + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want + // to use to decode the frames + // and argsort is [ 1, 3, 2, 0] argsort.resize(frameIndices.size()); for (size_t i = 0; i < argsort.size(); ++i) { argsort[i] = i; @@ -1058,7 +1060,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( auto previousIndexInVideo = -1; for (auto f = 0; f < frameIndices.size(); ++f) { - auto indexInOutput = sortIndices ? argsort[f] : f; + auto indexInOutput = indicesAreSorted ? f : argsort[f]; auto indexInVideo = frameIndices[indexInOutput]; if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) { throw std::runtime_error( @@ -1066,7 +1068,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( } if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice - auto previousIndexInOutput = sortIndices ? argsort[f - 1] : f - 1; + auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]); output.ptsSeconds[indexInOutput] = output.ptsSeconds[previousIndexInOutput]; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 6c401a70..2adbfac6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -241,8 +241,7 @@ class VideoDecoder { // Tensor. BatchDecodedOutput getFramesAtIndices( int streamIndex, - const std::vector& frameIndices, - const bool sortIndices = false); + const std::vector& frameIndices); // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from // the range are: diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index a48a2113..0be871a3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)"); + "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -221,13 +221,11 @@ OpsDecodedOutput get_frame_at_index( OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_indices, - bool sort_indices) { + at::IntArrayRef frame_indices) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndices( - stream_index, frameIndicesVec, sort_indices); + auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); return makeOpsBatchDecodedOutput(result); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 2ec49d94..5b442025 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -90,8 +90,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder); OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_indices, - bool sort_indices = false); + at::IntArrayRef frame_indices); // Return the frames inside a range as a single stacked Tensor. The range is // defined as [start, stop). diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 6335f997..01de6ad6 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -190,7 +190,6 @@ def get_frames_at_indices_abstract( *, stream_index: int, frame_indices: List[int], - sort_indices: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index dc875c83..bbd9fe4e 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -124,8 +124,7 @@ def test_get_frames_at_indices(self): assert_tensor_equal(frames0and180[0], reference_frame0) assert_tensor_equal(frames0and180[1], reference_frame180) - @pytest.mark.parametrize("sort_indices", (False, True)) - def test_get_frames_at_indices_with_sort(self, sort_indices): + def test_get_frames_at_indices_unsorted_indices(self): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder) scan_all_streams_to_update_metadata(decoder) @@ -144,7 +143,6 @@ def test_get_frames_at_indices_with_sort(self, sort_indices): decoder, stream_index=stream_index, frame_indices=frame_indices, - sort_indices=sort_indices, ) for frame, expected_frame in zip(frames, expected_frames): assert_tensor_equal(frame, expected_frame) From 4dda5b78b7ddf039c7f94774db80a4b6bc0b0798 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 11:10:32 +0100 Subject: [PATCH 10/15] Rename --- .../decoders/_core/VideoDecoder.cpp | 6 ++--- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- .../decoders/_core/VideoDecoderOps.cpp | 27 +++++++++---------- .../decoders/_core/VideoDecoderOps.h | 2 +- src/torchcodec/decoders/_core/__init__.py | 2 +- .../decoders/_core/video_decoder_ops.py | 6 ++--- test/decoders/test_video_decoder_ops.py | 6 ++--- 7 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 8eefc2f1..17354706 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1090,11 +1090,11 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( return output; } -VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss( +VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( int streamIndex, - const std::vector& framePtss){ + const std::vector& framePtss) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesAtPtss"); + validateScannedAllStreams("getFramesDisplayedByTimestamps"); // The frame displayed at timestamp t and the one displayed at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c2f19816..5eab70bf 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -244,7 +244,7 @@ class VideoDecoder { int streamIndex, const std::vector& frameIndices); - BatchDecodedOutput getFramesAtPtss( + BatchDecodedOutput getFramesDisplayedByTimestamps( int streamIndex, const std::vector& framePtss); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 8f168be9..fbc739b3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -41,12 +41,12 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); - m.def( - "get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); + m.def( + "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss) -> (Tensor, Tensor, Tensor)"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -211,17 +211,6 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { return makeOpsDecodedOutput(result); } -OpsBatchDecodedOutput get_frames_at_ptss( - at::Tensor& decoder, - int64_t stream_index, - at::ArrayRef frame_ptss) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector framePtssVec(frame_ptss.begin(), frame_ptss.end()); - auto result = - videoDecoder->getFramesAtPtss(stream_index, framePtssVec); - return makeOpsBatchDecodedOutput(result); -} - OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, int64_t stream_index, @@ -253,6 +242,16 @@ OpsBatchDecodedOutput get_frames_in_range( stream_index, start, stop, step.value_or(1)); return makeOpsBatchDecodedOutput(result); } +OpsBatchDecodedOutput get_frames_by_pts( + at::Tensor& decoder, + int64_t stream_index, + at::ArrayRef frame_ptss) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + std::vector framePtssVec(frame_ptss.begin(), frame_ptss.end()); + auto result = + videoDecoder->getFramesDisplayedByTimestamps(stream_index, framePtssVec); + return makeOpsBatchDecodedOutput(result); +} OpsBatchDecodedOutput get_frames_by_pts_in_range( at::Tensor& decoder, @@ -496,9 +495,9 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); - m.impl("get_frames_at_ptss", &get_frames_at_ptss); m.impl("get_frames_in_range", &get_frames_in_range); m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range); + m.impl("get_frames_by_pts", &get_frames_by_pts); m.impl("_test_frame_pts_equality", &_test_frame_pts_equality); m.impl( "scan_all_streams_to_update_metadata", diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 3b3d3615..12cce81c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -76,7 +76,7 @@ using OpsBatchDecodedOutput = std::tuple; OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); // Return the frames at given ptss for a given stream -OpsBatchDecodedOutput get_frames_at_ptss( +OpsBatchDecodedOutput get_frames_by_pts( at::Tensor& decoder, int64_t stream_index, at::ArrayRef frame_ptss); diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index 543cb74f..a1ac9a47 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -22,7 +22,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, - get_frames_at_ptss, + get_frames_by_pts, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 3a3b36c3..8f6c3c08 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -71,7 +71,7 @@ def load_torchcodec_extension(): get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default -get_frames_at_ptss = torch.ops.torchcodec_ns.get_frames_at_ptss.default +get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default @@ -173,8 +173,8 @@ def get_frame_at_pts_abstract( ) -@register_fake("torchcodec_ns::get_frames_at_ptss") -def get_frames_at_pts_abstract( +@register_fake("torchcodec_ns::get_frames_by_pts") +def get_frames_by_pts_abstract( decoder: torch.Tensor, *, stream_index: int, diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 153c4905..84cc560e 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -27,7 +27,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, - get_frames_at_ptss, + get_frames_by_pts, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, @@ -156,7 +156,7 @@ def test_get_frames_at_indices_unsorted_indices(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) - def test_get_frames_at_ptss(self): + def test_get_frames_by_pts(self): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder) scan_all_streams_to_update_metadata(decoder) @@ -168,7 +168,7 @@ def test_get_frames_at_ptss(self): get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss ] - frames, *_ = get_frames_at_ptss( + frames, *_ = get_frames_by_pts( decoder, stream_index=stream_index, frame_ptss=frame_ptss, From 7d266239f00fb31f89c1acb8d709f0dae7a771fa Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 11:42:55 +0100 Subject: [PATCH 11/15] Fix last frame request --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 1 + test/decoders/test_video_decoder_ops.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 17354706..b6e4f6ce 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1123,6 +1123,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( return ptsToSeconds(info.nextPts, stream.timeBase) <= start; }); int64_t frameIndex = it - stream.allFrames.begin(); + frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 84cc560e..adc1392f 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -162,7 +162,8 @@ def test_get_frames_by_pts(self): scan_all_streams_to_update_metadata(decoder) stream_index = 3 - frame_ptss = [2, 0, 1, 0 + 1e-3, 2 + 1e-3] + # Note: 13.01 should give the last video frame for the NASA video + frame_ptss = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] expected_frames = [ get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss From 3a8839d63086926829f731b18916737183d79bb2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 11:49:03 +0100 Subject: [PATCH 12/15] Better fix --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index b6e4f6ce..42a42b38 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1117,13 +1117,12 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), - stream.allFrames.end(), + stream.allFrames.end() - 1, framePts, [&stream](const FrameInfo& info, double start) { return ptsToSeconds(info.nextPts, stream.timeBase) <= start; }); int64_t frameIndex = it - stream.allFrames.begin(); - frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } From d43dd9120d227ad922079830c05b3c8fa3662de4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 15:41:57 +0100 Subject: [PATCH 13/15] Use timestamps as parameter name --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 8 ++++---- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 8 ++++---- src/torchcodec/decoders/_core/VideoDecoderOps.h | 2 +- src/torchcodec/decoders/_core/video_decoder_ops.py | 2 +- test/decoders/test_video_decoder_ops.py | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 42a42b38..bc3f1198 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1092,7 +1092,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( int streamIndex, - const std::vector& framePtss) { + const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFramesDisplayedByTimestamps"); @@ -1106,9 +1106,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); - std::vector frameIndices(framePtss.size()); - for (auto i = 0; i < framePtss.size(); ++i) { - auto framePts = framePtss[i]; + std::vector frameIndices(timestamps.size()); + for (auto i = 0; i < timestamps.size(); ++i) { + auto framePts = timestamps[i]; TORCH_CHECK( framePts >= minSeconds && framePts < maxSeconds, "frame pts is " + std::to_string(framePts) + "; must be in range [" + diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 5eab70bf..c0f489ce 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -246,7 +246,7 @@ class VideoDecoder { BatchDecodedOutput getFramesDisplayedByTimestamps( int streamIndex, - const std::vector& framePtss); + const std::vector& timestamps); // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index fbc739b3..6b91853c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -46,7 +46,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss) -> (Tensor, Tensor, Tensor)"); + "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -245,11 +245,11 @@ OpsBatchDecodedOutput get_frames_in_range( OpsBatchDecodedOutput get_frames_by_pts( at::Tensor& decoder, int64_t stream_index, - at::ArrayRef frame_ptss) { + at::ArrayRef timestamps) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector framePtssVec(frame_ptss.begin(), frame_ptss.end()); + std::vector timestampsVec(timestamps.begin(), timestamps.end()); auto result = - videoDecoder->getFramesDisplayedByTimestamps(stream_index, framePtssVec); + videoDecoder->getFramesDisplayedByTimestamps(stream_index, timestampsVec); return makeOpsBatchDecodedOutput(result); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 12cce81c..eac489ce 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -79,7 +79,7 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); OpsBatchDecodedOutput get_frames_by_pts( at::Tensor& decoder, int64_t stream_index, - at::ArrayRef frame_ptss); + at::ArrayRef timestamps); // Return the frame that is visible at a given index in the video. OpsDecodedOutput get_frame_at_index( diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 8f6c3c08..d4102ae5 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -178,7 +178,7 @@ def get_frames_by_pts_abstract( decoder: torch.Tensor, *, stream_index: int, - frame_ptss: List[float], + timestamps: List[float], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index adc1392f..8535abee 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -163,16 +163,16 @@ def test_get_frames_by_pts(self): stream_index = 3 # Note: 13.01 should give the last video frame for the NASA video - frame_ptss = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] + timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] expected_frames = [ - get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss + get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps ] frames, *_ = get_frames_by_pts( decoder, stream_index=stream_index, - frame_ptss=frame_ptss, + timestamps=timestamps, ) for frame, expected_frame in zip(frames, expected_frames): assert_tensor_equal(frame, expected_frame) From 8dd9b0a67ad78614f37e93bed2b5634a478c3b3b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 17:02:07 +0100 Subject: [PATCH 14/15] Address comments --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index bc3f1198..5f6a6d29 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1100,6 +1100,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( // eps` are probably the same frame, with the same index. The easiest way to // avoid decoding that unique frame twice is to convert the input timestamps // to indices, and leverage the de-duplication logic of getFramesAtIndices. + // This means this function requires a scan. + // TODO: longer term, we should implement this without requiring a scan const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; @@ -1119,8 +1121,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( stream.allFrames.begin(), stream.allFrames.end() - 1, framePts, - [&stream](const FrameInfo& info, double start) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= start; + [&stream](const FrameInfo& info, double framePts) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; }); int64_t frameIndex = it - stream.allFrames.begin(); frameIndices[i] = frameIndex; From 52dd9edd5c237f69eb61a0e351013929705dc428 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 17:31:03 +0100 Subject: [PATCH 15/15] Use max() approach, and fix comment in tests --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 8 +++++++- test/decoders/test_video_decoder_ops.py | 8 ++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 5f6a6d29..add9c9be 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1119,12 +1119,18 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), - stream.allFrames.end() - 1, + stream.allFrames.end(), framePts, [&stream](const FrameInfo& info, double framePts) { return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; }); int64_t frameIndex = it - stream.allFrames.begin(); + // If the frame index is larger than the size of allFrames, that means we + // couldn't match the pts value to the pts value of a NEXT FRAME. And + // that means that this timestamp falls during the time between when the + // last frame is displayed, and the video ends. Hence, it should map to the + // index of the last frame. + frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 8535abee..0ed68146 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -177,10 +177,10 @@ def test_get_frames_by_pts(self): for frame, expected_frame in zip(frames, expected_frames): assert_tensor_equal(frame, expected_frame) - # # first and last frame should be equal, at pts=2 [+ eps]. We then - # modify the # first frame and assert that it's now different from the - # last frame. # This ensures a copy was properly made during the - # de-duplication logic. + # first and last frame should be equal, at pts=2 [+ eps]. We then modify + # the first frame and assert that it's now different from the last + # frame. This ensures a copy was properly made during the de-duplication + # logic. assert_tensor_equal(frames[0], frames[-1]) frames[0] += 20 with pytest.raises(AssertionError):