Skip to content

Add sort and dedup logic in C++ with new getFramesDisplayedByTimestamps method / core API #282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,53 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
return output;
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
Copy link
Member Author

@NicolasHug NicolasHug Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the function names, I based it on the existing getFramesDisplayedByTimestampsInRange(). I'm happy to bikeshed, but not on this PR please.
Note that our naming is becoming a bit inconsistent, e.g. the python names, ops names and C++ names are not always aligned. We probably want to clean that up, but that's for later.

We'll also have to re-think our public VideoDecoder method names very soon I think, because adding the 2 new "get_frames_..." that we recently added may conflict with existing names.

int streamIndex,
const std::vector<double>& timestamps) {
validateUserProvidedStreamIndex(streamIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a TODO saying long term we should not require scanning the file for time-based frame extraction

validateScannedAllStreams("getFramesDisplayedByTimestamps");

// The frame displayed at timestamp t and the one displayed at timestamp `t +
// eps` are probably the same frame, with the same index. The easiest way to
// avoid decoding that unique frame twice is to convert the input timestamps
// to indices, and leverage the de-duplication logic of getFramesAtIndices.
Comment on lines +1096 to +1102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The long term thing to do here is to not require a scan for this function.

Scan should only be required for index-based functions where indexes need to be exact.

The index-based function can then call this function if the de-duping is done here and this doesn't need a scan.

That said you can merge this as-is and do the long-term thing later

Copy link
Member Author

@NicolasHug NicolasHug Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahmadsharif1 do you think it will be possible to correctly de-dup pts-based frame queries without going through indices?
That is currently the main reason we're converting to indices here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I mean by having explicit seeking modes (approximate versus exact) is a big conceptual change. :) It might only requires changing a few dozen lines of code, but it will require changes in many places, and we'll always have to be aware of what mode is active.

At the moment, we're implicitly exact when we need to do things with indices. I think we should just keep being that way as needed, and when we implement the different modes, we can reason about what changes to make wholistically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to dedup based on time, but it needs to be done on the fly -- i.e. when you decode a frame you see the extent of that frame on the timeline and can make copies of that for the timepoints that the user specified. It could be a bit more involved than that because you don't know the extent of the frame by looking at the frame itself -- you need to read (not decode) the next frame.

Doing it holistically in a future PR sounds good

// This means this function requires a scan.
// TODO: longer term, we should implement this without requiring a scan

const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& stream = streams_[streamIndex];
double minSeconds = streamMetadata.minPtsSecondsFromScan.value();
double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value();

std::vector<int64_t> frameIndices(timestamps.size());
for (auto i = 0; i < timestamps.size(); ++i) {
auto framePts = timestamps[i];
TORCH_CHECK(
framePts >= minSeconds && framePts < maxSeconds,
"frame pts is " + std::to_string(framePts) + "; must be in range [" +
std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) +
").");

auto it = std::lower_bound(
stream.allFrames.begin(),
stream.allFrames.end(),
framePts,
[&stream](const FrameInfo& info, double framePts) {
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
});
int64_t frameIndex = it - stream.allFrames.begin();
// If the frame index is larger than the size of allFrames, that means we
// couldn't match the pts value to the pts value of a NEXT FRAME. And
// that means that this timestamp falls during the time between when the
// last frame is displayed, and the video ends. Hence, it should map to the
// index of the last frame.
frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);
frameIndices[i] = frameIndex;
}

return getFramesAtIndices(streamIndex, frameIndices);
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
int streamIndex,
int64_t start,
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class VideoDecoder {
// i.e. it will be returned when this function is called with seconds=5.0 or
// seconds=5.999, etc.
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);

