From 8e06aa6ae52e3e97446f7b764e09f19595470a41 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 09:41:33 +0100 Subject: [PATCH 01/12] pass a pointer, segfaults --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 14 +++++++++----- src/torchcodec/decoders/_core/VideoDecoder.h | 6 +++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 58a605fd..c7f9d0ef 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -878,7 +878,9 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::RawDecodedOutput& rawOutput, - DecodedOutput& output) { + DecodedOutput& output, + torch::Tensor* tensor + ) { int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); auto& streamInfo = streams_[streamIndex]; @@ -886,15 +888,17 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { int width = streamInfo.options.width.value_or(frame->width); int height = streamInfo.options.height.value_or(frame->height); - torch::Tensor tensor = torch::empty( + auto tmp = torch::empty( {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); - rawOutput.data = tensor.data_ptr(); + tensor = &(tmp); + rawOutput.data = tensor->data_ptr(); convertFrameToBufferUsingSwsScale(rawOutput); if (streamInfo.options.dimensionOrder == "NCHW") { - tensor = tensor.permute({2, 0, 1}); + auto tmp = tensor->permute({2, 0, 1}); + tensor = &(tmp); } - output.frame = tensor; + output.frame = *tensor; } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8535a61e..01797e09 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -366,7 +366,11 @@ class VideoDecoder { DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput); void convertAVFrameToDecodedOutputOnCPU( RawDecodedOutput& rawOutput, - DecodedOutput& output); + DecodedOutput& output, + // TODO: Unable to use std::optional& tensor = std::nullopt + // on a non-const tensor :( ? + torch::Tensor* tensor = nullptr + ); DecoderOptions options_; ContainerMetadata containerMetadata_; From 025bf270b4a75b7066f5efd2f7e6a8ce369d42dc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 09:41:38 +0100 Subject: [PATCH 02/12] Revert "pass a pointer, segfaults" This reverts commit 8e06aa6ae52e3e97446f7b764e09f19595470a41. --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 14 +++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 6 +----- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index c7f9d0ef..58a605fd 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -878,9 +878,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::RawDecodedOutput& rawOutput, - DecodedOutput& output, - torch::Tensor* tensor - ) { + DecodedOutput& output) { int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); auto& streamInfo = streams_[streamIndex]; @@ -888,17 +886,15 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { int width = streamInfo.options.width.value_or(frame->width); int height = streamInfo.options.height.value_or(frame->height); - auto tmp = torch::empty( + torch::Tensor tensor = torch::empty( {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); - tensor = &(tmp); - rawOutput.data = tensor->data_ptr(); + rawOutput.data = tensor.data_ptr(); convertFrameToBufferUsingSwsScale(rawOutput); if (streamInfo.options.dimensionOrder == "NCHW") { - auto tmp = tensor->permute({2, 0, 1}); - tensor = &(tmp); + tensor = tensor.permute({2, 0, 1}); } - output.frame = *tensor; + output.frame = tensor; } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 01797e09..8535a61e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -366,11 +366,7 @@ class VideoDecoder { DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput); void convertAVFrameToDecodedOutputOnCPU( RawDecodedOutput& rawOutput, - DecodedOutput& output, - // TODO: Unable to use std::optional& tensor = std::nullopt - // on a non-const tensor :( ? - torch::Tensor* tensor = nullptr - ); + DecodedOutput& output); DecoderOptions options_; ContainerMetadata containerMetadata_; From f83ada99dfc361d77dcebe2c260447d4b78ced5a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 11:55:18 +0100 Subject: [PATCH 03/12] Pre-allocate tensors when possible to avoid copies --- .../decoders/_core/VideoDecoder.cpp | 55 +++++++++++++------ src/torchcodec/decoders/_core/VideoDecoder.h | 16 ++++-- .../decoders/_core/VideoDecoderOps.cpp | 8 ++- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 58a605fd..07de60e8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -846,7 +846,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( } VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( - VideoDecoder::RawDecodedOutput& rawOutput) { + VideoDecoder::RawDecodedOutput& rawOutput, + torch::Tensor& preAllocatedOutputTensor) { // Convert the frame to tensor. DecodedOutput output; int streamIndex = rawOutput.streamIndex; @@ -861,8 +862,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); if (streamInfo.options.device.type() == torch::kCPU) { - convertAVFrameToDecodedOutputOnCPU(rawOutput, output); + convertAVFrameToDecodedOutputOnCPU( + rawOutput, output, preAllocatedOutputTensor); } else if (streamInfo.options.device.type() == torch::kCUDA) { + // TODO: handle pre-allocated output tensor convertAVFrameToDecodedOutputOnCuda( streamInfo.options.device, streamInfo.options, @@ -878,16 +881,21 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::RawDecodedOutput& rawOutput, - DecodedOutput& output) { + DecodedOutput& output, + torch::Tensor& preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); auto& streamInfo = streams_[streamIndex]; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - int width = streamInfo.options.width.value_or(frame->width); - int height = streamInfo.options.height.value_or(frame->height); - torch::Tensor tensor = torch::empty( - {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); + torch::Tensor tensor; + if (preAllocatedOutputTensor.numel() != 0) { + // TODO: check shape of preAllocatedOutputTensor? + tensor = preAllocatedOutputTensor; + } else { + tensor = allocateOutputTensorFromRawOutput(rawOutput); + } + rawOutput.data = tensor.data_ptr(); convertFrameToBufferUsingSwsScale(rawOutput); @@ -912,6 +920,16 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( } } +torch::Tensor VideoDecoder::allocateOutputTensorFromRawOutput( + RawDecodedOutput& rawOutput) { + AVFrame* frame = rawOutput.frame.get(); + StreamInfo& streamInfo = streams_[rawOutput.streamIndex]; + int width = streamInfo.options.width.value_or(frame->width); + int height = streamInfo.options.height.value_or(frame->height); + return torch::empty( + {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); +} + VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( double seconds) { for (auto& [streamIndex, stream] : streams_) { @@ -945,7 +963,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( return seconds >= frameStartTime && seconds < frameEndTime; }); // Convert the frame to tensor. - return convertAVFrameToDecodedOutput(rawOutput); + auto preAllocatedOutputTensor = allocateOutputTensorFromRawOutput(rawOutput); + return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) { @@ -980,7 +999,8 @@ void VideoDecoder::validateFrameIndex( VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int streamIndex, - int64_t frameIndex) { + int64_t frameIndex, + torch::Tensor& preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFrameAtIndex"); @@ -989,7 +1009,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int64_t pts = stream.allFrames[frameIndex].pts; setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); - return getNextDecodedOutputNoDemux(); + return getNextDecodedOutputNoDemux(preAllocatedOutputTensor); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( @@ -1061,8 +1081,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( BatchDecodedOutput output(numOutputFrames, options, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { - DecodedOutput singleOut = getFrameAtIndex(streamIndex, i); - output.frames[f] = singleOut.frame; + auto preAllocatedOutputTensor = output.frames[f]; + DecodedOutput singleOut = + getFrameAtIndex(streamIndex, i, preAllocatedOutputTensor); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } @@ -1154,8 +1175,9 @@ VideoDecoder::getFramesDisplayedByTimestampInRange( int64_t numFrames = stopFrameIndex - startFrameIndex; BatchDecodedOutput output(numFrames, options, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - DecodedOutput singleOut = getFrameAtIndex(streamIndex, i); - output.frames[f] = singleOut.frame; + auto preAllocatedOutputTensor = output.frames[f]; + DecodedOutput singleOut = + getFrameAtIndex(streamIndex, i, preAllocatedOutputTensor); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } @@ -1173,9 +1195,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { return rawOutput; } -VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() { +VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux( + torch::Tensor& preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); - return convertAVFrameToDecodedOutput(rawOutput); + return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } void VideoDecoder::setCursorPtsInSeconds(double seconds) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8535a61e..f0f8bcb8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -214,7 +214,8 @@ class VideoDecoder { }; // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - DecodedOutput getNextDecodedOutputNoDemux(); + DecodedOutput getNextDecodedOutputNoDemux( + torch::Tensor& preAllocatedOutputTensor); // Decodes the first frame in any added stream that is visible at a given // timestamp. Frames in the video have a presentation timestamp and a // duration. For example, if a frame has presentation timestamp of 5.0s and a @@ -222,7 +223,10 @@ class VideoDecoder { // i.e. it will be returned when this function is called with seconds=5.0 or // seconds=5.999, etc. DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds); - DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); + DecodedOutput getFrameAtIndex( + int streamIndex, + int64_t frameIndex, + torch::Tensor& preAllocatedOutputTensor); struct BatchDecodedOutput { torch::Tensor frames; torch::Tensor ptsSeconds; @@ -363,10 +367,14 @@ class VideoDecoder { int streamIndex, const AVFrame* frame); void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput); - DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput); + DecodedOutput convertAVFrameToDecodedOutput( + RawDecodedOutput& rawOutput, + torch::Tensor& preAllocatedOutputTensor); void convertAVFrameToDecodedOutputOnCPU( RawDecodedOutput& rawOutput, - DecodedOutput& output); + DecodedOutput& output, + torch::Tensor& preAllocatedOutputTensor); + torch::Tensor allocateOutputTensorFromRawOutput(RawDecodedOutput& rawOutput); DecoderOptions options_; ContainerMetadata containerMetadata_; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 7da11b39..d4f2e546 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -190,8 +190,10 @@ void seek_to_pts(at::Tensor& decoder, double seconds) { OpsDecodedOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::DecodedOutput result; + auto preAllocatedOutputTensor = torch::empty({0}); try { - result = videoDecoder->getNextDecodedOutputNoDemux(); + result = + videoDecoder->getNextDecodedOutputNoDemux(preAllocatedOutputTensor); } catch (const VideoDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } @@ -214,7 +216,9 @@ OpsDecodedOutput get_frame_at_index( int64_t stream_index, int64_t frame_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index); + auto preAllocatedOutputTensor = torch::empty({0}); + auto result = videoDecoder->getFrameAtIndex( + stream_index, frame_index, preAllocatedOutputTensor); return makeOpsDecodedOutput(result); } From 72717bd07035d4b14c09654daecb19b6b3396144 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 12:32:04 +0100 Subject: [PATCH 04/12] refac --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 17 +++++------------ src/torchcodec/decoders/_core/VideoDecoder.h | 1 - 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 07de60e8..1f8790a1 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -893,7 +893,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( // TODO: check shape of preAllocatedOutputTensor? tensor = preAllocatedOutputTensor; } else { - tensor = allocateOutputTensorFromRawOutput(rawOutput); + int width = streamInfo.options.width.value_or(frame->width); + int height = streamInfo.options.height.value_or(frame->height); + tensor = torch::empty( + {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); } rawOutput.data = tensor.data_ptr(); @@ -920,16 +923,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( } } -torch::Tensor VideoDecoder::allocateOutputTensorFromRawOutput( - RawDecodedOutput& rawOutput) { - AVFrame* frame = rawOutput.frame.get(); - StreamInfo& streamInfo = streams_[rawOutput.streamIndex]; - int width = streamInfo.options.width.value_or(frame->width); - int height = streamInfo.options.height.value_or(frame->height); - return torch::empty( - {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); -} - VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( double seconds) { for (auto& [streamIndex, stream] : streams_) { @@ -963,7 +956,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( return seconds >= frameStartTime && seconds < frameEndTime; }); // Convert the frame to tensor. - auto preAllocatedOutputTensor = allocateOutputTensorFromRawOutput(rawOutput); + auto preAllocatedOutputTensor = torch::empty({0}); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index f0f8bcb8..ada430e6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -374,7 +374,6 @@ class VideoDecoder { RawDecodedOutput& rawOutput, DecodedOutput& output, torch::Tensor& preAllocatedOutputTensor); - torch::Tensor allocateOutputTensorFromRawOutput(RawDecodedOutput& rawOutput); DecoderOptions options_; ContainerMetadata containerMetadata_; From 291bc87377e99713ee4997ca81f91ffa40877f13 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 13:33:03 +0100 Subject: [PATCH 05/12] Fix C++ tests --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 4 ++++ src/torchcodec/decoders/_core/VideoDecoder.h | 1 + 2 files changed, 5 insertions(+) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 1f8790a1..64ffcd39 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1188,6 +1188,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { return rawOutput; } +VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() { + auto preAllocatedOutputTensor = torch::empty({0}); + return VideoDecoder::getNextDecodedOutputNoDemux(preAllocatedOutputTensor); +} VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux( torch::Tensor& preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index ada430e6..4f08a2c8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -214,6 +214,7 @@ class VideoDecoder { }; // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. + DecodedOutput getNextDecodedOutputNoDemux(); DecodedOutput getNextDecodedOutputNoDemux( torch::Tensor& preAllocatedOutputTensor); // Decodes the first frame in any added stream that is visible at a given From 887ae42506ff1c4610004c3f05eb2485802d585a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 13:36:54 +0100 Subject: [PATCH 06/12] minor simplification --- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index d4f2e546..8bae507a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -190,10 +190,8 @@ void seek_to_pts(at::Tensor& decoder, double seconds) { OpsDecodedOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::DecodedOutput result; - auto preAllocatedOutputTensor = torch::empty({0}); try { - result = - videoDecoder->getNextDecodedOutputNoDemux(preAllocatedOutputTensor); + result = videoDecoder->getNextDecodedOutputNoDemux(); } catch (const VideoDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } From 9418cb317ac18c83a8e275bc1f92f6ce1beaaadd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 16:40:53 +0100 Subject: [PATCH 07/12] WIP --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 16 ++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 9 ++++----- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 64ffcd39..aba53975 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -847,7 +847,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( VideoDecoder::RawDecodedOutput& rawOutput, - torch::Tensor& preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. DecodedOutput output; int streamIndex = rawOutput.streamIndex; @@ -882,16 +882,16 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::RawDecodedOutput& rawOutput, DecodedOutput& output, - torch::Tensor& preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); auto& streamInfo = streams_[streamIndex]; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { torch::Tensor tensor; - if (preAllocatedOutputTensor.numel() != 0) { + if (preAllocatedOutputTensor.has_value()) { // TODO: check shape of preAllocatedOutputTensor? - tensor = preAllocatedOutputTensor; + tensor = preAllocatedOutputTensor.value(); } else { int width = streamInfo.options.width.value_or(frame->width); int height = streamInfo.options.height.value_or(frame->height); @@ -993,7 +993,7 @@ void VideoDecoder::validateFrameIndex( VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int streamIndex, int64_t frameIndex, - torch::Tensor& preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFrameAtIndex"); @@ -1188,12 +1188,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { return rawOutput; } -VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() { - auto preAllocatedOutputTensor = torch::empty({0}); - return VideoDecoder::getNextDecodedOutputNoDemux(preAllocatedOutputTensor); -} VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux( - torch::Tensor& preAllocatedOutputTensor) { + std::optional preAllocatedOutputTensor){ auto rawOutput = getNextRawDecodedOutputNoDemux(); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 4f08a2c8..761f07df 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -214,9 +214,8 @@ class VideoDecoder { }; // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - DecodedOutput getNextDecodedOutputNoDemux(); DecodedOutput getNextDecodedOutputNoDemux( - torch::Tensor& preAllocatedOutputTensor); + std::optional preAllocatedOutputTensor = std::nullopt); // Decodes the first frame in any added stream that is visible at a given // timestamp. Frames in the video have a presentation timestamp and a // duration. For example, if a frame has presentation timestamp of 5.0s and a @@ -227,7 +226,7 @@ class VideoDecoder { DecodedOutput getFrameAtIndex( int streamIndex, int64_t frameIndex, - torch::Tensor& preAllocatedOutputTensor); + std::optional preAllocatedOutputTensor = std::nullopt); struct BatchDecodedOutput { torch::Tensor frames; torch::Tensor ptsSeconds; @@ -370,11 +369,11 @@ class VideoDecoder { void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput); DecodedOutput convertAVFrameToDecodedOutput( RawDecodedOutput& rawOutput, - torch::Tensor& preAllocatedOutputTensor); + std::optional preAllocatedOutputTensor = std::nullopt); void convertAVFrameToDecodedOutputOnCPU( RawDecodedOutput& rawOutput, DecodedOutput& output, - torch::Tensor& preAllocatedOutputTensor); + std::optional preAllocatedOutputTensor = std::nullopt); DecoderOptions options_; ContainerMetadata containerMetadata_; From 6a2190c070b5fe9e410977ca488144ed439829ac Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 16:42:48 +0100 Subject: [PATCH 08/12] WIP --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index aba53975..5ef73eeb 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -889,7 +889,8 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { torch::Tensor tensor; - if (preAllocatedOutputTensor.has_value()) { + // if (preAllocatedOutputTensor.has_value()) { + if (false) { // TODO: check shape of preAllocatedOutputTensor? tensor = preAllocatedOutputTensor.value(); } else { From c8f2e790e9f2c8551aa2ef63b9619ba1132368e6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Oct 2024 16:49:03 +0100 Subject: [PATCH 09/12] don't use a ref --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 8 +++----- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 4 +--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 5ef73eeb..e8b0b2df 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -889,8 +889,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { torch::Tensor tensor; - // if (preAllocatedOutputTensor.has_value()) { - if (false) { + if (preAllocatedOutputTensor.has_value()) { // TODO: check shape of preAllocatedOutputTensor? tensor = preAllocatedOutputTensor.value(); } else { @@ -957,8 +956,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( return seconds >= frameStartTime && seconds < frameEndTime; }); // Convert the frame to tensor. - auto preAllocatedOutputTensor = torch::empty({0}); - return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); + return convertAVFrameToDecodedOutput(rawOutput); } void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) { @@ -1190,7 +1188,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { } VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux( - std::optional preAllocatedOutputTensor){ + std::optional preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 57da01ae..70f4afdc 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -214,9 +214,7 @@ OpsDecodedOutput get_frame_at_index( int64_t stream_index, int64_t frame_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto preAllocatedOutputTensor = torch::empty({0}); - auto result = videoDecoder->getFrameAtIndex( - stream_index, frame_index, preAllocatedOutputTensor); + auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index); return makeOpsDecodedOutput(result); } From 5113b9c9154aba878ca58fd449032bdc2f7b1b38 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 17 Oct 2024 06:18:48 -0700 Subject: [PATCH 10/12] Avoid temporary variable --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index e8b0b2df..624cf47b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1073,9 +1073,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( BatchDecodedOutput output(numOutputFrames, options, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { - auto preAllocatedOutputTensor = output.frames[f]; DecodedOutput singleOut = - getFrameAtIndex(streamIndex, i, preAllocatedOutputTensor); + getFrameAtIndex(streamIndex, i, output.frames[f]); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } @@ -1167,9 +1166,8 @@ VideoDecoder::getFramesDisplayedByTimestampInRange( int64_t numFrames = stopFrameIndex - startFrameIndex; BatchDecodedOutput output(numFrames, options, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - auto preAllocatedOutputTensor = output.frames[f]; DecodedOutput singleOut = - getFrameAtIndex(streamIndex, i, preAllocatedOutputTensor); + getFrameAtIndex(streamIndex, i, output.frames[f]); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } From 9387537ca88afe9ea55d13ae7bb4f3ffaa29001c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 17 Oct 2024 06:33:23 -0700 Subject: [PATCH 11/12] Test, and fix --- .../decoders/_core/VideoDecoder.cpp | 12 ++++-- test/decoders/test_video_decoder_ops.py | 43 +++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 624cf47b..a71b49b9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1073,8 +1073,10 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( BatchDecodedOutput output(numOutputFrames, options, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { - DecodedOutput singleOut = - getFrameAtIndex(streamIndex, i, output.frames[f]); + DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); + if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { + output.frames[f] = singleOut.frame; + } output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } @@ -1166,8 +1168,10 @@ VideoDecoder::getFramesDisplayedByTimestampInRange( int64_t numFrames = stopFrameIndex - startFrameIndex; BatchDecodedOutput output(numFrames, options, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - DecodedOutput singleOut = - getFrameAtIndex(streamIndex, i, output.frames[f]); + DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); + if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { + output.frames[f] = singleOut.frame; + } output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 1bb28feb..87884026 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -27,6 +27,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, get_next_frame, @@ -383,6 +384,48 @@ def test_color_conversion_library_with_scaling( swscale_frame0, _, _ = get_next_frame(swscale_decoder) assert_tensor_equal(filtergraph_frame0, swscale_frame0) + @pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW")) + @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) + def test_color_conversion_library_with_dimension_order( + self, dimension_order, color_conversion_library + ): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream( + decoder, + color_conversion_library=color_conversion_library, + dimension_order=dimension_order, + ) + scan_all_streams_to_update_metadata(decoder) + + frame0_ref = NASA_VIDEO.get_frame_data_by_index(0) + C, H, W = frame0_ref.shape + if dimension_order == "NHWC": + frame0_ref = frame0_ref.permute(1, 2, 0) + expected_shape = frame0_ref.shape + + stream_index = 3 + frame0, *_ = get_frame_at_index( + decoder, stream_index=stream_index, frame_index=0 + ) + assert frame0.shape == expected_shape + assert_tensor_equal(frame0, frame0_ref) + + frame0, *_ = get_frame_at_pts(decoder, seconds=0.0) + assert frame0.shape == expected_shape + assert_tensor_equal(frame0, frame0_ref) + + frames, *_ = get_frames_in_range( + decoder, stream_index=stream_index, start=0, stop=3 + ) + assert frames.shape[1:] == expected_shape + assert_tensor_equal(frames[0], frame0_ref) + + frames, *_ = get_frames_by_pts_in_range( + decoder, stream_index=stream_index, start_seconds=0, stop_seconds=1 + ) + assert frames.shape[1:] == expected_shape + assert_tensor_equal(frames[0], frame0_ref) + @pytest.mark.parametrize( "width_scaling_factor,height_scaling_factor", ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), From e23acb79d8e16e3ee1dcdde28cc09503339be046 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 18 Oct 2024 10:21:16 +0100 Subject: [PATCH 12/12] Update test/decoders/test_video_decoder_ops.py --- test/decoders/test_video_decoder_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 87884026..18782a36 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -398,7 +398,6 @@ def test_color_conversion_library_with_dimension_order( scan_all_streams_to_update_metadata(decoder) frame0_ref = NASA_VIDEO.get_frame_data_by_index(0) - C, H, W = frame0_ref.shape if dimension_order == "NHWC": frame0_ref = frame0_ref.permute(1, 2, 0) expected_shape = frame0_ref.shape