@@ -847,7 +847,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
847847}
848848
849849VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput (
850- VideoDecoder::RawDecodedOutput& rawOutput) {
850+ VideoDecoder::RawDecodedOutput& rawOutput,
851+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
851852 // Convert the frame to tensor.
852853 DecodedOutput output;
853854 int streamIndex = rawOutput.streamIndex ;
@@ -862,8 +863,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
862863 output.durationSeconds = ptsToSeconds (
863864 getDuration (frame), formatContext_->streams [streamIndex]->time_base );
864865 if (streamInfo.options .device .type () == torch::kCPU ) {
865- convertAVFrameToDecodedOutputOnCPU (rawOutput, output);
866+ convertAVFrameToDecodedOutputOnCPU (
867+ rawOutput, output, preAllocatedOutputTensor);
866868 } else if (streamInfo.options .device .type () == torch::kCUDA ) {
869+ // TODO: handle pre-allocated output tensor
867870 convertAVFrameToDecodedOutputOnCuda (
868871 streamInfo.options .device ,
869872 streamInfo.options ,
@@ -879,16 +882,24 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
879882
880883void VideoDecoder::convertAVFrameToDecodedOutputOnCPU (
881884 VideoDecoder::RawDecodedOutput& rawOutput,
882- DecodedOutput& output) {
885+ DecodedOutput& output,
886+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
883887 int streamIndex = rawOutput.streamIndex ;
884888 AVFrame* frame = rawOutput.frame .get ();
885889 auto & streamInfo = streams_[streamIndex];
886890 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
887891 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
888- int width = streamInfo.options .width .value_or (frame->width );
889- int height = streamInfo.options .height .value_or (frame->height );
890- torch::Tensor tensor = torch::empty (
891- {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
892+ torch::Tensor tensor;
893+ if (preAllocatedOutputTensor.has_value ()) {
894+ // TODO: check shape of preAllocatedOutputTensor?
895+ tensor = preAllocatedOutputTensor.value ();
896+ } else {
897+ int width = streamInfo.options .width .value_or (frame->width );
898+ int height = streamInfo.options .height .value_or (frame->height );
899+ tensor = torch::empty (
900+ {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
901+ }
902+
892903 rawOutput.data = tensor.data_ptr <uint8_t >();
893904 convertFrameToBufferUsingSwsScale (rawOutput);
894905
@@ -981,7 +992,8 @@ void VideoDecoder::validateFrameIndex(
981992
982993VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex (
983994 int streamIndex,
984- int64_t frameIndex) {
995+ int64_t frameIndex,
996+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
985997 validateUserProvidedStreamIndex (streamIndex);
986998 validateScannedAllStreams (" getFrameAtIndex" );
987999
@@ -990,7 +1002,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
9901002
9911003 int64_t pts = stream.allFrames [frameIndex].pts ;
9921004 setCursorPtsInSeconds (ptsToSeconds (pts, stream.timeBase ));
993- return getNextDecodedOutputNoDemux ();
1005+ return getNextDecodedOutputNoDemux (preAllocatedOutputTensor );
9941006}
9951007
9961008VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices (
@@ -1062,8 +1074,10 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10621074 BatchDecodedOutput output (numOutputFrames, options, streamMetadata);
10631075
10641076 for (int64_t i = start, f = 0 ; i < stop; i += step, ++f) {
1065- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i);
1066- output.frames [f] = singleOut.frame ;
1077+ DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1078+ if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1079+ output.frames [f] = singleOut.frame ;
1080+ }
10671081 output.ptsSeconds [f] = singleOut.ptsSeconds ;
10681082 output.durationSeconds [f] = singleOut.durationSeconds ;
10691083 }
@@ -1155,8 +1169,10 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11551169 int64_t numFrames = stopFrameIndex - startFrameIndex;
11561170 BatchDecodedOutput output (numFrames, options, streamMetadata);
11571171 for (int64_t i = startFrameIndex, f = 0 ; i < stopFrameIndex; ++i, ++f) {
1158- DecodedOutput singleOut = getFrameAtIndex (streamIndex, i);
1159- output.frames [f] = singleOut.frame ;
1172+ DecodedOutput singleOut = getFrameAtIndex (streamIndex, i, output.frames [f]);
1173+ if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1174+ output.frames [f] = singleOut.frame ;
1175+ }
11601176 output.ptsSeconds [f] = singleOut.ptsSeconds ;
11611177 output.durationSeconds [f] = singleOut.durationSeconds ;
11621178 }
@@ -1173,9 +1189,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
11731189 return rawOutput;
11741190}
11751191
1176- VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux () {
1192+ VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux (
1193+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
11771194 auto rawOutput = getNextRawDecodedOutputNoDemux ();
1178- return convertAVFrameToDecodedOutput (rawOutput);
1195+ return convertAVFrameToDecodedOutput (rawOutput, preAllocatedOutputTensor );
11791196}
11801197
11811198void VideoDecoder::setCursorPtsInSeconds (double seconds) {
0 commit comments