Skip to content

Commit c8de21c

Browse files
authored
Add sort and dedup logic in C++ to getFramesAtIndices (#280)
1 parent f72e39a commit c8de21c

File tree

7 files changed

+88
-20
lines changed

7 files changed

+88
-20
lines changed

benchmarks/decoders/benchmark_decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def get_frames_from_video(self, video_file, pts_list):
209209
best_video_stream = metadata["bestVideoStreamIndex"]
210210
indices_list = [int(pts * average_fps) for pts in pts_list]
211211
frames = []
212-
frames = get_frames_at_indices(
212+
frames, *_ = get_frames_at_indices(
213213
decoder, stream_index=best_video_stream, frame_indices=indices_list
214214
)
215215
return frames
@@ -226,7 +226,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
226226
best_video_stream = metadata["bestVideoStreamIndex"]
227227
frames = []
228228
indices_list = list(range(numFramesToDecode))
229-
frames = get_frames_at_indices(
229+
frames, *_ = get_frames_at_indices(
230230
decoder, stream_index=best_video_stream, frame_indices=indices_list
231231
)
232232
return frames

src/torchcodec/_samplers/video_clip_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _get_clips_for_index_based_sampling(
240240
clip_start_idx + i * index_based_sampler_args.video_frame_dilation
241241
for i in range(index_based_sampler_args.frames_per_clip)
242242
]
243-
frames = get_frames_at_indices(
243+
frames, *_ = get_frames_at_indices(
244244
video_decoder,
245245
stream_index=metadata_json["bestVideoStreamIndex"],
246246
frame_indices=batch_indexes,

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,24 +1034,57 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10341034
validateUserProvidedStreamIndex(streamIndex);
10351035
validateScannedAllStreams("getFramesAtIndices");
10361036

1037+
auto indicesAreSorted =
1038+
std::is_sorted(frameIndices.begin(), frameIndices.end());
1039+
1040+
std::vector<size_t> argsort;
1041+
if (!indicesAreSorted) {
1042+
// if frameIndices is [13, 10, 12, 11]
1043+
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1044+
// to use to decode the frames
1045+
// and argsort is [ 1, 3, 2, 0]
1046+
argsort.resize(frameIndices.size());
1047+
for (size_t i = 0; i < argsort.size(); ++i) {
1048+
argsort[i] = i;
1049+
}
1050+
std::sort(
1051+
argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
1052+
return frameIndices[a] < frameIndices[b];
1053+
});
1054+
}
1055+
10371056
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
10381057
const auto& stream = streams_[streamIndex];
10391058
const auto& options = stream.options;
10401059
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
10411060

1061+
auto previousIndexInVideo = -1;
10421062
for (auto f = 0; f < frameIndices.size(); ++f) {
1043-
auto frameIndex = frameIndices[f];
1044-
if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) {
1063+
auto indexInOutput = indicesAreSorted ? f : argsort[f];
1064+
auto indexInVideo = frameIndices[indexInOutput];
1065+
if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) {
10451066
throw std::runtime_error(
1046-
"Invalid frame index=" + std::to_string(frameIndex));
1067+
"Invalid frame index=" + std::to_string(indexInVideo));
10471068
}
1048-
DecodedOutput singleOut =
1049-
getFrameAtIndex(streamIndex, frameIndex, output.frames[f]);
1050-
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1051-
output.frames[f] = singleOut.frame;
1069+
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
1070+
// Avoid decoding the same frame twice
1071+
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
1072+
output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]);
1073+
output.ptsSeconds[indexInOutput] =
1074+
output.ptsSeconds[previousIndexInOutput];
1075+
output.durationSeconds[indexInOutput] =
1076+
output.durationSeconds[previousIndexInOutput];
1077+
} else {
1078+
DecodedOutput singleOut = getFrameAtIndex(
1079+
streamIndex, indexInVideo, output.frames[indexInOutput]);
1080+
if (options.colorConversionLibrary ==
1081+
ColorConversionLibrary::FILTERGRAPH) {
1082+
output.frames[indexInOutput] = singleOut.frame;
1083+
}
1084+
output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
1085+
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
10521086
}
1053-
// Note that for now we ignore the pts and duration parts of the output,
1054-
// because they're never used in any caller.
1087+
previousIndexInVideo = indexInVideo;
10551088
}
10561089
output.frames = MaybePermuteHWC2CHW(options, output.frames);
10571090
return output;

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
m.def(
4141
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
4242
m.def(
43-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor");
43+
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
4444
m.def(
4545
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4646
m.def(
@@ -218,15 +218,15 @@ OpsDecodedOutput get_frame_at_index(
218218
return makeOpsDecodedOutput(result);
219219
}
220220

221-
at::Tensor get_frames_at_indices(
221+
OpsBatchDecodedOutput get_frames_at_indices(
222222
at::Tensor& decoder,
223223
int64_t stream_index,
224224
at::IntArrayRef frame_indices) {
225225
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
226226
std::vector<int64_t> frameIndicesVec(
227227
frame_indices.begin(), frame_indices.end());
228228
auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec);
229-
return result.frames;
229+
return makeOpsBatchDecodedOutput(result);
230230
}
231231

232232
OpsBatchDecodedOutput get_frames_in_range(

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder);
8787

8888
// Return the frames at a given index for a given stream as a single stacked
8989
// Tensor.
90-
at::Tensor get_frames_at_indices(
90+
OpsBatchDecodedOutput get_frames_at_indices(
9191
at::Tensor& decoder,
9292
int64_t stream_index,
9393
at::IntArrayRef frame_indices);

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,13 @@ def get_frames_at_indices_abstract(
190190
*,
191191
stream_index: int,
192192
frame_indices: List[int],
193-
) -> torch.Tensor:
193+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194194
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
195-
return torch.empty(image_size)
195+
return (
196+
torch.empty(image_size),
197+
torch.empty([], dtype=torch.float),
198+
torch.empty([], dtype=torch.float),
199+
)
196200

197201

198202
@register_fake("torchcodec_ns::get_frames_in_range")

test/decoders/test_video_decoder_ops.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,45 @@ def test_get_frames_at_indices(self):
116116
decoder = create_from_file(str(NASA_VIDEO.path))
117117
scan_all_streams_to_update_metadata(decoder)
118118
add_video_stream(decoder)
119-
frames0and180 = get_frames_at_indices(
119+
frames0and180, *_ = get_frames_at_indices(
120120
decoder, stream_index=3, frame_indices=[0, 180]
121121
)
122122
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
123123
reference_frame180 = NASA_VIDEO.get_frame_by_name("time6.000000")
124124
assert_tensor_equal(frames0and180[0], reference_frame0)
125125
assert_tensor_equal(frames0and180[1], reference_frame180)
126126

127+
def test_get_frames_at_indices_unsorted_indices(self):
128+
decoder = create_from_file(str(NASA_VIDEO.path))
129+
_add_video_stream(decoder)
130+
scan_all_streams_to_update_metadata(decoder)
131+
stream_index = 3
132+
133+
frame_indices = [2, 0, 1, 0, 2]
134+
135+
expected_frames = [
136+
get_frame_at_index(
137+
decoder, stream_index=stream_index, frame_index=frame_index
138+
)[0]
139+
for frame_index in frame_indices
140+
]
141+
142+
frames, *_ = get_frames_at_indices(
143+
decoder,
144+
stream_index=stream_index,
145+
frame_indices=frame_indices,
146+
)
147+
for frame, expected_frame in zip(frames, expected_frames):
148+
assert_tensor_equal(frame, expected_frame)
149+
150+
# first and last frame should be equal, at index 2. We then modify the
151+
# first frame and assert that it's now different from the last frame.
152+
# This ensures a copy was properly made during the de-duplication logic.
153+
assert_tensor_equal(frames[0], frames[-1])
154+
frames[0] += 20
155+
with pytest.raises(AssertionError):
156+
assert_tensor_equal(frames[0], frames[-1])
157+
127158
def test_get_frames_in_range(self):
128159
decoder = create_from_file(str(NASA_VIDEO.path))
129160
scan_all_streams_to_update_metadata(decoder)
@@ -425,7 +456,7 @@ def test_color_conversion_library_with_dimension_order(
425456
assert frames.shape[1:] == expected_shape
426457
assert_tensor_equal(frames[0], frame0_ref)
427458

428-
frames = get_frames_at_indices(
459+
frames, *_ = get_frames_at_indices(
429460
decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4]
430461
)
431462
assert frames.shape[1:] == expected_shape

0 commit comments

Comments
 (0)