Skip to content

Pass pre-allocate tensors in batch APIs to avoid copies #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
}

VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
VideoDecoder::RawDecodedOutput& rawOutput) {
VideoDecoder::RawDecodedOutput& rawOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
// Convert the frame to tensor.
DecodedOutput output;
int streamIndex = rawOutput.streamIndex;
Expand All @@ -862,8 +863,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
output.durationSeconds = ptsToSeconds(
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
if (streamInfo.options.device.type() == torch::kCPU) {
convertAVFrameToDecodedOutputOnCPU(rawOutput, output);
convertAVFrameToDecodedOutputOnCPU(
rawOutput, output, preAllocatedOutputTensor);
} else if (streamInfo.options.device.type() == torch::kCUDA) {
// TODO: handle pre-allocated output tensor
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is a no-op for CUDA devices. I'm leaving-out CUDA pre-allocation because this is strongly tied to #189 and can be treated separately.

convertAVFrameToDecodedOutputOnCuda(
streamInfo.options.device,
streamInfo.options,
Expand All @@ -879,16 +882,24 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(

void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
VideoDecoder::RawDecodedOutput& rawOutput,
DecodedOutput& output) {
DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
int streamIndex = rawOutput.streamIndex;
AVFrame* frame = rawOutput.frame.get();
auto& streamInfo = streams_[streamIndex];
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
int width = streamInfo.options.width.value_or(frame->width);
int height = streamInfo.options.height.value_or(frame->height);
torch::Tensor tensor = torch::empty(
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
torch::Tensor tensor;
if (preAllocatedOutputTensor.has_value()) {
// TODO: check shape of preAllocatedOutputTensor?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should TORCH_CHECK for height, width, shape, etc. here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have this a try thinking it would be a simple assert like

assert `shape[-3] == H, W, 3`

But it turns out it's not as simple. Some tensors come as HWC while some other come pas HWC. This is because the pre-allocated batched tensors are allocated as such:

if (options.dimensionOrder == "NHWC") {
frames = torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8});
} else if (options.dimensionOrder == "NCHW") {
frames = torch::empty(
{numFrames,
3,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width)},
torch::TensorOptions()
.memory_format(torch::MemoryFormat::ChannelsLast)
.dtype({torch::kUInt8}));

It then me realize that everything works, but it's pretty magical. We end up doing the .pemute() calls in different places, but I think it would be a lot cleaner if we allocated batched output only as NHWC, and then permute this entire NHWC tensor in one go. What we do right now is that we permute all the N HWC tensors, and that's probably not as efficient (or as clean).

I want to fix this as an immediate follow-up if that's OK. I gave it a try here, but it's not trivial and it might be preferable not to overcomplexify this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of @NicolasHug suggestion. The logic he points out is legacy from way back when, and it wasn't necessarily throught through in terms of long term maintenance and code health. Always doing it one way, and then permuting as needed on the way out, sounds easier and cleaner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me

tensor = preAllocatedOutputTensor.value();
} else {
int width = streamInfo.options.width.value_or(frame->width);
int height = streamInfo.options.height.value_or(frame->height);
tensor = torch::empty(
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
}

rawOutput.data = tensor.data_ptr<uint8_t>();
convertFrameToBufferUsingSwsScale(rawOutput);

Expand Down Expand Up @@ -981,7 +992,8 @@ void VideoDecoder::validateFrameIndex(

VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
int streamIndex,
int64_t frameIndex) {
int64_t frameIndex,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
validateUserProvidedStreamIndex(streamIndex);
validateScannedAllStreams("getFrameAtIndex");

Expand All @@ -990,7 +1002,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(

int64_t pts = stream.allFrames[frameIndex].pts;
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
return getNextDecodedOutputNoDemux();
return getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
Expand Down Expand Up @@ -1062,8 +1074,10 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);

for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
output.frames[f] = singleOut.frame;
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
output.frames[f] = singleOut.frame;
}
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
Expand Down Expand Up @@ -1155,8 +1169,10 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
int64_t numFrames = stopFrameIndex - startFrameIndex;
BatchDecodedOutput output(numFrames, options, streamMetadata);
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
output.frames[f] = singleOut.frame;
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
output.frames[f] = singleOut.frame;
}
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
Expand All @@ -1173,9 +1189,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
return rawOutput;
}

VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
auto rawOutput = getNextRawDecodedOutputNoDemux();
return convertAVFrameToDecodedOutput(rawOutput);
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
}

void VideoDecoder::setCursorPtsInSeconds(double seconds) {
Expand Down
15 changes: 11 additions & 4 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,19 @@ class VideoDecoder {
};
// Decodes the frame where the current cursor position is. It also advances
// the cursor to the next frame.
DecodedOutput getNextDecodedOutputNoDemux();
DecodedOutput getNextDecodedOutputNoDemux(
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
// Decodes the first frame in any added stream that is visible at a given
// timestamp. Frames in the video have a presentation timestamp and a
// duration. For example, if a frame has presentation timestamp of 5.0s and a
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
// i.e. it will be returned when this function is called with seconds=5.0 or
// seconds=5.999, etc.
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
DecodedOutput getFrameAtIndex(
int streamIndex,
int64_t frameIndex,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
struct BatchDecodedOutput {
torch::Tensor frames;
torch::Tensor ptsSeconds;
Expand Down Expand Up @@ -363,10 +367,13 @@ class VideoDecoder {
int streamIndex,
const AVFrame* frame);
void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput);
DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput);
DecodedOutput convertAVFrameToDecodedOutput(
RawDecodedOutput& rawOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
void convertAVFrameToDecodedOutputOnCPU(
RawDecodedOutput& rawOutput,
DecodedOutput& output);
DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

DecoderOptions options_;
ContainerMetadata containerMetadata_;
Expand Down
42 changes: 42 additions & 0 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_frame_at_index,
get_frame_at_pts,
get_frames_at_indices,
get_frames_by_pts_in_range,
get_frames_in_range,
get_json_metadata,
get_next_frame,
Expand Down Expand Up @@ -383,6 +384,47 @@ def test_color_conversion_library_with_scaling(
swscale_frame0, _, _ = get_next_frame(swscale_decoder)
assert_tensor_equal(filtergraph_frame0, swscale_frame0)

@pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW"))
@pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))
def test_color_conversion_library_with_dimension_order(
self, dimension_order, color_conversion_library
):
decoder = create_from_file(str(NASA_VIDEO.path))
_add_video_stream(
decoder,
color_conversion_library=color_conversion_library,
dimension_order=dimension_order,
)
scan_all_streams_to_update_metadata(decoder)

frame0_ref = NASA_VIDEO.get_frame_data_by_index(0)
if dimension_order == "NHWC":
frame0_ref = frame0_ref.permute(1, 2, 0)
expected_shape = frame0_ref.shape

stream_index = 3
frame0, *_ = get_frame_at_index(
decoder, stream_index=stream_index, frame_index=0
)
assert frame0.shape == expected_shape
assert_tensor_equal(frame0, frame0_ref)

frame0, *_ = get_frame_at_pts(decoder, seconds=0.0)
assert frame0.shape == expected_shape
assert_tensor_equal(frame0, frame0_ref)

frames, *_ = get_frames_in_range(
decoder, stream_index=stream_index, start=0, stop=3
)
assert frames.shape[1:] == expected_shape
assert_tensor_equal(frames[0], frame0_ref)

frames, *_ = get_frames_by_pts_in_range(
decoder, stream_index=stream_index, start_seconds=0, stop_seconds=1
)
assert frames.shape[1:] == expected_shape
assert_tensor_equal(frames[0], frame0_ref)

@pytest.mark.parametrize(
"width_scaling_factor,height_scaling_factor",
((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),
Expand Down
Loading