From 823c8a325209f02e952d5f9107d423fb86a2dcf1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 10:04:55 +0100 Subject: [PATCH 01/27] Let get_frames_at_indices op return a 3-tuple instead of single Tensor --- benchmarks/decoders/benchmark_decoders.py | 4 ++-- src/torchcodec/_samplers/video_clip_sampler.py | 2 +- src/torchcodec/decoders/_core/VideoDecoder.cpp | 4 ++-- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 6 +++--- src/torchcodec/decoders/_core/VideoDecoderOps.h | 2 +- src/torchcodec/decoders/_core/video_decoder_ops.py | 8 ++++++-- test/decoders/test_video_decoder_ops.py | 4 ++-- 7 files changed, 17 insertions(+), 13 deletions(-) diff --git a/benchmarks/decoders/benchmark_decoders.py b/benchmarks/decoders/benchmark_decoders.py index 1c542505..761f269f 100644 --- a/benchmarks/decoders/benchmark_decoders.py +++ b/benchmarks/decoders/benchmark_decoders.py @@ -209,7 +209,7 @@ def get_frames_from_video(self, video_file, pts_list): best_video_stream = metadata["bestVideoStreamIndex"] indices_list = [int(pts * average_fps) for pts in pts_list] frames = [] - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=best_video_stream, frame_indices=indices_list ) return frames @@ -226,7 +226,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): best_video_stream = metadata["bestVideoStreamIndex"] frames = [] indices_list = list(range(numFramesToDecode)) - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=best_video_stream, frame_indices=indices_list ) return frames diff --git a/src/torchcodec/_samplers/video_clip_sampler.py b/src/torchcodec/_samplers/video_clip_sampler.py index 1440edae..4900be53 100644 --- a/src/torchcodec/_samplers/video_clip_sampler.py +++ b/src/torchcodec/_samplers/video_clip_sampler.py @@ -240,7 +240,7 @@ def _get_clips_for_index_based_sampling( clip_start_idx + i * index_based_sampler_args.video_frame_dilation for i in range(index_based_sampler_args.frames_per_clip) ] - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( video_decoder, stream_index=metadata_json["bestVideoStreamIndex"], frame_indices=batch_indexes, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 58f94635..132957c6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1050,8 +1050,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { output.frames[f] = singleOut.frame; } - // Note that for now we ignore the pts and duration parts of the output, - // because they're never used in any caller. + output.ptsSeconds[f] = singleOut.ptsSeconds; + output.durationSeconds[f] = singleOut.durationSeconds; } output.frames = MaybePermuteHWC2CHW(options, output.frames); return output; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 70f4afdc..0be871a3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor"); + "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -218,7 +218,7 @@ OpsDecodedOutput get_frame_at_index( return makeOpsDecodedOutput(result); } -at::Tensor get_frames_at_indices( +OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices) { @@ -226,7 +226,7 @@ at::Tensor get_frames_at_indices( std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); - return result.frames; + return makeOpsBatchDecodedOutput(result); } OpsBatchDecodedOutput get_frames_in_range( diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 7e9621e9..5b442025 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -87,7 +87,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder); // Return the frames at a given index for a given stream as a single stacked // Tensor. -at::Tensor get_frames_at_indices( +OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices); diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index bf170086..01de6ad6 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -190,9 +190,13 @@ def get_frames_at_indices_abstract( *, stream_index: int, frame_indices: List[int], -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] - return torch.empty(image_size) + return ( + torch.empty(image_size), + torch.empty([], dtype=torch.float), + torch.empty([], dtype=torch.float), + ) @register_fake("torchcodec_ns::get_frames_in_range") diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index cc7b5011..0ac06f65 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -116,7 +116,7 @@ def test_get_frames_at_indices(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder) - frames0and180 = get_frames_at_indices( + frames0and180, *_ = get_frames_at_indices( decoder, stream_index=3, frame_indices=[0, 180] ) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -425,7 +425,7 @@ def test_color_conversion_library_with_dimension_order( assert frames.shape[1:] == expected_shape assert_tensor_equal(frames[0], frame0_ref) - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4] ) assert frames.shape[1:] == expected_shape From 61b493758c8b2b90ec44a1976ca3d12989e0123b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 10:21:38 +0100 Subject: [PATCH 02/27] Add deduplication logic --- .../decoders/_core/VideoDecoder.cpp | 26 ++++++++++++++----- src/torchcodec/decoders/_core/VideoDecoder.h | 3 ++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 132957c6..a1e1afc8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1030,7 +1030,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, - const std::vector& frameIndices) { + const std::vector& frameIndices, + const bool sortIndices) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFramesAtIndices"); @@ -1039,21 +1040,32 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); + auto previousFrameIndex = -1; for (auto f = 0; f < frameIndices.size(); ++f) { auto frameIndex = frameIndices[f]; if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) { throw std::runtime_error( "Invalid frame index=" + std::to_string(frameIndex)); } - DecodedOutput singleOut = - getFrameAtIndex(streamIndex, frameIndex, output.frames[f]); - if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - output.frames[f] = singleOut.frame; + if ((f > 0) && (frameIndex == previousFrameIndex)) { + // Avoid decoding the same frame twice + output.frames[f].copy_(output.frames[f - 1]); + output.ptsSeconds[f] = output.ptsSeconds[f - 1]; + output.durationSeconds[f] = output.durationSeconds[f - 1]; + } else { + DecodedOutput singleOut = + getFrameAtIndex(streamIndex, frameIndex, output.frames[f]); + if (options.colorConversionLibrary == + ColorConversionLibrary::FILTERGRAPH) { + output.frames[f] = singleOut.frame; + } + output.ptsSeconds[f] = singleOut.ptsSeconds; + output.durationSeconds[f] = singleOut.durationSeconds; } - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; + previousFrameIndex = frameIndex; } output.frames = MaybePermuteHWC2CHW(options, output.frames); + return output; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 2adbfac6..6c401a70 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -241,7 +241,8 @@ class VideoDecoder { // Tensor. BatchDecodedOutput getFramesAtIndices( int streamIndex, - const std::vector& frameIndices); + const std::vector& frameIndices, + const bool sortIndices = false); // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from // the range are: From f7a70ba53a3d66b36e2a868ba68cd3d48f3340a8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 11:45:36 +0100 Subject: [PATCH 03/27] Added sorting logic --- .../decoders/_core/VideoDecoder.cpp | 44 ++++++++++++------- .../decoders/_core/VideoDecoderOps.cpp | 8 ++-- .../decoders/_core/VideoDecoderOps.h | 3 +- .../decoders/_core/video_decoder_ops.py | 1 + test/decoders/test_video_decoder_ops.py | 33 ++++++++++++++ 5 files changed, 70 insertions(+), 19 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index a1e1afc8..eb07ed8e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1040,32 +1040,46 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); - auto previousFrameIndex = -1; + std::vector argsort(frameIndices.size()); + for (size_t i = 0; i < argsort.size(); ++i) { + argsort[i] = i; + } + if (sortIndices) { + std::sort( + argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) { + return frameIndices[a] < frameIndices[b]; + }); + } + + auto previousIndexInVideo = -1; for (auto f = 0; f < frameIndices.size(); ++f) { - auto frameIndex = frameIndices[f]; - if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) { + auto indexInOutput = argsort[f]; + auto indexInVideo = frameIndices[argsort[f]]; + if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) { throw std::runtime_error( - "Invalid frame index=" + std::to_string(frameIndex)); + "Invalid frame index=" + std::to_string(indexInVideo)); } - if ((f > 0) && (frameIndex == previousFrameIndex)) { + if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice - output.frames[f].copy_(output.frames[f - 1]); - output.ptsSeconds[f] = output.ptsSeconds[f - 1]; - output.durationSeconds[f] = output.durationSeconds[f - 1]; + auto previousIndexInOutput = argsort[f - 1]; + output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]); + output.ptsSeconds[indexInOutput] = + output.ptsSeconds[previousIndexInOutput]; + output.durationSeconds[indexInOutput] = + output.durationSeconds[previousIndexInOutput]; } else { - DecodedOutput singleOut = - getFrameAtIndex(streamIndex, frameIndex, output.frames[f]); + DecodedOutput singleOut = getFrameAtIndex( + streamIndex, indexInVideo, output.frames[indexInOutput]); if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - output.frames[f] = singleOut.frame; + output.frames[indexInOutput] = singleOut.frame; } - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; + output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; + output.durationSeconds[indexInOutput] = singleOut.durationSeconds; } - previousFrameIndex = frameIndex; + previousIndexInVideo = indexInVideo; } output.frames = MaybePermuteHWC2CHW(options, output.frames); - return output; } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 0be871a3..a48a2113 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); + "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -221,11 +221,13 @@ OpsDecodedOutput get_frame_at_index( OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_indices) { + at::IntArrayRef frame_indices, + bool sort_indices) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); + auto result = videoDecoder->getFramesAtIndices( + stream_index, frameIndicesVec, sort_indices); return makeOpsBatchDecodedOutput(result); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 5b442025..2ec49d94 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -90,7 +90,8 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder); OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_indices); + at::IntArrayRef frame_indices, + bool sort_indices = false); // Return the frames inside a range as a single stacked Tensor. The range is // defined as [start, stop). diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 01de6ad6..6335f997 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -190,6 +190,7 @@ def get_frames_at_indices_abstract( *, stream_index: int, frame_indices: List[int], + sort_indices: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 0ac06f65..dc875c83 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -124,6 +124,39 @@ def test_get_frames_at_indices(self): assert_tensor_equal(frames0and180[0], reference_frame0) assert_tensor_equal(frames0and180[1], reference_frame180) + @pytest.mark.parametrize("sort_indices", (False, True)) + def test_get_frames_at_indices_with_sort(self, sort_indices): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream(decoder) + scan_all_streams_to_update_metadata(decoder) + stream_index = 3 + + frame_indices = [2, 0, 1, 0, 2] + + expected_frames = [ + get_frame_at_index( + decoder, stream_index=stream_index, frame_index=frame_index + )[0] + for frame_index in frame_indices + ] + + frames, *_ = get_frames_at_indices( + decoder, + stream_index=stream_index, + frame_indices=frame_indices, + sort_indices=sort_indices, + ) + for frame, expected_frame in zip(frames, expected_frames): + assert_tensor_equal(frame, expected_frame) + + # first and last frame should be equal, at index 2. We then modify the + # first frame and assert that it's now different from the last frame. + # This ensures a copy was properly made during the de-duplication logic. + assert_tensor_equal(frames[0], frames[-1]) + frames[0] += 20 + with pytest.raises(AssertionError): + assert_tensor_equal(frames[0], frames[-1]) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) From 133c2131cd2c51afb65954fa333ac39112601c99 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 11:49:34 +0100 Subject: [PATCH 04/27] minor opt --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index eb07ed8e..739be558 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1040,11 +1040,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); - std::vector argsort(frameIndices.size()); - for (size_t i = 0; i < argsort.size(); ++i) { - argsort[i] = i; - } + std::vector argsort; + if (sortIndices) { + argsort.resize(frameIndices.size()); + for (size_t i = 0; i < argsort.size(); ++i) { + argsort[i] = i; + } std::sort( argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) { return frameIndices[a] < frameIndices[b]; @@ -1053,15 +1055,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( auto previousIndexInVideo = -1; for (auto f = 0; f < frameIndices.size(); ++f) { - auto indexInOutput = argsort[f]; - auto indexInVideo = frameIndices[argsort[f]]; + auto indexInOutput = sortIndices ? argsort[f] : f; + auto indexInVideo = frameIndices[indexInOutput]; if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) { throw std::runtime_error( "Invalid frame index=" + std::to_string(indexInVideo)); } if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice - auto previousIndexInOutput = argsort[f - 1]; + auto previousIndexInOutput = sortIndices ? argsort[f - 1] : f - 1; output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]); output.ptsSeconds[indexInOutput] = output.ptsSeconds[previousIndexInOutput]; From f391582d867b0ed30dee560a08a6ed3683824a0f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 11:59:11 +0100 Subject: [PATCH 05/27] Comments --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 739be558..7a179efe 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1040,8 +1040,11 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); + // if frameIndices is [13, 10, 12, 11] + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want + // to use to decode the frames + // and argsort is [ 1, 3, 2, 0] std::vector argsort; - if (sortIndices) { argsort.resize(frameIndices.size()); for (size_t i = 0; i < argsort.size(); ++i) { From d475890cf11735f0fba3480a20e34b79f8ac08be Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 13:32:38 +0100 Subject: [PATCH 06/27] scaffolding --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 7 +++++++ src/torchcodec/decoders/_core/VideoDecoder.h | 7 +++++++ .../decoders/_core/VideoDecoderOps.cpp | 17 +++++++++++++++++ src/torchcodec/decoders/_core/VideoDecoderOps.h | 10 ++++++++-- src/torchcodec/decoders/_core/__init__.py | 1 + .../decoders/_core/video_decoder_ops.py | 16 ++++++++++++++++ 6 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7a179efe..90276f79 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1088,6 +1088,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( return output; } +VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss( + int streamIndex, + const std::vector& framePtss, + const bool sortPtss) { + return getFramesAtIndices(streamIndex, framePtss, sortPtss); + } + VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int streamIndex, int64_t start, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 6c401a70..7f34ff12 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -223,6 +223,7 @@ class VideoDecoder { // i.e. it will be returned when this function is called with seconds=5.0 or // seconds=5.999, etc. DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds); + DecodedOutput getFrameAtIndex( int streamIndex, int64_t frameIndex, @@ -243,6 +244,12 @@ class VideoDecoder { int streamIndex, const std::vector& frameIndices, const bool sortIndices = false); + + BatchDecodedOutput getFramesAtPtss( + int streamIndex, + const std::vector& framePtss, + const bool sortPtss = false); + // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from // the range are: diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index a48a2113..76b457c9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -41,6 +41,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)"); + m.def( + "get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, int[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -209,6 +211,20 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { return makeOpsDecodedOutput(result); } +OpsBatchDecodedOutput get_frames_at_ptss( + at::Tensor& decoder, + int64_t stream_index, + at::IntArrayRef frame_ptss, + bool sort_ptss) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + std::vector framePtssVec( + frame_ptss.begin(), frame_ptss.end()); + auto result = videoDecoder->getFramesAtPtss( + stream_index, framePtssVec, sort_ptss); + return makeOpsBatchDecodedOutput(result); +} + + OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, int64_t stream_index, @@ -485,6 +501,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); + m.impl("get_frames_at_ptss", &get_frames_at_ptss); m.impl("get_frames_in_range", &get_frames_in_range); m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range); m.impl("_test_frame_pts_equality", &_test_frame_pts_equality); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 2ec49d94..0da1edd0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -75,6 +75,13 @@ using OpsBatchDecodedOutput = std::tuple; // given timestamp T has T >= PTS and T < PTS + Duration. OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); +// Return the frames at given ptss for a given stream +OpsBatchDecodedOutput get_frames_at_ptss( + at::Tensor& decoder, + int64_t stream_index, + at::IntArrayRef frame_ptss, + bool sort_ptss = false); + // Return the frame that is visible at a given index in the video. OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, @@ -85,8 +92,7 @@ OpsDecodedOutput get_frame_at_index( // duration as tensors. OpsDecodedOutput get_next_frame(at::Tensor& decoder); -// Return the frames at a given index for a given stream as a single stacked -// Tensor. +// Return the frames at given indices for a given stream OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index bd761fe1..543cb74f 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -22,6 +22,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_at_ptss, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 6335f997..2f0036ab 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -71,6 +71,7 @@ def load_torchcodec_extension(): get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default +get_frames_at_ptss = torch.ops.torchcodec_ns.get_frames_at_ptss.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default @@ -171,6 +172,21 @@ def get_frame_at_pts_abstract( torch.empty([], dtype=torch.float), ) +@register_fake("torchcodec_ns::get_frames_at_ptss") +def get_frames_at_pts_abstract( + decoder: torch.Tensor, + *, + stream_index: int, + frame_ptss: List[int], + sort_ptss: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + image_size = [get_ctx().new_dynamic_size() for _ in range(4)] + return ( + torch.empty(image_size), + torch.empty([], dtype=torch.float), + torch.empty([], dtype=torch.float), + ) + @register_fake("torchcodec_ns::get_frame_at_index") def get_frame_at_index_abstract( From 14e287604e05d17d3f59987b7a3d289ebe5b5516 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 14:16:39 +0100 Subject: [PATCH 07/27] Added logic --- .../decoders/_core/VideoDecoder.cpp | 39 +++++++++++++++++-- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- .../decoders/_core/VideoDecoderOps.cpp | 12 +++--- .../decoders/_core/VideoDecoderOps.h | 2 +- .../decoders/_core/video_decoder_ops.py | 3 +- 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 90276f79..f71f2c6e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1090,10 +1090,43 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss( int streamIndex, - const std::vector& framePtss, + const std::vector& framePtss, const bool sortPtss) { - return getFramesAtIndices(streamIndex, framePtss, sortPtss); - } + validateUserProvidedStreamIndex(streamIndex); + validateScannedAllStreams("getFramesAtPtss"); + + // The frame displayed at timestamp t and the one displayed at timestamp `t + + // eps` are probably the same frame, with the same index. The easiest way to + // avoid decoding that unique frame twice is to convert the input timestamps + // to indices, and leverage the de-duplication logic of getFramesAtIndices. + + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& stream = streams_[streamIndex]; + double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); + double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); + + std::vector frameIndices(framePtss.size()); + for (auto i = 0; i < framePtss.size(); ++i) { + auto framePts = framePtss[i]; + TORCH_CHECK( + framePts >= minSeconds && framePts < maxSeconds, + "frame pts is " + std::to_string(framePts) + "; must be in range [" + + std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + + ")."); + + auto it = std::lower_bound( + stream.allFrames.begin(), + stream.allFrames.end(), + framePts, + [&stream](const FrameInfo& info, double start) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= start; + }); + int64_t frameIndex = it - stream.allFrames.begin(); + frameIndices[i] = frameIndex; + } + + return getFramesAtIndices(streamIndex, frameIndices, sortPtss); +} VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int streamIndex, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 7f34ff12..02d9c10c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -247,7 +247,7 @@ class VideoDecoder { BatchDecodedOutput getFramesAtPtss( int streamIndex, - const std::vector& framePtss, + const std::vector& framePtss, const bool sortPtss = false); // Returns frames within a given range for a given stream as a single stacked diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 76b457c9..dd6db3b4 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -42,7 +42,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, int[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)"); + "get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -214,17 +214,15 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { OpsBatchDecodedOutput get_frames_at_ptss( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_ptss, + at::ArrayRef frame_ptss, bool sort_ptss) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector framePtssVec( - frame_ptss.begin(), frame_ptss.end()); - auto result = videoDecoder->getFramesAtPtss( - stream_index, framePtssVec, sort_ptss); + std::vector framePtssVec(frame_ptss.begin(), frame_ptss.end()); + auto result = + videoDecoder->getFramesAtPtss(stream_index, framePtssVec, sort_ptss); return makeOpsBatchDecodedOutput(result); } - OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, int64_t stream_index, diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 0da1edd0..167387f0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -79,7 +79,7 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); OpsBatchDecodedOutput get_frames_at_ptss( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_ptss, + at::ArrayRef frame_ptss, bool sort_ptss = false); // Return the frame that is visible at a given index in the video. diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 2f0036ab..b80f9122 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -172,12 +172,13 @@ def get_frame_at_pts_abstract( torch.empty([], dtype=torch.float), ) + @register_fake("torchcodec_ns::get_frames_at_ptss") def get_frames_at_pts_abstract( decoder: torch.Tensor, *, stream_index: int, - frame_ptss: List[int], + frame_ptss: List[float], sort_ptss: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] From 9c9e4627d2b37da2b7a509d50d3ee897fa29d8dd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 14:21:38 +0100 Subject: [PATCH 08/27] Added test --- test/decoders/test_video_decoder_ops.py | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index dc875c83..f404ed9c 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -27,6 +27,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_at_ptss, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, @@ -157,6 +158,37 @@ def test_get_frames_at_indices_with_sort(self, sort_indices): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + @pytest.mark.parametrize("sort_ptss", (False, True)) + def test_get_frames_at_ptss_with_sort(self, sort_ptss): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream(decoder) + scan_all_streams_to_update_metadata(decoder) + stream_index = 3 + + frame_ptss = [2, 0, 1, 0 + 1e-3, 2 + 1e-3] + + expected_frames = [ + get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss + ] + + frames, *_ = get_frames_at_ptss( + decoder, + stream_index=stream_index, + frame_ptss=frame_ptss, + sort_ptss=sort_ptss, + ) + for frame, expected_frame in zip(frames, expected_frames): + assert_tensor_equal(frame, expected_frame) + + # # first and last frame should be equal, at pts=2 [+ eps]. We then + # modify the # first frame and assert that it's now different from the + # last frame. # This ensures a copy was properly made during the + # de-duplication logic. + assert_tensor_equal(frames[0], frames[-1]) + frames[0] += 20 + with pytest.raises(AssertionError): + assert_tensor_equal(frames[0], frames[-1]) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) From 2bce920bbfad12e4b5bfb6e52205d6629585f9e2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 17:08:50 +0100 Subject: [PATCH 09/27] Use C++ decoding APIs in sampler --- src/torchcodec/samplers/_common.py | 19 ------ src/torchcodec/samplers/_index_based.py | 82 ++++++++----------------- src/torchcodec/samplers/_time_based.py | 82 ++++++++----------------- test/samplers/test_samplers.py | 2 +- 4 files changed, 53 insertions(+), 132 deletions(-) diff --git a/src/torchcodec/samplers/_common.py b/src/torchcodec/samplers/_common.py index 46bf3b18..bcf8f675 100644 --- a/src/torchcodec/samplers/_common.py +++ b/src/torchcodec/samplers/_common.py @@ -1,8 +1,5 @@ from typing import Callable, Union -import torch -from torchcodec import Frame, FrameBatch - _LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]] @@ -42,22 +39,6 @@ def _error_policy( } -def _chunk_list(lst, chunk_size): - # return list of sublists of length chunk_size - return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] - - -def _to_framebatch(frames: list[Frame]) -> FrameBatch: - # IMPORTANT: see other IMPORTANT note in _decode_all_clips_indices and - # _decode_all_clips_timestamps - data = torch.stack([frame.data for frame in frames]) - pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) - duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) - return FrameBatch( - data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds - ) - - def _validate_common_params(*, decoder, num_frames_per_clip, policy): if len(decoder) < 1: raise ValueError( diff --git a/src/torchcodec/samplers/_index_based.py b/src/torchcodec/samplers/_index_based.py index 25e4bd32..d33b6279 100644 --- a/src/torchcodec/samplers/_index_based.py +++ b/src/torchcodec/samplers/_index_based.py @@ -1,14 +1,13 @@ -from typing import List, Literal, Optional +from typing import Literal, Optional import torch -from torchcodec import Frame, FrameBatch +from torchcodec import FrameBatch from torchcodec.decoders import VideoDecoder +from torchcodec.decoders._core import get_frames_at_indices from torchcodec.samplers._common import ( - _chunk_list, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, - _to_framebatch, _validate_common_params, ) @@ -117,51 +116,6 @@ def _build_all_clips_indices( return all_clips_indices -def _decode_all_clips_indices( - decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int -) -> list[FrameBatch]: - # This takes the list of all the frames to decode (in arbitrary order), - # decode all the frames, and then packs them into clips of length - # num_frames_per_clip. - # - # To avoid backwards seeks (which are slow), we: - # - sort all the frame indices to be decoded - # - dedup them - # - decode all unique frames in sorted order - # - re-assemble the decoded frames back to their original order - # - # TODO: Write this in C++ so we can avoid the copies that happen in `_to_framebatch` - - all_clips_indices_sorted, argsort = zip( - *sorted((frame_index, i) for (i, frame_index) in enumerate(all_clips_indices)) - ) - previous_decoded_frame = None - all_decoded_frames = [None] * len(all_clips_indices) - for i, j in enumerate(argsort): - frame_index = all_clips_indices_sorted[i] - if ( - previous_decoded_frame is not None # then we know i > 0 - and frame_index == all_clips_indices_sorted[i - 1] - ): - # Avoid decoding the same frame twice. - # IMPORTANT: this is only correct because a copy of the frame will - # happen within `_to_framebatch` when we call torch.stack. - # If a copy isn't made, the same underlying memory will be used for - # the 2 consecutive frames. When we re-write this, we should make - # sure to explicitly copy the data. - decoded_frame = previous_decoded_frame - else: - decoded_frame = decoder.get_frame_at(index=frame_index) - previous_decoded_frame = decoded_frame - all_decoded_frames[j] = decoded_frame - - all_clips: list[list[Frame]] = _chunk_list( - all_decoded_frames, chunk_size=num_frames_per_clip - ) - - return [_to_framebatch(clip) for clip in all_clips] - - def _generic_index_based_sampler( kind: Literal["random", "regular"], decoder: VideoDecoder, @@ -174,7 +128,7 @@ def _generic_index_based_sampler( # Important note: sampling_range_end defines the upper bound of where a clip # can *start*, not where a clip can end. policy: Literal["repeat_last", "wrap", "error"], -) -> List[FrameBatch]: +) -> FrameBatch: _validate_common_params( decoder=decoder, @@ -221,11 +175,27 @@ def _generic_index_based_sampler( num_frames_in_video=len(decoder), policy_fun=_POLICY_FUNCTIONS[policy], ) - return _decode_all_clips_indices( - decoder, - all_clips_indices=all_clips_indices, - num_frames_per_clip=num_frames_per_clip, + + frames, pts_seconds, duration_seconds = get_frames_at_indices( + decoder._decoder, + stream_index=decoder.stream_index, + frame_indices=all_clips_indices, + sort_indices=True, + ) + last_3_dims = frames.shape[-3:] + out = FrameBatch( + data=frames.view(num_clips, num_frames_per_clip, *last_3_dims), + pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), + duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), ) + return [ + FrameBatch( + out.data[i], + out.pts_seconds[i], + out.duration_seconds[i], + ) + for i in range(out.data.shape[0]) + ] def clips_at_random_indices( @@ -237,7 +207,7 @@ def clips_at_random_indices( sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: return _generic_index_based_sampler( kind="random", decoder=decoder, @@ -259,7 +229,7 @@ def clips_at_regular_indices( sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: return _generic_index_based_sampler( kind="regular", diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index f890d216..c10d7913 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -2,13 +2,11 @@ import torch -from torchcodec import Frame, FrameBatch -from torchcodec.decoders import VideoDecoder +from torchcodec import FrameBatch +from torchcodec.decoders._core import get_frames_at_ptss from torchcodec.samplers._common import ( - _chunk_list, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, - _to_framebatch, _validate_common_params, ) @@ -147,51 +145,6 @@ def _build_all_clips_timestamps( return all_clips_timestamps -def _decode_all_clips_timestamps( - decoder: VideoDecoder, all_clips_timestamps: list[float], num_frames_per_clip: int -) -> list[FrameBatch]: - # This is 99% the same as _decode_all_clips_indices. The only change is the - # call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx) - - all_clips_timestamps_sorted, argsort = zip( - *sorted( - (frame_index, i) for (i, frame_index) in enumerate(all_clips_timestamps) - ) - ) - previous_decoded_frame = None - all_decoded_frames = [None] * len(all_clips_timestamps) - for i, j in enumerate(argsort): - frame_pts_seconds = all_clips_timestamps_sorted[i] - if ( - previous_decoded_frame is not None # then we know i > 0 - and frame_pts_seconds == all_clips_timestamps_sorted[i - 1] - ): - # Avoid decoding the same frame twice. - # Unfortunatly this is unlikely to lead to speed-up as-is: it's - # pretty unlikely that 2 pts will be the same since pts are float - # contiguous values. Theoretically the dedup can still happen, but - # it would be much more efficient to implement it at the frame index - # level. We should do that once we implement that in C++. - # See also https://github.com/pytorch/torchcodec/issues/256. - # - # IMPORTANT: this is only correct because a copy of the frame will - # happen within `_to_framebatch` when we call torch.stack. - # If a copy isn't made, the same underlying memory will be used for - # the 2 consecutive frames. When we re-write this, we should make - # sure to explicitly copy the data. - decoded_frame = previous_decoded_frame - else: - decoded_frame = decoder.get_frame_displayed_at(seconds=frame_pts_seconds) - previous_decoded_frame = decoded_frame - all_decoded_frames[j] = decoded_frame - - all_clips: list[list[Frame]] = _chunk_list( - all_decoded_frames, chunk_size=num_frames_per_clip - ) - - return [_to_framebatch(clip) for clip in all_clips] - - def _generic_time_based_sampler( kind: Literal["random", "regular"], decoder, @@ -204,7 +157,7 @@ def _generic_time_based_sampler( sampling_range_start: Optional[float], sampling_range_end: Optional[float], # interval is [start, end). policy: str = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: # Note: *everywhere*, sampling_range_end denotes the upper bound of where a # clip can start. This is an *open* upper bound, i.e. we will make sure no # clip starts exactly at (or above) sampling_range_end. @@ -246,6 +199,7 @@ def _generic_time_based_sampler( sampling_range_end, # excluded seconds_between_clip_starts, ) + num_clips = len(clip_start_seconds) all_clips_timestamps = _build_all_clips_timestamps( clip_start_seconds=clip_start_seconds, @@ -255,11 +209,27 @@ def _generic_time_based_sampler( policy_fun=_POLICY_FUNCTIONS[policy], ) - return _decode_all_clips_timestamps( - decoder, - all_clips_timestamps=all_clips_timestamps, - num_frames_per_clip=num_frames_per_clip, + frames, pts_seconds, duration_seconds = get_frames_at_ptss( + decoder._decoder, + stream_index=decoder.stream_index, + frame_ptss=all_clips_timestamps, + sort_ptss=True, ) + last_3_dims = frames.shape[-3:] + + out = FrameBatch( + data=frames.view(num_clips, num_frames_per_clip, *last_3_dims), + pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), + duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), + ) + return [ + FrameBatch( + out.data[i], + out.pts_seconds[i], + out.duration_seconds[i], + ) + for i in range(out.data.shape[0]) + ] def clips_at_random_timestamps( @@ -272,7 +242,7 @@ def clips_at_random_timestamps( sampling_range_start: Optional[float] = None, sampling_range_end: Optional[float] = None, # interval is [start, end). policy: str = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: return _generic_time_based_sampler( kind="random", decoder=decoder, @@ -296,7 +266,7 @@ def clips_at_regular_timestamps( sampling_range_start: Optional[float] = None, sampling_range_end: Optional[float] = None, # interval is [start, end). policy: str = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: return _generic_time_based_sampler( kind="regular", decoder=decoder, diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 3149a541..bae359f8 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -130,7 +130,7 @@ def test_time_based_sampler(sampler, seconds_between_frames): if sampler.func is clips_at_regular_timestamps: seconds_between_clip_starts = sampler.keywords["seconds_between_clip_starts"] expected_seconds_between_clip_starts = torch.tensor( - [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float + [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float64 ) _assert_regular_sampler( clips=clips, From b8284cc84b575709c87a431bfb339ea7a943dbfd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 22 Oct 2024 17:17:45 +0100 Subject: [PATCH 10/27] Remove parameter, just sort if not already sorted --- .../decoders/_core/VideoDecoder.cpp | 20 ++++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 3 +-- .../decoders/_core/VideoDecoderOps.cpp | 8 +++----- .../decoders/_core/VideoDecoderOps.h | 3 +-- .../decoders/_core/video_decoder_ops.py | 1 - test/decoders/test_video_decoder_ops.py | 4 +--- 6 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7a179efe..416d96a9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1030,8 +1030,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, - const std::vector& frameIndices, - const bool sortIndices) { + const std::vector& frameIndices) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFramesAtIndices"); @@ -1040,12 +1039,15 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); - // if frameIndices is [13, 10, 12, 11] - // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want - // to use to decode the frames - // and argsort is [ 1, 3, 2, 0] + auto indicesAreSorted = + std::is_sorted(frameIndices.begin(), frameIndices.end()); + std::vector argsort; - if (sortIndices) { + if (!indicesAreSorted) { + // if frameIndices is [13, 10, 12, 11] + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want + // to use to decode the frames + // and argsort is [ 1, 3, 2, 0] argsort.resize(frameIndices.size()); for (size_t i = 0; i < argsort.size(); ++i) { argsort[i] = i; @@ -1058,7 +1060,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( auto previousIndexInVideo = -1; for (auto f = 0; f < frameIndices.size(); ++f) { - auto indexInOutput = sortIndices ? argsort[f] : f; + auto indexInOutput = indicesAreSorted ? f : argsort[f]; auto indexInVideo = frameIndices[indexInOutput]; if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) { throw std::runtime_error( @@ -1066,7 +1068,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( } if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice - auto previousIndexInOutput = sortIndices ? argsort[f - 1] : f - 1; + auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]); output.ptsSeconds[indexInOutput] = output.ptsSeconds[previousIndexInOutput]; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 6c401a70..2adbfac6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -241,8 +241,7 @@ class VideoDecoder { // Tensor. BatchDecodedOutput getFramesAtIndices( int streamIndex, - const std::vector& frameIndices, - const bool sortIndices = false); + const std::vector& frameIndices); // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from // the range are: diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index a48a2113..0be871a3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)"); + "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -221,13 +221,11 @@ OpsDecodedOutput get_frame_at_index( OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_indices, - bool sort_indices) { + at::IntArrayRef frame_indices) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndices( - stream_index, frameIndicesVec, sort_indices); + auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); return makeOpsBatchDecodedOutput(result); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 2ec49d94..5b442025 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -90,8 +90,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder); OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, - at::IntArrayRef frame_indices, - bool sort_indices = false); + at::IntArrayRef frame_indices); // Return the frames inside a range as a single stacked Tensor. The range is // defined as [start, stop). diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 6335f997..01de6ad6 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -190,7 +190,6 @@ def get_frames_at_indices_abstract( *, stream_index: int, frame_indices: List[int], - sort_indices: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index dc875c83..bbd9fe4e 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -124,8 +124,7 @@ def test_get_frames_at_indices(self): assert_tensor_equal(frames0and180[0], reference_frame0) assert_tensor_equal(frames0and180[1], reference_frame180) - @pytest.mark.parametrize("sort_indices", (False, True)) - def test_get_frames_at_indices_with_sort(self, sort_indices): + def test_get_frames_at_indices_unsorted_indices(self): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder) scan_all_streams_to_update_metadata(decoder) @@ -144,7 +143,6 @@ def test_get_frames_at_indices_with_sort(self, sort_indices): decoder, stream_index=stream_index, frame_indices=frame_indices, - sort_indices=sort_indices, ) for frame, expected_frame in zip(frames, expected_frames): assert_tensor_equal(frame, expected_frame) From 4dda5b78b7ddf039c7f94774db80a4b6bc0b0798 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 11:10:32 +0100 Subject: [PATCH 11/27] Rename --- .../decoders/_core/VideoDecoder.cpp | 6 ++--- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- .../decoders/_core/VideoDecoderOps.cpp | 27 +++++++++---------- .../decoders/_core/VideoDecoderOps.h | 2 +- src/torchcodec/decoders/_core/__init__.py | 2 +- .../decoders/_core/video_decoder_ops.py | 6 ++--- test/decoders/test_video_decoder_ops.py | 6 ++--- 7 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 8eefc2f1..17354706 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1090,11 +1090,11 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( return output; } -VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss( +VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( int streamIndex, - const std::vector& framePtss){ + const std::vector& framePtss) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesAtPtss"); + validateScannedAllStreams("getFramesDisplayedByTimestamps"); // The frame displayed at timestamp t and the one displayed at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c2f19816..5eab70bf 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -244,7 +244,7 @@ class VideoDecoder { int streamIndex, const std::vector& frameIndices); - BatchDecodedOutput getFramesAtPtss( + BatchDecodedOutput getFramesDisplayedByTimestamps( int streamIndex, const std::vector& framePtss); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 8f168be9..fbc739b3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -41,12 +41,12 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); - m.def( - "get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); + m.def( + "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss) -> (Tensor, Tensor, Tensor)"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -211,17 +211,6 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { return makeOpsDecodedOutput(result); } -OpsBatchDecodedOutput get_frames_at_ptss( - at::Tensor& decoder, - int64_t stream_index, - at::ArrayRef frame_ptss) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector framePtssVec(frame_ptss.begin(), frame_ptss.end()); - auto result = - videoDecoder->getFramesAtPtss(stream_index, framePtssVec); - return makeOpsBatchDecodedOutput(result); -} - OpsDecodedOutput get_frame_at_index( at::Tensor& decoder, int64_t stream_index, @@ -253,6 +242,16 @@ OpsBatchDecodedOutput get_frames_in_range( stream_index, start, stop, step.value_or(1)); return makeOpsBatchDecodedOutput(result); } +OpsBatchDecodedOutput get_frames_by_pts( + at::Tensor& decoder, + int64_t stream_index, + at::ArrayRef frame_ptss) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + std::vector framePtssVec(frame_ptss.begin(), frame_ptss.end()); + auto result = + videoDecoder->getFramesDisplayedByTimestamps(stream_index, framePtssVec); + return makeOpsBatchDecodedOutput(result); +} OpsBatchDecodedOutput get_frames_by_pts_in_range( at::Tensor& decoder, @@ -496,9 +495,9 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); - m.impl("get_frames_at_ptss", &get_frames_at_ptss); m.impl("get_frames_in_range", &get_frames_in_range); m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range); + m.impl("get_frames_by_pts", &get_frames_by_pts); m.impl("_test_frame_pts_equality", &_test_frame_pts_equality); m.impl( "scan_all_streams_to_update_metadata", diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 3b3d3615..12cce81c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -76,7 +76,7 @@ using OpsBatchDecodedOutput = std::tuple; OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); // Return the frames at given ptss for a given stream -OpsBatchDecodedOutput get_frames_at_ptss( +OpsBatchDecodedOutput get_frames_by_pts( at::Tensor& decoder, int64_t stream_index, at::ArrayRef frame_ptss); diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index 543cb74f..a1ac9a47 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -22,7 +22,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, - get_frames_at_ptss, + get_frames_by_pts, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 3a3b36c3..8f6c3c08 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -71,7 +71,7 @@ def load_torchcodec_extension(): get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default -get_frames_at_ptss = torch.ops.torchcodec_ns.get_frames_at_ptss.default +get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default @@ -173,8 +173,8 @@ def get_frame_at_pts_abstract( ) -@register_fake("torchcodec_ns::get_frames_at_ptss") -def get_frames_at_pts_abstract( +@register_fake("torchcodec_ns::get_frames_by_pts") +def get_frames_by_pts_abstract( decoder: torch.Tensor, *, stream_index: int, diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 153c4905..84cc560e 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -27,7 +27,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, - get_frames_at_ptss, + get_frames_by_pts, get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, @@ -156,7 +156,7 @@ def test_get_frames_at_indices_unsorted_indices(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) - def test_get_frames_at_ptss(self): + def test_get_frames_by_pts(self): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder) scan_all_streams_to_update_metadata(decoder) @@ -168,7 +168,7 @@ def test_get_frames_at_ptss(self): get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss ] - frames, *_ = get_frames_at_ptss( + frames, *_ = get_frames_by_pts( decoder, stream_index=stream_index, frame_ptss=frame_ptss, From 7d266239f00fb31f89c1acb8d709f0dae7a771fa Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 11:42:55 +0100 Subject: [PATCH 12/27] Fix last frame request --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 1 + test/decoders/test_video_decoder_ops.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 17354706..b6e4f6ce 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1123,6 +1123,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( return ptsToSeconds(info.nextPts, stream.timeBase) <= start; }); int64_t frameIndex = it - stream.allFrames.begin(); + frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 84cc560e..adc1392f 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -162,7 +162,8 @@ def test_get_frames_by_pts(self): scan_all_streams_to_update_metadata(decoder) stream_index = 3 - frame_ptss = [2, 0, 1, 0 + 1e-3, 2 + 1e-3] + # Note: 13.01 should give the last video frame for the NASA video + frame_ptss = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] expected_frames = [ get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss From 3a8839d63086926829f731b18916737183d79bb2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 11:49:03 +0100 Subject: [PATCH 13/27] Better fix --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index b6e4f6ce..42a42b38 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1117,13 +1117,12 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), - stream.allFrames.end(), + stream.allFrames.end() - 1, framePts, [&stream](const FrameInfo& info, double start) { return ptsToSeconds(info.nextPts, stream.timeBase) <= start; }); int64_t frameIndex = it - stream.allFrames.begin(); - frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } From a76a6add85f3793ae9583244ef92daf6f2dc3e6a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 12:42:00 +0100 Subject: [PATCH 14/27] Clean up --- src/torchcodec/_frame.py | 10 ++++ src/torchcodec/samplers/_index_based.py | 11 +---- src/torchcodec/samplers/_time_based.py | 17 ++----- test/samplers/test_samplers.py | 63 ++++++++++++------------- 4 files changed, 46 insertions(+), 55 deletions(-) diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index c847f57b..273f837f 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -61,5 +61,15 @@ def __iter__(self) -> Iterator[Union[Tensor, float]]: for field in dataclasses.fields(self): yield getattr(self, field.name) + def __getitem__(self, key): + return FrameBatch( + self.data[key], + self.pts_seconds[key], + self.duration_seconds[key], + ) + + def __len__(self): + return len(self.data) + def __repr__(self): return _frame_repr(self) diff --git a/src/torchcodec/samplers/_index_based.py b/src/torchcodec/samplers/_index_based.py index d33b6279..67bb1201 100644 --- a/src/torchcodec/samplers/_index_based.py +++ b/src/torchcodec/samplers/_index_based.py @@ -180,22 +180,13 @@ def _generic_index_based_sampler( decoder._decoder, stream_index=decoder.stream_index, frame_indices=all_clips_indices, - sort_indices=True, ) last_3_dims = frames.shape[-3:] - out = FrameBatch( + return FrameBatch( data=frames.view(num_clips, num_frames_per_clip, *last_3_dims), pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), ) - return [ - FrameBatch( - out.data[i], - out.pts_seconds[i], - out.duration_seconds[i], - ) - for i in range(out.data.shape[0]) - ] def clips_at_random_indices( diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index c10d7913..10fcd2e3 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -1,9 +1,9 @@ -from typing import List, Literal, Optional +from typing import Literal, Optional import torch from torchcodec import FrameBatch -from torchcodec.decoders._core import get_frames_at_ptss +from torchcodec.decoders._core import get_frames_by_pts from torchcodec.samplers._common import ( _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, @@ -209,27 +209,18 @@ def _generic_time_based_sampler( policy_fun=_POLICY_FUNCTIONS[policy], ) - frames, pts_seconds, duration_seconds = get_frames_at_ptss( + frames, pts_seconds, duration_seconds = get_frames_by_pts( decoder._decoder, stream_index=decoder.stream_index, frame_ptss=all_clips_timestamps, - sort_ptss=True, ) last_3_dims = frames.shape[-3:] - out = FrameBatch( + return FrameBatch( data=frames.view(num_clips, num_frames_per_clip, *last_3_dims), pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), ) - return [ - FrameBatch( - out.data[i], - out.pts_seconds[i], - out.duration_seconds[i], - ) - for i in range(out.data.shape[0]) - ] def clips_at_random_timestamps( diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index bae359f8..31375330 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -25,23 +25,21 @@ def _assert_output_type_and_shapes( video, clips, expected_num_clips, num_frames_per_clip ): - assert isinstance(clips, list) - assert len(clips) == expected_num_clips - assert all(isinstance(clip, FrameBatch) for clip in clips) - expected_clip_data_shape = ( + assert isinstance(clips, FrameBatch) + # assert len(clips) == expected_num_clips + # assert all(isinstance(clip, FrameBatch) for clip in clips) + expected_clips_data_shape = ( + expected_num_clips, num_frames_per_clip, 3, video.height, video.width, ) - assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + assert clips.data.shape == expected_clips_data_shape def _assert_regular_sampler(clips, expected_seconds_between_clip_starts=None): - # assert regular spacing between sampled clips - seconds_between_clip_starts = torch.tensor( - [clip.pts_seconds[0] for clip in clips] - ).diff() + seconds_between_clip_starts = clips.pts_seconds[:, 0].diff() if expected_seconds_between_clip_starts is not None: # This can only be asserted with the time-based sampler, where @@ -88,10 +86,7 @@ def test_index_based_sampler(sampler, num_indices_between_frames): # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. - - avg_distance_between_frames_seconds = torch.concat( - [clip.pts_seconds.diff() for clip in clips] - ).mean() + avg_distance_between_frames_seconds = clips.pts_seconds.diff(dim=1).mean() assert avg_distance_between_frames_seconds == pytest.approx( num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5 ) @@ -140,10 +135,8 @@ def test_time_based_sampler(sampler, seconds_between_frames): expected_seconds_between_frames = ( seconds_between_frames or 1 / decoder.metadata.average_fps ) - avg_seconds_between_frames_seconds = torch.concat( - [clip.pts_seconds.diff() for clip in clips] - ).mean() - assert avg_seconds_between_frames_seconds == pytest.approx( + avg_seconds_between_frames = clips.pts_seconds.diff(dim=1).mean() + assert avg_seconds_between_frames == pytest.approx( expected_seconds_between_frames, abs=0.05 ) @@ -208,8 +201,8 @@ def test_sampling_range( else pytest.raises(AssertionError, match="Tensor-likes are not") ) with cm: - for clip in clips: - assert_tensor_equal(clip.data, clips[0].data) + for clip_data in clips.data: + assert_tensor_equal(clip_data, clips.data[0]) @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @@ -236,11 +229,11 @@ def test_sampling_range_negative(sampler): ) # There is only one unique clip in clips_1... - for clip in clips_1: - assert_tensor_equal(clip.data, clips_1[0].data) + for clip_data in clips_1.data: + assert_tensor_equal(clip_data, clips_1.data[0]) # ... and it's the same that's in clips_2 - for clip in clips_2: - assert_tensor_equal(clip.data, clips_1[0].data) + for clip_data in clips_2.data: + assert_tensor_equal(clip_data, clips_1.data[0]) @pytest.mark.parametrize( @@ -284,7 +277,8 @@ def test_sampling_range_default_behavior_random_sampler(sampler): policy="error", ) - last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) + # last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) + last_clip_start_default = clips_default.pts_seconds[:, 0].max() # with manual sampling_range_end value set to last frame / end of video clips_manual = sampler( @@ -294,7 +288,7 @@ def test_sampling_range_default_behavior_random_sampler(sampler): sampling_range_start=sampling_range_start, sampling_range_end=1000, ) - last_clip_start_manual = max([clip.pts_seconds[0] for clip in clips_manual]) + last_clip_start_manual = clips_manual.pts_seconds[:, 0].max() assert last_clip_start_manual - last_clip_start_default > 0.3 @@ -382,22 +376,27 @@ def test_random_sampler_randomness(sampler): # Assert the clip starts aren't sorted, to make sure we haven't messed up # the implementation. (This may fail if we're unlucky, but we hard-coded a # seed, so it will always pass.) - clip_starts = [clip.pts_seconds.item() for clip in clips_1] + # clip_starts = [clip.pts_seconds.item() for clip in clips_1] + clip_starts = clips_1.pts_seconds[:, 0].tolist() assert sorted(clip_starts) != clip_starts # Call the same sampler again with the same seed, expect same results torch.manual_seed(0) clips_2 = sampler(decoder, num_clips=num_clips) - for clip_1, clip_2 in zip(clips_1, clips_2): - assert_tensor_equal(clip_1.data, clip_2.data) - assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds) - assert_tensor_equal(clip_1.duration_seconds, clip_2.duration_seconds) + for clip_1_data, clip_2_data in zip(clips_1.data, clips_2.data): + assert_tensor_equal(clip_1_data, clip_2_data) + for clip_1_pts, clip_2_pts in zip(clips_1.pts_seconds, clips_2.pts_seconds): + assert_tensor_equal(clip_1_pts, clip_2_pts) + for clip_1_duration, clip_2_duration in zip( + clips_1.duration_seconds, clips_2.duration_seconds + ): + assert_tensor_equal(clip_1_duration, clip_2_duration) # Call with a different seed, expect different results torch.manual_seed(1) clips_3 = sampler(decoder, num_clips=num_clips) with pytest.raises(AssertionError, match="Tensor-likes are not"): - assert_tensor_equal(clips_1[0].data, clips_3[0].data) + assert_tensor_equal(clips_1.data[0], clips_3.data[0]) # Make sure we didn't alter the builtin Python RNG builtin_random_state_end = random.getstate() @@ -427,7 +426,7 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz assert len(clips) == num_clips - clip_starts_seconds = torch.tensor([clip.pts_seconds[0] for clip in clips]) + clip_starts_seconds = clips.pts_seconds[:, 0] assert len(torch.unique(clip_starts_seconds)) == sampling_range_size # Assert clips starts are ordered, i.e. the start indices don't just "wrap From 14825293eb134d39d07b739d9896e3ba7b774ce9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 14:18:44 +0100 Subject: [PATCH 15/27] Frame and FrameBatch improvements --- src/torchcodec/_frame.py | 47 ++++++++++++- test/test_frame_dataclasses.py | 121 +++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 3 deletions(-) create mode 100644 test/test_frame_dataclasses.py diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index c847f57b..e6013ba8 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -38,6 +38,13 @@ class Frame(Iterable): duration_seconds: float """The duration of the frame, in seconds (float).""" + def __post_init__(self): + if not self.data.ndim == 3: + raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") + + self.pts_seconds = float(self.pts_seconds) + self.duration_seconds = float(self.duration_seconds) + def __iter__(self) -> Iterator[Union[Tensor, float]]: for field in dataclasses.fields(self): yield getattr(self, field.name) @@ -57,9 +64,43 @@ class FrameBatch(Iterable): duration_seconds: Tensor """The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" - def __iter__(self) -> Iterator[Union[Tensor, float]]: - for field in dataclasses.fields(self): - yield getattr(self, field.name) + def __post_init__(self): + if self.data.ndim < 4: + raise ValueError( + f"data must be at least 4-dimensional. Got {self.data.shape = } " + "For 3-dimensional data, create a Frame object instead." + ) + + leading_dims = self.data.shape[:-3] + if not (leading_dims == self.pts_seconds.shape == self.duration_seconds.shape): + raise ValueError( + "Tried to create a FrameBatch but the leading dimensions of the inputs do not match. " + f"Got {self.data.shape = } so we expected the shape of pts_seconds and " + f"duration_seconds to be {leading_dims = }, but got " + f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }." + ) + + def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]: + cls = Frame if self.data.ndim == 4 else FrameBatch + for data, pts_seconds, duration_seconds in zip( + self.data, self.pts_seconds, self.duration_seconds + ): + yield cls( + data=data, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + ) + + def __getitem__(self, key) -> Union["FrameBatch", Frame]: + cls = Frame if self.data.ndim == 4 else FrameBatch + return cls( + self.data[key], + self.pts_seconds[key], + self.duration_seconds[key], + ) + + def __len__(self): + return len(self.data) def __repr__(self): return _frame_repr(self) diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py new file mode 100644 index 00000000..9b79b882 --- /dev/null +++ b/test/test_frame_dataclasses.py @@ -0,0 +1,121 @@ +import pytest +import torch +from torchcodec import Frame, FrameBatch + + +def test_frame_unpacking(): + data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa + + +def test_frame_error(): + with pytest.raises(ValueError, match="data must be 3-dimensional"): + Frame( + data=torch.rand(1, 2), + pts_seconds=1, + duration_seconds=1, + ) + with pytest.raises(ValueError, match="data must be 3-dimensional"): + Frame( + data=torch.rand(1, 2, 3, 4), + pts_seconds=1, + duration_seconds=1, + ) + + +def test_framebatch_error(): + with pytest.raises(ValueError, match="data must be at least 4-dimensional"): + FrameBatch( + data=torch.rand(1, 2, 3), + pts_seconds=torch.rand(1), + duration_seconds=torch.rand(1), + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(3, 4, 2, 1), + pts_seconds=torch.rand(3), # ok + duration_seconds=torch.rand(2), # bad + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(3, 4, 2, 1), + pts_seconds=torch.rand(2), # bad + duration_seconds=torch.rand(3), # ok + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(5, 3, 4, 2, 1), + pts_seconds=torch.rand(5, 3), # ok + duration_seconds=torch.rand(5, 2), # bad + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(5, 3, 4, 2, 1), + pts_seconds=torch.rand(5, 2), # bad + duration_seconds=torch.rand(5, 3), # ok + ) + + +def test_framebatch_iteration(): + T, N, C, H, W = 7, 6, 3, 2, 4 + + fb = FrameBatch( + data=torch.rand(T, N, C, H, W), + pts_seconds=torch.rand(T, N), + duration_seconds=torch.rand(T, N), + ) + + for sub_fb in fb: + assert isinstance(sub_fb, FrameBatch) + assert sub_fb.data.shape == (N, C, H, W) + assert sub_fb.pts_seconds.shape == (N,) + assert sub_fb.duration_seconds.shape == (N,) + for frame in sub_fb: + assert isinstance(frame, Frame) + assert frame.data.shape == (C, H, W) + assert isinstance(frame.pts_seconds, float) + assert isinstance(frame.duration_seconds, float) + + # Check unpacking behavior + first_sub_fb, *_ = fb + assert isinstance(first_sub_fb, FrameBatch) + + +def test_framebatch_indexing(): + T, N, C, H, W = 7, 6, 3, 2, 4 + + fb = FrameBatch( + data=torch.rand(T, N, C, H, W), + pts_seconds=torch.rand(T, N), + duration_seconds=torch.rand(T, N), + ) + + for i in range(len(fb)): + assert isinstance(fb[i], FrameBatch) + assert fb[i].data.shape == (N, C, H, W) + assert fb[i].pts_seconds.shape == (N,) + assert fb[i].duration_seconds.shape == (N,) + for j in range(len(fb[i])): + assert isinstance(fb[i][j], Frame) + assert fb[i][j].data.shape == (C, H, W) + assert isinstance(fb[i][j].pts_seconds, float) + assert isinstance(fb[i][j].duration_seconds, float) + + fb_fancy = fb[torch.arange(3)] + assert isinstance(fb_fancy, FrameBatch) + assert fb_fancy.data.shape == (3, N, C, H, W) + + fb_fancy = fb[[[0], [1]]] # select T=0 and N=1. + assert isinstance(fb_fancy, FrameBatch) + assert fb_fancy.data.shape == (1, C, H, W) From 46612374677ff4dbbea5fdf2b175bc1090c4d1fa Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 14:28:56 +0100 Subject: [PATCH 16/27] Fix mypy? --- src/torchcodec/_frame.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index e6013ba8..b9542df4 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -41,7 +41,6 @@ class Frame(Iterable): def __post_init__(self): if not self.data.ndim == 3: raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") - self.pts_seconds = float(self.pts_seconds) self.duration_seconds = float(self.duration_seconds) @@ -92,12 +91,21 @@ def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]: ) def __getitem__(self, key) -> Union["FrameBatch", Frame]: - cls = Frame if self.data.ndim == 4 else FrameBatch - return cls( - self.data[key], - self.pts_seconds[key], - self.duration_seconds[key], - ) + data = self.data[key] + pts_seconds = self.pts_seconds[key] + duration_seconds = self.duration_seconds[key] + if self.data.ndim == 4: + return Frame( + data=data, + pts_seconds=float(pts_seconds.item()), + duration_seconds=float(duration_seconds.item()), + ) + else: + return FrameBatch( + data=data, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + ) def __len__(self): return len(self.data) From d43dd9120d227ad922079830c05b3c8fa3662de4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 15:41:57 +0100 Subject: [PATCH 17/27] Use timestamps as parameter name --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 8 ++++---- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 8 ++++---- src/torchcodec/decoders/_core/VideoDecoderOps.h | 2 +- src/torchcodec/decoders/_core/video_decoder_ops.py | 2 +- test/decoders/test_video_decoder_ops.py | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 42a42b38..bc3f1198 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1092,7 +1092,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( int streamIndex, - const std::vector& framePtss) { + const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFramesDisplayedByTimestamps"); @@ -1106,9 +1106,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); - std::vector frameIndices(framePtss.size()); - for (auto i = 0; i < framePtss.size(); ++i) { - auto framePts = framePtss[i]; + std::vector frameIndices(timestamps.size()); + for (auto i = 0; i < timestamps.size(); ++i) { + auto framePts = timestamps[i]; TORCH_CHECK( framePts >= minSeconds && framePts < maxSeconds, "frame pts is " + std::to_string(framePts) + "; must be in range [" + diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 5eab70bf..c0f489ce 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -246,7 +246,7 @@ class VideoDecoder { BatchDecodedOutput getFramesDisplayedByTimestamps( int streamIndex, - const std::vector& framePtss); + const std::vector& timestamps); // Returns frames within a given range for a given stream as a single stacked // Tensor. The range is defined by [start, stop). The values retrieved from diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index fbc739b3..6b91853c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -46,7 +46,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss) -> (Tensor, Tensor, Tensor)"); + "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -245,11 +245,11 @@ OpsBatchDecodedOutput get_frames_in_range( OpsBatchDecodedOutput get_frames_by_pts( at::Tensor& decoder, int64_t stream_index, - at::ArrayRef frame_ptss) { + at::ArrayRef timestamps) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector framePtssVec(frame_ptss.begin(), frame_ptss.end()); + std::vector timestampsVec(timestamps.begin(), timestamps.end()); auto result = - videoDecoder->getFramesDisplayedByTimestamps(stream_index, framePtssVec); + videoDecoder->getFramesDisplayedByTimestamps(stream_index, timestampsVec); return makeOpsBatchDecodedOutput(result); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 12cce81c..eac489ce 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -79,7 +79,7 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds); OpsBatchDecodedOutput get_frames_by_pts( at::Tensor& decoder, int64_t stream_index, - at::ArrayRef frame_ptss); + at::ArrayRef timestamps); // Return the frame that is visible at a given index in the video. OpsDecodedOutput get_frame_at_index( diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 8f6c3c08..d4102ae5 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -178,7 +178,7 @@ def get_frames_by_pts_abstract( decoder: torch.Tensor, *, stream_index: int, - frame_ptss: List[float], + timestamps: List[float], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index adc1392f..8535abee 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -163,16 +163,16 @@ def test_get_frames_by_pts(self): stream_index = 3 # Note: 13.01 should give the last video frame for the NASA video - frame_ptss = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] + timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] expected_frames = [ - get_frame_at_pts(decoder, seconds=pts)[0] for pts in frame_ptss + get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps ] frames, *_ = get_frames_by_pts( decoder, stream_index=stream_index, - frame_ptss=frame_ptss, + timestamps=timestamps, ) for frame, expected_frame in zip(frames, expected_frames): assert_tensor_equal(frame, expected_frame) From f2feab9a7f63420457ba04c337ff7d0eacc3ad23 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 16:42:27 +0100 Subject: [PATCH 18/27] better --- test/samplers/test_samplers.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 31375330..9b163101 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -383,14 +383,11 @@ def test_random_sampler_randomness(sampler): # Call the same sampler again with the same seed, expect same results torch.manual_seed(0) clips_2 = sampler(decoder, num_clips=num_clips) - for clip_1_data, clip_2_data in zip(clips_1.data, clips_2.data): - assert_tensor_equal(clip_1_data, clip_2_data) - for clip_1_pts, clip_2_pts in zip(clips_1.pts_seconds, clips_2.pts_seconds): - assert_tensor_equal(clip_1_pts, clip_2_pts) - for clip_1_duration, clip_2_duration in zip( - clips_1.duration_seconds, clips_2.duration_seconds - ): - assert_tensor_equal(clip_1_duration, clip_2_duration) + + for clip_1, clip_2 in zip(clips_1, clips_2): + assert_tensor_equal(clip_1.data, clip_2.data) + assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds) + assert_tensor_equal(clip_1.duration_seconds, clip_2.duration_seconds) # Call with a different seed, expect different results torch.manual_seed(1) From 92b79549cbd2a1ad94709965df2198bedce161e3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 16:56:50 +0100 Subject: [PATCH 19/27] minor refac --- src/torchcodec/_frame.py | 10 ---------- src/torchcodec/samplers/_common.py | 19 +++++++++++++++++++ src/torchcodec/samplers/_index_based.py | 13 ++++++++----- src/torchcodec/samplers/_time_based.py | 14 ++++++++------ test/samplers/test_samplers.py | 2 -- 5 files changed, 35 insertions(+), 23 deletions(-) diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index 9aee1036..b9542df4 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -110,15 +110,5 @@ def __getitem__(self, key) -> Union["FrameBatch", Frame]: def __len__(self): return len(self.data) - def __getitem__(self, key): - return FrameBatch( - self.data[key], - self.pts_seconds[key], - self.duration_seconds[key], - ) - - def __len__(self): - return len(self.data) - def __repr__(self): return _frame_repr(self) diff --git a/src/torchcodec/samplers/_common.py b/src/torchcodec/samplers/_common.py index bcf8f675..93dceb75 100644 --- a/src/torchcodec/samplers/_common.py +++ b/src/torchcodec/samplers/_common.py @@ -1,5 +1,8 @@ from typing import Callable, Union +from torch import Tensor +from torchcodec import FrameBatch + _LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]] @@ -53,3 +56,19 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy): raise ValueError( f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}." ) + + +def _make_5d_framebatch( + *, + data: Tensor, + pts_seconds: Tensor, + duration_seconds: Tensor, + num_clips: int, + num_frames_per_clip: int, +) -> FrameBatch: + last_3_dims = data.shape[-3:] + return FrameBatch( + data=data.view(num_clips, num_frames_per_clip, *last_3_dims), + pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), + duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), + ) diff --git a/src/torchcodec/samplers/_index_based.py b/src/torchcodec/samplers/_index_based.py index 67bb1201..a16a1292 100644 --- a/src/torchcodec/samplers/_index_based.py +++ b/src/torchcodec/samplers/_index_based.py @@ -6,6 +6,7 @@ from torchcodec.decoders import VideoDecoder from torchcodec.decoders._core import get_frames_at_indices from torchcodec.samplers._common import ( + _make_5d_framebatch, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, _validate_common_params, @@ -176,16 +177,18 @@ def _generic_index_based_sampler( policy_fun=_POLICY_FUNCTIONS[policy], ) + # TODO: Use public method of decoder, when it exists frames, pts_seconds, duration_seconds = get_frames_at_indices( decoder._decoder, stream_index=decoder.stream_index, frame_indices=all_clips_indices, ) - last_3_dims = frames.shape[-3:] - return FrameBatch( - data=frames.view(num_clips, num_frames_per_clip, *last_3_dims), - pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), - duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), + return _make_5d_framebatch( + data=frames, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, ) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index 10fcd2e3..db57a9a5 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -5,6 +5,7 @@ from torchcodec import FrameBatch from torchcodec.decoders._core import get_frames_by_pts from torchcodec.samplers._common import ( + _make_5d_framebatch, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, _validate_common_params, @@ -209,17 +210,18 @@ def _generic_time_based_sampler( policy_fun=_POLICY_FUNCTIONS[policy], ) + # TODO: Use public method of decoder, when it exists frames, pts_seconds, duration_seconds = get_frames_by_pts( decoder._decoder, stream_index=decoder.stream_index, frame_ptss=all_clips_timestamps, ) - last_3_dims = frames.shape[-3:] - - return FrameBatch( - data=frames.view(num_clips, num_frames_per_clip, *last_3_dims), - pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), - duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), + return _make_5d_framebatch( + data=frames, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 9b163101..3147cc92 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -26,8 +26,6 @@ def _assert_output_type_and_shapes( video, clips, expected_num_clips, num_frames_per_clip ): assert isinstance(clips, FrameBatch) - # assert len(clips) == expected_num_clips - # assert all(isinstance(clip, FrameBatch) for clip in clips) expected_clips_data_shape = ( expected_num_clips, num_frames_per_clip, From 8dd9b0a67ad78614f37e93bed2b5634a478c3b3b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 17:02:07 +0100 Subject: [PATCH 20/27] Address comments --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index bc3f1198..5f6a6d29 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1100,6 +1100,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( // eps` are probably the same frame, with the same index. The easiest way to // avoid decoding that unique frame twice is to convert the input timestamps // to indices, and leverage the de-duplication logic of getFramesAtIndices. + // This means this function requires a scan. + // TODO: longer term, we should implement this without requiring a scan const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; @@ -1119,8 +1121,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( stream.allFrames.begin(), stream.allFrames.end() - 1, framePts, - [&stream](const FrameInfo& info, double start) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= start; + [&stream](const FrameInfo& info, double framePts) { + return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; }); int64_t frameIndex = it - stream.allFrames.begin(); frameIndices[i] = frameIndex; From 326860618ec4bc09e062dd8d0f0bac30cb10315e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 17:05:20 +0100 Subject: [PATCH 21/27] fix --- src/torchcodec/samplers/_time_based.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index db57a9a5..e9d485aa 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -214,7 +214,7 @@ def _generic_time_based_sampler( frames, pts_seconds, duration_seconds = get_frames_by_pts( decoder._decoder, stream_index=decoder.stream_index, - frame_ptss=all_clips_timestamps, + timestamps=all_clips_timestamps, ) return _make_5d_framebatch( data=frames, From c03294bbf95486f3f12bc9bb6746f26ddb3a54da Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 10:36:52 +0100 Subject: [PATCH 22/27] Fix binary search of getFramesDisplayedByTimestamps --- .../decoders/_core/VideoDecoder.cpp | 9 +--- test/decoders/test_video_decoder_ops.py | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index add9c9be..9243365c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1119,21 +1119,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), - stream.allFrames.end(), + stream.allFrames.end() - 1, framePts, [&stream](const FrameInfo& info, double framePts) { return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; }); int64_t frameIndex = it - stream.allFrames.begin(); - // If the frame index is larger than the size of allFrames, that means we - // couldn't match the pts value to the pts value of a NEXT FRAME. And - // that means that this timestamp falls during the time between when the - // last frame is displayed, and the video ends. Hence, it should map to the - // index of the last frame. - frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } - return getFramesAtIndices(streamIndex, frameIndices); } diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 0ed68146..5ebd7830 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -186,6 +186,51 @@ def test_get_frames_by_pts(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + def test_pts_apis_against_index_ref(self): + # Get all frames in the video, then query all frames with all time-based + # APIs exactly where those frames are supposed to start. We assert that + # we get the expected frame. + decoder = create_from_file(str(NASA_VIDEO.path)) + scan_all_streams_to_update_metadata(decoder) + add_video_stream(decoder) + + metadata = get_json_metadata(decoder) + metadata_dict = json.loads(metadata) + num_frames = metadata_dict["numFrames"] + assert num_frames == 390 + + stream_index = 3 + _, all_pts_seconds_ref, _ = zip( + *[ + get_frame_at_index( + decoder, stream_index=stream_index, frame_index=frame_index + ) + for frame_index in range(num_frames) + ] + ) + all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref) + + assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref)) + + _, pts_seconds, _ = zip( + *[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref] + ) + pts_seconds = torch.tensor(pts_seconds) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts_in_range( + decoder, + stream_index=stream_index, + start_seconds=0, + stop_seconds=all_pts_seconds_ref[-1] + 1e-4, + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts( + decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist() + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) From 5ab33b982d582ce5a8bd2b795a7eae69d396b1e6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 10:39:07 +0100 Subject: [PATCH 23/27] Comment --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 9243365c..97d55c06 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1119,6 +1119,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), + // See https://github.com/pytorch/torchcodec/pull/286 for why the `- 1` + // is needed. stream.allFrames.end() - 1, framePts, [&stream](const FrameInfo& info, double framePts) { From fa374bc0709baa7dc993de69e80ba8626cb438f6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 11:13:40 +0100 Subject: [PATCH 24/27] comment --- test/decoders/test_video_decoder_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 5ebd7830..6ad774b5 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -187,6 +187,7 @@ def test_get_frames_by_pts(self): assert_tensor_equal(frames[0], frames[-1]) def test_pts_apis_against_index_ref(self): + # Non-regression test for https://github.com/pytorch/torchcodec/pull/286 # Get all frames in the video, then query all frames with all time-based # APIs exactly where those frames are supposed to start. We assert that # we get the expected frame. From c75417bc1705a74979f195538dafec4be92b0bc5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 11:46:03 +0100 Subject: [PATCH 25/27] Nits --- test/samplers/test_samplers.py | 41 +++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 3147cc92..3f632752 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -26,14 +26,21 @@ def _assert_output_type_and_shapes( video, clips, expected_num_clips, num_frames_per_clip ): assert isinstance(clips, FrameBatch) - expected_clips_data_shape = ( + assert clips.data.shape == ( expected_num_clips, num_frames_per_clip, 3, video.height, video.width, ) - assert clips.data.shape == expected_clips_data_shape + assert clips.pts_seconds.shape == ( + expected_num_clips, + num_frames_per_clip, + ) + assert clips.duration_seconds.shape == ( + expected_num_clips, + num_frames_per_clip, + ) def _assert_regular_sampler(clips, expected_seconds_between_clip_starts=None): @@ -84,10 +91,11 @@ def test_index_based_sampler(sampler, num_indices_between_frames): # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. - avg_distance_between_frames_seconds = clips.pts_seconds.diff(dim=1).mean() - assert avg_distance_between_frames_seconds == pytest.approx( - num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5 - ) + for clip in clips: + avg_distance_between_frames_seconds = clip.pts_seconds.diff().mean() + assert avg_distance_between_frames_seconds == pytest.approx( + num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5 + ) @pytest.mark.parametrize( @@ -133,10 +141,11 @@ def test_time_based_sampler(sampler, seconds_between_frames): expected_seconds_between_frames = ( seconds_between_frames or 1 / decoder.metadata.average_fps ) - avg_seconds_between_frames = clips.pts_seconds.diff(dim=1).mean() - assert avg_seconds_between_frames == pytest.approx( - expected_seconds_between_frames, abs=0.05 - ) + for clip in clips: + avg_seconds_between_frames = clip.pts_seconds.diff().mean() + assert avg_seconds_between_frames == pytest.approx( + expected_seconds_between_frames, abs=0.05 + ) @pytest.mark.parametrize( @@ -199,8 +208,8 @@ def test_sampling_range( else pytest.raises(AssertionError, match="Tensor-likes are not") ) with cm: - for clip_data in clips.data: - assert_tensor_equal(clip_data, clips.data[0]) + for clip in clips: + assert_tensor_equal(clip.data, clips.data[0]) @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @@ -227,11 +236,11 @@ def test_sampling_range_negative(sampler): ) # There is only one unique clip in clips_1... - for clip_data in clips_1.data: - assert_tensor_equal(clip_data, clips_1.data[0]) + for clip_1 in clips_1: + assert_tensor_equal(clip_1.data, clips_1.data[0]) # ... and it's the same that's in clips_2 - for clip_data in clips_2.data: - assert_tensor_equal(clip_data, clips_1.data[0]) + for clip_2 in clips_2: + assert_tensor_equal(clip_2.data, clips_1.data[0]) @pytest.mark.parametrize( From da59be44409b9be0f8b3d56e7ce9f22a03640053 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 16:05:10 +0100 Subject: [PATCH 26/27] merge fixes --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 5 ++--- test/samplers/test_samplers.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 97d55c06..8c9f4363 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1119,9 +1119,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), - // See https://github.com/pytorch/torchcodec/pull/286 for why the `- 1` - // is needed. - stream.allFrames.end() - 1, + stream.allFrames.end(), framePts, [&stream](const FrameInfo& info, double framePts) { return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; @@ -1129,6 +1127,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( int64_t frameIndex = it - stream.allFrames.begin(); frameIndices[i] = frameIndex; } + return getFramesAtIndices(streamIndex, frameIndices); } diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 3f632752..ca0e5a42 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -383,7 +383,6 @@ def test_random_sampler_randomness(sampler): # Assert the clip starts aren't sorted, to make sure we haven't messed up # the implementation. (This may fail if we're unlucky, but we hard-coded a # seed, so it will always pass.) - # clip_starts = [clip.pts_seconds.item() for clip in clips_1] clip_starts = clips_1.pts_seconds[:, 0].tolist() assert sorted(clip_starts) != clip_starts From ad1dd3a073e91ca996b53bf551bf6f69cc1bcf80 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 16:14:06 +0100 Subject: [PATCH 27/27] Simplify even further --- test/samplers/test_samplers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index ca0e5a42..4a12d93c 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -209,7 +209,7 @@ def test_sampling_range( ) with cm: for clip in clips: - assert_tensor_equal(clip.data, clips.data[0]) + assert_tensor_equal(clip.data, clips[0].data) @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @@ -236,11 +236,11 @@ def test_sampling_range_negative(sampler): ) # There is only one unique clip in clips_1... - for clip_1 in clips_1: - assert_tensor_equal(clip_1.data, clips_1.data[0]) + for clip in clips_1: + assert_tensor_equal(clip.data, clips_1[0].data) # ... and it's the same that's in clips_2 - for clip_2 in clips_2: - assert_tensor_equal(clip_2.data, clips_1.data[0]) + for clip in clips_2: + assert_tensor_equal(clip.data, clips_1[0].data) @pytest.mark.parametrize( @@ -399,7 +399,7 @@ def test_random_sampler_randomness(sampler): torch.manual_seed(1) clips_3 = sampler(decoder, num_clips=num_clips) with pytest.raises(AssertionError, match="Tensor-likes are not"): - assert_tensor_equal(clips_1.data[0], clips_3.data[0]) + assert_tensor_equal(clips_1[0].data, clips_3[0].data) # Make sure we didn't alter the builtin Python RNG builtin_random_state_end = random.getstate()