DecodedOutput getFrameAtIndex(
int streamIndex,
int64_t frameIndex,
Expand All @@ -242,6 +243,11 @@ class VideoDecoder {
BatchDecodedOutput getFramesAtIndices(
int streamIndex,
const std::vector<int64_t>& frameIndices);

BatchDecodedOutput getFramesDisplayedByTimestamps(
int streamIndex,
const std::vector<double>& timestamps);

// Returns frames within a given range for a given stream as a single stacked
// Tensor. The range is defined by [start, stop). The values retrieved from
// the range are:
Expand Down
13 changes: 13 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
Expand Down Expand Up @@ -240,6 +242,16 @@ OpsBatchDecodedOutput get_frames_in_range(
stream_index, start, stop, step.value_or(1));
return makeOpsBatchDecodedOutput(result);
}
OpsBatchDecodedOutput get_frames_by_pts(
at::Tensor& decoder,
int64_t stream_index,
at::ArrayRef<double> timestamps) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
auto result =
videoDecoder->getFramesDisplayedByTimestamps(stream_index, timestampsVec);
return makeOpsBatchDecodedOutput(result);
}

OpsBatchDecodedOutput get_frames_by_pts_in_range(
at::Tensor& decoder,
Expand Down Expand Up @@ -485,6 +497,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("get_frames_at_indices", &get_frames_at_indices);
m.impl("get_frames_in_range", &get_frames_in_range);
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
m.impl("get_frames_by_pts", &get_frames_by_pts);
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
m.impl(
"scan_all_streams_to_update_metadata",
Expand Down
9 changes: 7 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ using OpsBatchDecodedOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
// given timestamp T has T >= PTS and T < PTS + Duration.
OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds);

// Return the frames at given ptss for a given stream
OpsBatchDecodedOutput get_frames_by_pts(
at::Tensor& decoder,
int64_t stream_index,
at::ArrayRef<double> timestamps);

// Return the frame that is visible at a given index in the video.
OpsDecodedOutput get_frame_at_index(
at::Tensor& decoder,
Expand All @@ -85,8 +91,7 @@ OpsDecodedOutput get_frame_at_index(
// duration as tensors.
OpsDecodedOutput get_next_frame(at::Tensor& decoder);

// Return the frames at a given index for a given stream as a single stacked
// Tensor.
// Return the frames at given indices for a given stream
OpsBatchDecodedOutput get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/decoders/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_frame_at_index,
get_frame_at_pts,
get_frames_at_indices,
get_frames_by_pts,
get_frames_by_pts_in_range,
get_frames_in_range,
get_json_metadata,
Expand Down
16 changes: 16 additions & 0 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def load_torchcodec_extension():
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
Expand Down Expand Up @@ -172,6 +173,21 @@ def get_frame_at_pts_abstract(
)


@register_fake("torchcodec_ns::get_frames_by_pts")
def get_frames_by_pts_abstract(
decoder: torch.Tensor,
*,
stream_index: int,
timestamps: List[float],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return (
torch.empty(image_size),
torch.empty([], dtype=torch.float),
torch.empty([], dtype=torch.float),
)


@register_fake("torchcodec_ns::get_frame_at_index")
def get_frame_at_index_abstract(
decoder: torch.Tensor, *, stream_index: int, frame_index: int
Expand Down
31 changes: 31 additions & 0 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_frame_at_index,
get_frame_at_pts,
get_frames_at_indices,
get_frames_by_pts,
get_frames_by_pts_in_range,
get_frames_in_range,
get_json_metadata,
Expand Down Expand Up @@ -155,6 +156,36 @@ def test_get_frames_at_indices_unsorted_indices(self):
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_get_frames_by_pts(self):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(decoder)
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

# Note: 13.01 should give the last video frame for the NASA video
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]

expected_frames = [
get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps
]

frames, *_ = get_frames_by_pts(
decoder,
stream_index=stream_index,
timestamps=timestamps,
)
for frame, expected_frame in zip(frames, expected_frames):
assert_tensor_equal(frame, expected_frame)

# first and last frame should be equal, at pts=2 [+ eps]. We then modify
# the first frame and assert that it's now different from the last
# frame. This ensures a copy was properly made during the de-duplication
# logic.
assert_tensor_equal(frames[0], frames[-1])
frames[0] += 20
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_get_frames_in_range(self):
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
Expand Down
Loading