diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 20a4e380..02c1c942 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -20,7 +20,7 @@ void convertAVFrameToDecodedOutputOnCuda( AVCodecContext* codecContext, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, - std::optional preAllocatedOutputTensor) { + torch::Tensor preAllocatedOutputTensor) { throwUnsupportedDeviceError(device); } diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index b15684a2..9f2e7634 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) { #endif } -torch::Tensor allocateDeviceTensor( - at::IntArrayRef shape, - torch::Device device, - const torch::Dtype dtype = torch::kUInt8) { - return torch::empty( - shape, - torch::TensorOptions() - .dtype(dtype) - .layout(torch::kStrided) - .device(device)); -} - void throwErrorIfNonCudaDevice(const torch::Device& device) { TORCH_CHECK( device.type() != torch::kCPU, @@ -202,7 +190,7 @@ void convertAVFrameToDecodedOutputOnCuda( AVCodecContext* codecContext, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, - std::optional preAllocatedOutputTensor) { + torch::Tensor preAllocatedOutputTensor) { AVFrame* src = rawOutput.frame.get(); TORCH_CHECK( @@ -213,22 +201,6 @@ void convertAVFrameToDecodedOutputOnCuda( int height = options.height.value_or(codecContext->height); NppiSize oSizeROI = {width, height}; Npp8u* input[2] = {src->data[0], src->data[1]}; - torch::Tensor& dst = output.frame; - if (preAllocatedOutputTensor.has_value()) { - dst = preAllocatedOutputTensor.value(); - auto shape = dst.sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && - (shape[2] == 3), - "Expected tensor of shape ", - height, - "x", - width, - "x3, got ", - shape); - } else { - dst = allocateDeviceTensor({height, width, 3}, options.device); - } // Use the user-requested GPU for running the NPP kernel. c10::cuda::CUDAGuard deviceGuard(device); @@ -238,10 +210,11 @@ void convertAVFrameToDecodedOutputOnCuda( NppStatus status = nppiNV12ToRGB_8u_P2C3R( input, src->linesize[0], - static_cast(dst.data_ptr()), - dst.stride(0), + static_cast(preAllocatedOutputTensor.data_ptr()), + preAllocatedOutputTensor.stride(0), oSizeROI); TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + output.frame = preAllocatedOutputTensor; // Make the pytorch stream wait for the npp kernel to finish before using the // output. at::cuda::CUDAEvent nppDoneEvent; diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 772bdfe6..65fb6ce4 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -38,7 +38,7 @@ void convertAVFrameToDecodedOutputOnCuda( AVCodecContext* codecContext, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, - std::optional preAllocatedOutputTensor = std::nullopt); + torch::Tensor preAllocatedOutputTensor); void releaseContextOnCuda( const torch::Device& device, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 91566328..7280a00a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -187,18 +187,34 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( } } -VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( - int64_t numFrames, - const VideoStreamDecoderOptions& options, - const StreamMetadata& metadata) - : frames(torch::empty( - {numFrames, - options.height.value_or(*metadata.height), - options.width.value_or(*metadata.width), - 3}, - at::TensorOptions(options.device).dtype(torch::kUInt8))), - ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), - durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {} +torch::Tensor VideoDecoder::allocateEmptyHWCTensorForStream( + int streamIndex, + std::optional numFrames) { + auto metadata = containerMetadata_.streams[streamIndex]; + auto options = streams_[streamIndex].options; + auto height = options.height.value_or(*metadata.height); + auto width = options.width.value_or(*metadata.width); + + auto tensorOptions = torch::TensorOptions() + .dtype(torch::kUInt8) + .layout(torch::kStrided) + .device(options.device.type()); + if (numFrames.has_value()) { + return torch::empty({numFrames.value(), height, width, 3}, tensorOptions); + } else { + return torch::empty({height, width, 3}, tensorOptions); + } +} + +VideoDecoder::BatchDecodedOutput VideoDecoder::allocateBatchDecodedOutput( + int streamIndex, + int64_t numFrames) { + BatchDecodedOutput output; + output.frames = allocateEmptyHWCTensorForStream(streamIndex, numFrames); + output.ptsSeconds = torch::empty({numFrames}, {torch::kFloat64}); + output.durationSeconds = torch::empty({numFrames}, {torch::kFloat64}); + return output; +} VideoDecoder::VideoDecoder() {} @@ -841,7 +857,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( VideoDecoder::RawDecodedOutput& rawOutput, - std::optional preAllocatedOutputTensor) { + torch::Tensor preAllocatedOutputTensor) { // Convert the frame to tensor. DecodedOutput output; int streamIndex = rawOutput.streamIndex; @@ -875,7 +891,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( } // Note [preAllocatedOutputTensor with swscale and filtergraph]: -// Callers may pass a pre-allocated tensor, where the output frame tensor will +// Callers must pass a pre-allocated tensor, where the output frame tensor will // be stored. This parameter is honored in any case, but it only leads to a // speed-up when swscale is used. With swscale, we can tell ffmpeg to place the // decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet @@ -886,50 +902,25 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::RawDecodedOutput& rawOutput, DecodedOutput& output, - std::optional preAllocatedOutputTensor) { + torch::Tensor preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); auto& streamInfo = streams_[streamIndex]; - torch::Tensor tensor; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - int width = streamInfo.options.width.value_or(frame->width); - int height = streamInfo.options.height.value_or(frame->height); - if (preAllocatedOutputTensor.has_value()) { - tensor = preAllocatedOutputTensor.value(); - auto shape = tensor.sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == height) && - (shape[1] == width) && (shape[2] == 3), - "Expected tensor of shape ", - height, - "x", - width, - "x3, got ", - shape); - } else { - tensor = torch::empty( - {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); - } - rawOutput.data = tensor.data_ptr(); + rawOutput.data = preAllocatedOutputTensor.data_ptr(); convertFrameToBufferUsingSwsScale(rawOutput); - - output.frame = tensor; } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - tensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); - if (preAllocatedOutputTensor.has_value()) { - preAllocatedOutputTensor.value().copy_(tensor); - output.frame = preAllocatedOutputTensor.value(); - } else { - output.frame = tensor; - } + auto tmpTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); + preAllocatedOutputTensor.copy_(tmpTensor); } else { throw std::runtime_error( "Invalid color conversion library: " + std::to_string(static_cast(streamInfo.colorConversionLibrary))); } + output.frame = preAllocatedOutputTensor; } else if (output.streamType == AVMEDIA_TYPE_AUDIO) { // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement @@ -971,8 +962,11 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( return seconds >= frameStartTime && seconds < frameEndTime; }); // Convert the frame to tensor. - auto output = convertAVFrameToDecodedOutput(rawOutput); - output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); + auto streamIndex = rawOutput.streamIndex; + auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex); + auto output = + convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); + output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); return output; } @@ -1009,7 +1003,9 @@ void VideoDecoder::validateFrameIndex( VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int streamIndex, int64_t frameIndex) { - auto output = getFrameAtIndexInternal(streamIndex, frameIndex); + auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex); + auto output = getFrameAtIndexInternal( + streamIndex, frameIndex, preAllocatedOutputTensor); output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); return output; } @@ -1017,7 +1013,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, - std::optional preAllocatedOutputTensor) { + torch::Tensor preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFrameAtIndex"); @@ -1057,7 +1053,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; const auto& options = stream.options; - BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); + BatchDecodedOutput output = + allocateBatchDecodedOutput(streamIndex, frameIndices.size()); auto previousIndexInVideo = -1; for (auto f = 0; f < frameIndices.size(); ++f) { @@ -1149,8 +1146,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( step > 0, "Step must be greater than 0; is " + std::to_string(step)); int64_t numOutputFrames = std::ceil((stop - start) / double(step)); - const auto& options = stream.options; - BatchDecodedOutput output(numOutputFrames, options, streamMetadata); + BatchDecodedOutput output = + allocateBatchDecodedOutput(streamIndex, numOutputFrames); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { DecodedOutput singleOut = @@ -1189,9 +1186,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange( "; must be less than or equal to " + std::to_string(maxSeconds) + ")."); - const auto& stream = streams_[streamIndex]; - const auto& options = stream.options; - // Special case needed to implement a half-open range. At first glance, this // may seem unnecessary, as our search for stopFrame can return the end, and // we don't include stopFramIndex in our output. However, consider the @@ -1210,7 +1204,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // values of the intervals will map to the same frame indices below. Hence, we // need this special case below. if (startSeconds == stopSeconds) { - BatchDecodedOutput output(0, options, streamMetadata); + BatchDecodedOutput output = allocateBatchDecodedOutput(streamIndex, 0); output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1226,6 +1220,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // 2. In order to establish if the start of an interval maps to a particular // frame, we need to figure out if it is ordered after the frame's pts, but // before the next frames's pts. + const auto& stream = streams_[streamIndex]; auto startFrame = std::lower_bound( stream.allFrames.begin(), stream.allFrames.end(), @@ -1245,7 +1240,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange( int64_t startFrameIndex = startFrame - stream.allFrames.begin(); int64_t stopFrameIndex = stopFrame - stream.allFrames.begin(); int64_t numFrames = stopFrameIndex - startFrameIndex; - BatchDecodedOutput output(numFrames, options, streamMetadata); + BatchDecodedOutput output = + allocateBatchDecodedOutput(streamIndex, numFrames); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { DecodedOutput singleOut = getFrameAtIndexInternal(streamIndex, i, output.frames[f]); @@ -1267,13 +1263,17 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { } VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { - auto output = getNextFrameOutputNoDemuxInternal(); - output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); + auto rawOutput = getNextRawDecodedOutputNoDemux(); + auto streamIndex = rawOutput.streamIndex; + auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex); + auto output = + convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); + output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); return output; } VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal( - std::optional preAllocatedOutputTensor) { + torch::Tensor preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index ce4a0cc1..0cdce613 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -157,8 +157,6 @@ class VideoDecoder { int streamIndex, const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions()); - torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); - // ---- SINGLE FRAME SEEK AND DECODING API ---- // Places the cursor at the first frame on or after the position in seconds. // Calling getNextFrameOutputNoDemuxInternal() will return the first frame at @@ -232,17 +230,16 @@ class VideoDecoder { DecodedOutput getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, - std::optional preAllocatedOutputTensor = std::nullopt); + torch::Tensor preAllocatedOutputTensor); + struct BatchDecodedOutput { torch::Tensor frames; torch::Tensor ptsSeconds; torch::Tensor durationSeconds; - - explicit BatchDecodedOutput( - int64_t numFrames, - const VideoStreamDecoderOptions& options, - const StreamMetadata& metadata); }; + BatchDecodedOutput allocateBatchDecodedOutput( + int streamIndex, + int64_t numFrames); // Returns frames at the given indices for a given stream as a single stacked // Tensor. BatchDecodedOutput getFramesAtIndices( @@ -301,6 +298,14 @@ class VideoDecoder { double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex); + // -------------------------------------------------------------------------- + // Tensor (frames) manipulation APIs + // -------------------------------------------------------------------------- + torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + torch::Tensor allocateEmptyHWCTensorForStream( + int streamIndex, + std::optional numFrames = std::nullopt); + private: struct FrameInfo { int64_t pts = 0; @@ -385,14 +390,14 @@ class VideoDecoder { void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput); DecodedOutput convertAVFrameToDecodedOutput( RawDecodedOutput& rawOutput, - std::optional preAllocatedOutputTensor = std::nullopt); + torch::Tensor preAllocatedOutputTensor); void convertAVFrameToDecodedOutputOnCPU( RawDecodedOutput& rawOutput, DecodedOutput& output, - std::optional preAllocatedOutputTensor = std::nullopt); + torch::Tensor preAllocatedOutputTensor); DecodedOutput getNextFrameOutputNoDemuxInternal( - std::optional preAllocatedOutputTensor = std::nullopt); + torch::Tensor preAllocatedOutputTensor); DecoderOptions options_; ContainerMetadata containerMetadata_;