Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/decoders/benchmark_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def get_frames_from_video(self, video_file, pts_list):
best_video_stream = metadata["bestVideoStreamIndex"]
indices_list = [int(pts * average_fps) for pts in pts_list]
frames = []
frames = get_frames_at_indices(
frames, *_ = get_frames_at_indices(
decoder, stream_index=best_video_stream, frame_indices=indices_list
)
return frames
Expand All @@ -226,7 +226,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
best_video_stream = metadata["bestVideoStreamIndex"]
frames = []
indices_list = list(range(numFramesToDecode))
frames = get_frames_at_indices(
frames, *_ = get_frames_at_indices(
decoder, stream_index=best_video_stream, frame_indices=indices_list
)
return frames
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_samplers/video_clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _get_clips_for_index_based_sampling(
clip_start_idx + i * index_based_sampler_args.video_frame_dilation
for i in range(index_based_sampler_args.frames_per_clip)
]
frames = get_frames_at_indices(
frames, *_ = get_frames_at_indices(
video_decoder,
stream_index=metadata_json["bestVideoStreamIndex"],
frame_indices=batch_indexes,
Expand Down
51 changes: 42 additions & 9 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,24 +1034,57 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
validateUserProvidedStreamIndex(streamIndex);
validateScannedAllStreams("getFramesAtIndices");

auto indicesAreSorted =
std::is_sorted(frameIndices.begin(), frameIndices.end());

std::vector<size_t> argsort;
if (!indicesAreSorted) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Digging a bit, I think we're probably better off not checking to see if the sequence is already sorted, and just always sorting. Modern implementations of std::sort seem to be Introsort, which was designed to be nearly linear with an already sorted sequence. I'm also fine if we commit this as is. We can always investigate more later if it becomes important.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I like to keep object definitions as close as possible to their initial use. So I'd prefer to see lines 1037-1040 appear after the sorting, right before we use output in the loop. (I just ran into this trying to find the definition of what output is when reading the loop.)

// if frameIndices is [13, 10, 12, 11]
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
// to use to decode the frames
// and argsort is [ 1, 3, 2, 0]
argsort.resize(frameIndices.size());
for (size_t i = 0; i < argsort.size(); ++i) {
argsort[i] = i;
}
std::sort(
argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
return frameIndices[a] < frameIndices[b];
});
}

const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& stream = streams_[streamIndex];
const auto& options = stream.options;
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);

auto previousIndexInVideo = -1;
for (auto f = 0; f < frameIndices.size(); ++f) {
auto frameIndex = frameIndices[f];
if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) {
auto indexInOutput = indicesAreSorted ? f : argsort[f];
auto indexInVideo = frameIndices[indexInOutput];
if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) {
throw std::runtime_error(
"Invalid frame index=" + std::to_string(frameIndex));
"Invalid frame index=" + std::to_string(indexInVideo));
}
DecodedOutput singleOut =
getFrameAtIndex(streamIndex, frameIndex, output.frames[f]);
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
output.frames[f] = singleOut.frame;
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
// Avoid decoding the same frame twice
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]);
output.ptsSeconds[indexInOutput] =
output.ptsSeconds[previousIndexInOutput];
output.durationSeconds[indexInOutput] =
output.durationSeconds[previousIndexInOutput];
} else {
DecodedOutput singleOut = getFrameAtIndex(
streamIndex, indexInVideo, output.frames[indexInOutput]);
if (options.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
output.frames[indexInOutput] = singleOut.frame;
}
Comment on lines +1080 to +1083
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed when we are passing output.frames[indexInOutput] in getFrameAtIndex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because the pre-allocated buffer isn't used with filtergraph, only with swscale.

(Just a note that this is not something that was introduced in this PR, you'll see the same pattern in other callers)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like the wrong behavior. When a user passes in a pre-allocated tensor it should work with either color conversion library.

I guess this behavior is introduced by the PR to add pre-allocated tensor. That PR should have done it for either color conversion library. Maybe make that change first before merging in this PR? That would be my vote because otherwise the caller has to think about what color conversion library was used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you that the current behavior is potentially surprising.

I am open to re-working this eventually, but I want us to acknowledge that #266 (and #277) are significantly improving the existing code-base. They're not perfect, but they're clear improvements.

output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
}
// Note that for now we ignore the pts and duration parts of the output,
// because they're never used in any caller.
previousIndexInVideo = indexInVideo;
}
output.frames = MaybePermuteHWC2CHW(options, output.frames);
return output;
Expand Down
6 changes: 3 additions & 3 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor");
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
Expand Down Expand Up @@ -218,15 +218,15 @@ OpsDecodedOutput get_frame_at_index(
return makeOpsDecodedOutput(result);
}

at::Tensor get_frames_at_indices(
OpsBatchDecodedOutput get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
at::IntArrayRef frame_indices) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<int64_t> frameIndicesVec(
frame_indices.begin(), frame_indices.end());
auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec);
return result.frames;
return makeOpsBatchDecodedOutput(result);
}

OpsBatchDecodedOutput get_frames_in_range(
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder);

// Return the frames at a given index for a given stream as a single stacked
// Tensor.
at::Tensor get_frames_at_indices(
OpsBatchDecodedOutput get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
at::IntArrayRef frame_indices);
Expand Down
8 changes: 6 additions & 2 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,13 @@ def get_frames_at_indices_abstract(
*,
stream_index: int,
frame_indices: List[int],
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
return (
torch.empty(image_size),
torch.empty([], dtype=torch.float),
torch.empty([], dtype=torch.float),
)


@register_fake("torchcodec_ns::get_frames_in_range")
Expand Down
35 changes: 33 additions & 2 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,45 @@ def test_get_frames_at_indices(self):
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder)
frames0and180 = get_frames_at_indices(
frames0and180, *_ = get_frames_at_indices(
decoder, stream_index=3, frame_indices=[0, 180]
)
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
reference_frame180 = NASA_VIDEO.get_frame_by_name("time6.000000")
assert_tensor_equal(frames0and180[0], reference_frame0)
assert_tensor_equal(frames0and180[1], reference_frame180)

def test_get_frames_at_indices_unsorted_indices(self):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder)
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

frame_indices = [2, 0, 1, 0, 2]

expected_frames = [
get_frame_at_index(
decoder, stream_index=stream_index, frame_index=frame_index
)[0]
for frame_index in frame_indices
]

frames, *_ = get_frames_at_indices(
decoder,
stream_index=stream_index,
frame_indices=frame_indices,
)
for frame, expected_frame in zip(frames, expected_frames):
assert_tensor_equal(frame, expected_frame)

# first and last frame should be equal, at index 2. We then modify the
# first frame and assert that it's now different from the last frame.
# This ensures a copy was properly made during the de-duplication logic.
assert_tensor_equal(frames[0], frames[-1])
frames[0] += 20
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_get_frames_in_range(self):
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
Expand Down Expand Up @@ -425,7 +456,7 @@ def test_color_conversion_library_with_dimension_order(
assert frames.shape[1:] == expected_shape
assert_tensor_equal(frames[0], frame0_ref)

frames = get_frames_at_indices(
frames, *_ = get_frames_at_indices(
decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4]
)
assert frames.shape[1:] == expected_shape
Expand Down
Loading