Skip to content

Commit d6cbee5

Browse files
authored
Pass pre-allocate tensors in batch APIs to avoid copies (#266)
1 parent 74961d2 commit d6cbee5

File tree

3 files changed

+85
-19
lines changed

3 files changed

+85
-19
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
847847
}
848848

849849
VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
850-
VideoDecoder::RawDecodedOutput& rawOutput) {
850+
VideoDecoder::RawDecodedOutput& rawOutput,
851+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
851852
// Convert the frame to tensor.
852853
DecodedOutput output;
853854
int streamIndex = rawOutput.streamIndex;
@@ -862,8 +863,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
862863
output.durationSeconds = ptsToSeconds(
863864
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
864865
if (streamInfo.options.device.type() == torch::kCPU) {
865-
convertAVFrameToDecodedOutputOnCPU(rawOutput, output);
866+
convertAVFrameToDecodedOutputOnCPU(
867+
rawOutput, output, preAllocatedOutputTensor);
866868
} else if (streamInfo.options.device.type() == torch::kCUDA) {
869+
// TODO: handle pre-allocated output tensor
867870
convertAVFrameToDecodedOutputOnCuda(
868871
streamInfo.options.device,
869872
streamInfo.options,
@@ -879,16 +882,24 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
879882

880883
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
881884
VideoDecoder::RawDecodedOutput& rawOutput,
882-
DecodedOutput& output) {
885+
DecodedOutput& output,
886+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
883887
int streamIndex = rawOutput.streamIndex;
884888
AVFrame* frame = rawOutput.frame.get();
885889
auto& streamInfo = streams_[streamIndex];
886890
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
887891
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
888-
int width = streamInfo.options.width.value_or(frame->width);
889-
int height = streamInfo.options.height.value_or(frame->height);
890-
torch::Tensor tensor = torch::empty(
891-
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
892+
torch::Tensor tensor;
893+
if (preAllocatedOutputTensor.has_value()) {
894+
// TODO: check shape of preAllocatedOutputTensor?
895+
tensor = preAllocatedOutputTensor.value();
896+
} else {
897+
int width = streamInfo.options.width.value_or(frame->width);
898+
int height = streamInfo.options.height.value_or(frame->height);
899+
tensor = torch::empty(
900+
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
901+
}
902+
892903
rawOutput.data = tensor.data_ptr<uint8_t>();
893904
convertFrameToBufferUsingSwsScale(rawOutput);
894905

@@ -981,7 +992,8 @@ void VideoDecoder::validateFrameIndex(
981992

982993
VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
983994
int streamIndex,
984-
int64_t frameIndex) {
995+
int64_t frameIndex,
996+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
985997
validateUserProvidedStreamIndex(streamIndex);
986998
validateScannedAllStreams("getFrameAtIndex");
987999

@@ -990,7 +1002,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
9901002

9911003
int64_t pts = stream.allFrames[frameIndex].pts;
9921004
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
993-
return getNextDecodedOutputNoDemux();
1005+
return getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
9941006
}
9951007

9961008
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
@@ -1062,8 +1074,10 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10621074
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);
10631075

10641076
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
1065-
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
1066-
output.frames[f] = singleOut.frame;
1077+
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1078+
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1079+
output.frames[f] = singleOut.frame;
1080+
}
10671081
output.ptsSeconds[f] = singleOut.ptsSeconds;
10681082
output.durationSeconds[f] = singleOut.durationSeconds;
10691083
}
@@ -1155,8 +1169,10 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11551169
int64_t numFrames = stopFrameIndex - startFrameIndex;
11561170
BatchDecodedOutput output(numFrames, options, streamMetadata);
11571171
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
1158-
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i);
1159-
output.frames[f] = singleOut.frame;
1172+
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
1173+
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1174+
output.frames[f] = singleOut.frame;
1175+
}
11601176
output.ptsSeconds[f] = singleOut.ptsSeconds;
11611177
output.durationSeconds[f] = singleOut.durationSeconds;
11621178
}
@@ -1173,9 +1189,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
11731189
return rawOutput;
11741190
}
11751191

