diff --git a/benchmarks/decoders/benchmark_decoders.py b/benchmarks/decoders/benchmark_decoders.py index 1c5425050..761f269fd 100644 --- a/benchmarks/decoders/benchmark_decoders.py +++ b/benchmarks/decoders/benchmark_decoders.py @@ -209,7 +209,7 @@ def get_frames_from_video(self, video_file, pts_list): best_video_stream = metadata["bestVideoStreamIndex"] indices_list = [int(pts * average_fps) for pts in pts_list] frames = [] - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=best_video_stream, frame_indices=indices_list ) return frames @@ -226,7 +226,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): best_video_stream = metadata["bestVideoStreamIndex"] frames = [] indices_list = list(range(numFramesToDecode)) - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=best_video_stream, frame_indices=indices_list ) return frames diff --git a/src/torchcodec/_samplers/video_clip_sampler.py b/src/torchcodec/_samplers/video_clip_sampler.py index 1440edaeb..4900be534 100644 --- a/src/torchcodec/_samplers/video_clip_sampler.py +++ b/src/torchcodec/_samplers/video_clip_sampler.py @@ -240,7 +240,7 @@ def _get_clips_for_index_based_sampling( clip_start_idx + i * index_based_sampler_args.video_frame_dilation for i in range(index_based_sampler_args.frames_per_clip) ] - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( video_decoder, stream_index=metadata_json["bestVideoStreamIndex"], frame_indices=batch_indexes, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 58f946358..1794c984b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1034,24 +1034,57 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFramesAtIndices"); + auto indicesAreSorted = + std::is_sorted(frameIndices.begin(), frameIndices.end()); + + std::vector argsort; + if (!indicesAreSorted) { + // if frameIndices is [13, 10, 12, 11] + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want + // to use to decode the frames + // and argsort is [ 1, 3, 2, 0] + argsort.resize(frameIndices.size()); + for (size_t i = 0; i < argsort.size(); ++i) { + argsort[i] = i; + } + std::sort( + argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) { + return frameIndices[a] < frameIndices[b]; + }); + } + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; const auto& options = stream.options; BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); + auto previousIndexInVideo = -1; for (auto f = 0; f < frameIndices.size(); ++f) { - auto frameIndex = frameIndices[f]; - if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) { + auto indexInOutput = indicesAreSorted ? f : argsort[f]; + auto indexInVideo = frameIndices[indexInOutput]; + if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) { throw std::runtime_error( - "Invalid frame index=" + std::to_string(frameIndex)); + "Invalid frame index=" + std::to_string(indexInVideo)); } - DecodedOutput singleOut = - getFrameAtIndex(streamIndex, frameIndex, output.frames[f]); - if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - output.frames[f] = singleOut.frame; + if ((f > 0) && (indexInVideo == previousIndexInVideo)) { + // Avoid decoding the same frame twice + auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; + output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]); + output.ptsSeconds[indexInOutput] = + output.ptsSeconds[previousIndexInOutput]; + output.durationSeconds[indexInOutput] = + output.durationSeconds[previousIndexInOutput]; + } else { + DecodedOutput singleOut = getFrameAtIndex( + streamIndex, indexInVideo, output.frames[indexInOutput]); + if (options.colorConversionLibrary == + ColorConversionLibrary::FILTERGRAPH) { + output.frames[indexInOutput] = singleOut.frame; + } + output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; + output.durationSeconds[indexInOutput] = singleOut.durationSeconds; } - // Note that for now we ignore the pts and duration parts of the output, - // because they're never used in any caller. + previousIndexInVideo = indexInVideo; } output.frames = MaybePermuteHWC2CHW(options, output.frames); return output; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 70f4afdc2..0be871a3e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor"); + "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -218,7 +218,7 @@ OpsDecodedOutput get_frame_at_index( return makeOpsDecodedOutput(result); } -at::Tensor get_frames_at_indices( +OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices) { @@ -226,7 +226,7 @@ at::Tensor get_frames_at_indices( std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); - return result.frames; + return makeOpsBatchDecodedOutput(result); } OpsBatchDecodedOutput get_frames_in_range( diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 7e9621e91..5b442025d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -87,7 +87,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder); // Return the frames at a given index for a given stream as a single stacked // Tensor. -at::Tensor get_frames_at_indices( +OpsBatchDecodedOutput get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices); diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index bf170086e..01de6ad67 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -190,9 +190,13 @@ def get_frames_at_indices_abstract( *, stream_index: int, frame_indices: List[int], -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] - return torch.empty(image_size) + return ( + torch.empty(image_size), + torch.empty([], dtype=torch.float), + torch.empty([], dtype=torch.float), + ) @register_fake("torchcodec_ns::get_frames_in_range") diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index cc7b5011b..bbd9fe4e7 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -116,7 +116,7 @@ def test_get_frames_at_indices(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder) - frames0and180 = get_frames_at_indices( + frames0and180, *_ = get_frames_at_indices( decoder, stream_index=3, frame_indices=[0, 180] ) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -124,6 +124,37 @@ def test_get_frames_at_indices(self): assert_tensor_equal(frames0and180[0], reference_frame0) assert_tensor_equal(frames0and180[1], reference_frame180) + def test_get_frames_at_indices_unsorted_indices(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream(decoder) + scan_all_streams_to_update_metadata(decoder) + stream_index = 3 + + frame_indices = [2, 0, 1, 0, 2] + + expected_frames = [ + get_frame_at_index( + decoder, stream_index=stream_index, frame_index=frame_index + )[0] + for frame_index in frame_indices + ] + + frames, *_ = get_frames_at_indices( + decoder, + stream_index=stream_index, + frame_indices=frame_indices, + ) + for frame, expected_frame in zip(frames, expected_frames): + assert_tensor_equal(frame, expected_frame) + + # first and last frame should be equal, at index 2. We then modify the + # first frame and assert that it's now different from the last frame. + # This ensures a copy was properly made during the de-duplication logic. + assert_tensor_equal(frames[0], frames[-1]) + frames[0] += 20 + with pytest.raises(AssertionError): + assert_tensor_equal(frames[0], frames[-1]) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) @@ -425,7 +456,7 @@ def test_color_conversion_library_with_dimension_order( assert frames.shape[1:] == expected_shape assert_tensor_equal(frames[0], frame0_ref) - frames = get_frames_at_indices( + frames, *_ = get_frames_at_indices( decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4] ) assert frames.shape[1:] == expected_shape