1176-
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
1192+
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux(
1193+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
11771194
auto rawOutput = getNextRawDecodedOutputNoDemux();
1178-
return convertAVFrameToDecodedOutput(rawOutput);
1195+
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
11791196
}
11801197

11811198
void VideoDecoder::setCursorPtsInSeconds(double seconds) {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,19 @@ class VideoDecoder {
214214
};
215215
// Decodes the frame where the current cursor position is. It also advances
216216
// the cursor to the next frame.
217-
DecodedOutput getNextDecodedOutputNoDemux();
217+
DecodedOutput getNextDecodedOutputNoDemux(
218+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
218219
// Decodes the first frame in any added stream that is visible at a given
219220
// timestamp. Frames in the video have a presentation timestamp and a
220221
// duration. For example, if a frame has presentation timestamp of 5.0s and a
221222
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
222223
// i.e. it will be returned when this function is called with seconds=5.0 or
223224
// seconds=5.999, etc.
224225
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
225-
DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
226+
DecodedOutput getFrameAtIndex(
227+
int streamIndex,
228+
int64_t frameIndex,
229+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
226230
struct BatchDecodedOutput {
227231
torch::Tensor frames;
228232
torch::Tensor ptsSeconds;
@@ -363,10 +367,13 @@ class VideoDecoder {
363367
int streamIndex,
364368
const AVFrame* frame);
365369
void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput);
366-
DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput);
370+
DecodedOutput convertAVFrameToDecodedOutput(
371+
RawDecodedOutput& rawOutput,
372+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
367373
void convertAVFrameToDecodedOutputOnCPU(
368374
RawDecodedOutput& rawOutput,
369-
DecodedOutput& output);
375+
DecodedOutput& output,
376+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
370377

371378
DecoderOptions options_;
372379
ContainerMetadata containerMetadata_;

test/decoders/test_video_decoder_ops.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_frame_at_index,
2828
get_frame_at_pts,
2929
get_frames_at_indices,
30+
get_frames_by_pts_in_range,
3031
get_frames_in_range,
3132
get_json_metadata,
3233
get_next_frame,
@@ -383,6 +384,47 @@ def test_color_conversion_library_with_scaling(
383384
swscale_frame0, _, _ = get_next_frame(swscale_decoder)
384385
assert_tensor_equal(filtergraph_frame0, swscale_frame0)
385386

387+
@pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW"))
388+
@pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))
389+
def test_color_conversion_library_with_dimension_order(
390+
self, dimension_order, color_conversion_library
391+
):
392+
decoder = create_from_file(str(NASA_VIDEO.path))
393+
_add_video_stream(
394+
decoder,
395+
color_conversion_library=color_conversion_library,
396+
dimension_order=dimension_order,
397+
)
398+
scan_all_streams_to_update_metadata(decoder)
399+
400+
frame0_ref = NASA_VIDEO.get_frame_data_by_index(0)
401+
if dimension_order == "NHWC":
402+
frame0_ref = frame0_ref.permute(1, 2, 0)
403+
expected_shape = frame0_ref.shape
404+
405+
stream_index = 3
406+
frame0, *_ = get_frame_at_index(
407+
decoder, stream_index=stream_index, frame_index=0
408+
)
409+
assert frame0.shape == expected_shape
410+
assert_tensor_equal(frame0, frame0_ref)
411+
412+
frame0, *_ = get_frame_at_pts(decoder, seconds=0.0)
413+
assert frame0.shape == expected_shape
414+
assert_tensor_equal(frame0, frame0_ref)
415+
416+
frames, *_ = get_frames_in_range(
417+
decoder, stream_index=stream_index, start=0, stop=3
418+
)
419+
assert frames.shape[1:] == expected_shape
420+
assert_tensor_equal(frames[0], frame0_ref)
421+
422+
frames, *_ = get_frames_by_pts_in_range(
423+
decoder, stream_index=stream_index, start_seconds=0, stop_seconds=1
424+
)
425+
assert frames.shape[1:] == expected_shape
426+
assert_tensor_equal(frames[0], frame0_ref)
427+
386428
@pytest.mark.parametrize(
387429
"width_scaling_factor,height_scaling_factor",
388430
((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),

0 commit comments

Comments
 (0)