From 0143d4ef2060d77b98d0202e66ff4651d5227022 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Wed, 26 Feb 2025 01:56:45 +0000 Subject: [PATCH 1/5] Make device interface generic Fixes: #605 Changes: * Device interface made device agnostic by intorducing `class DeviceInterface` from which specific backends should inherit their device specific implementations * Implemented `CudaDevice` derived from `DeviceInterface` * Created device interface registration mechanism (`registerDeviceInterface`) * Created device interface creation mechanism (`createDeviceInterface`) These changes allow to replace CUDA specific code in `VideoDecoder.cpp` and `VideoDecoderOps.cpp` by device agnostic code. Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/CMakeLists.txt | 3 +- src/torchcodec/_core/CPUOnlyDevice.cpp | 45 --------------- src/torchcodec/_core/CudaDevice.cpp | 48 +++++++--------- src/torchcodec/_core/CudaDevice.h | 33 +++++++++++ src/torchcodec/_core/DeviceInterface.cpp | 56 +++++++++++++++++++ src/torchcodec/_core/DeviceInterface.h | 58 +++++++++++++------- src/torchcodec/_core/SingleStreamDecoder.cpp | 49 +++++++---------- src/torchcodec/_core/SingleStreamDecoder.h | 5 +- src/torchcodec/_core/custom_ops.cpp | 13 +---- 9 files changed, 174 insertions(+), 136 deletions(-) delete mode 100644 src/torchcodec/_core/CPUOnlyDevice.cpp create mode 100644 src/torchcodec/_core/CudaDevice.h create mode 100644 src/torchcodec/_core/DeviceInterface.cpp diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index abec1d21..1f42e24d 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -60,6 +60,7 @@ function(make_torchcodec_libraries set(decoder_sources AVIOContextHolder.cpp FFMPEGCommon.cpp + DeviceInterface.cpp SingleStreamDecoder.cpp # TODO: lib name should probably not be "*_decoder*" now that it also # contains an encoder @@ -68,8 +69,6 @@ function(make_torchcodec_libraries if(ENABLE_CUDA) list(APPEND decoder_sources CudaDevice.cpp) - else() - list(APPEND decoder_sources CPUOnlyDevice.cpp) endif() set(decoder_library_dependencies diff --git a/src/torchcodec/_core/CPUOnlyDevice.cpp b/src/torchcodec/_core/CPUOnlyDevice.cpp deleted file mode 100644 index 1d5b477d..00000000 --- a/src/torchcodec/_core/CPUOnlyDevice.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include -#include "src/torchcodec/_core/DeviceInterface.h" - -namespace facebook::torchcodec { - -// This file is linked with the CPU-only version of torchcodec. -// So all functions will throw an error because they should only be called if -// the device is not CPU. - -[[noreturn]] void throwUnsupportedDeviceError(const torch::Device& device) { - TORCH_CHECK( - device.type() != torch::kCPU, - "Device functions should only be called if the device is not CPU.") - TORCH_CHECK(false, "Unsupported device: " + device.str()); -} - -void convertAVFrameToFrameOutputOnCuda( - const torch::Device& device, - [[maybe_unused]] const SingleStreamDecoder::VideoStreamOptions& - videoStreamOptions, - [[maybe_unused]] UniqueAVFrame& avFrame, - [[maybe_unused]] SingleStreamDecoder::FrameOutput& frameOutput, - [[maybe_unused]] std::optional preAllocatedOutputTensor) { - throwUnsupportedDeviceError(device); -} - -void initializeContextOnCuda( - const torch::Device& device, - [[maybe_unused]] AVCodecContext* codecContext) { - throwUnsupportedDeviceError(device); -} - -void releaseContextOnCuda( - const torch::Device& device, - [[maybe_unused]] AVCodecContext* codecContext) { - throwUnsupportedDeviceError(device); -} - -std::optional findCudaCodec( - const torch::Device& device, - [[maybe_unused]] const AVCodecID& codecId) { - throwUnsupportedDeviceError(device); -} - -} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDevice.cpp b/src/torchcodec/_core/CudaDevice.cpp index fd8be9de..98714ee5 100644 --- a/src/torchcodec/_core/CudaDevice.cpp +++ b/src/torchcodec/_core/CudaDevice.cpp @@ -4,7 +4,7 @@ #include #include -#include "src/torchcodec/_core/DeviceInterface.h" +#include "src/torchcodec/_core/CudaDevice.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" @@ -16,6 +16,10 @@ extern "C" { namespace facebook::torchcodec { namespace { +bool g_cuda = registerDeviceInterface("cuda", [](const std::string& device) { + return new CudaDevice(device); +}); + // We reuse cuda contexts across VideoDeoder instances. This is because // creating a cuda context is expensive. The cache mechanism is as follows: // 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for @@ -158,39 +162,29 @@ AVBufferRef* getCudaContext(const torch::Device& device) { device, nonNegativeDeviceIndex, type); #endif } +} // namespace -void throwErrorIfNonCudaDevice(const torch::Device& device) { - TORCH_CHECK( - device.type() != torch::kCPU, - "Device functions should only be called if the device is not CPU.") - if (device.type() != torch::kCUDA) { - throw std::runtime_error("Unsupported device: " + device.str()); +CudaDevice::CudaDevice(const std::string& device) : DeviceInterface(device) { + if (device_.type() != torch::kCUDA) { + throw std::runtime_error("Unsupported device: " + device_.str()); } } -} // namespace -void releaseContextOnCuda( - const torch::Device& device, - AVCodecContext* codecContext) { - throwErrorIfNonCudaDevice(device); - addToCacheIfCacheHasCapacity(device, codecContext); +void CudaDevice::releaseContext(AVCodecContext* codecContext) { + addToCacheIfCacheHasCapacity(device_, codecContext); } -void initializeContextOnCuda( - const torch::Device& device, - AVCodecContext* codecContext) { - throwErrorIfNonCudaDevice(device); +void CudaDevice::initializeContext(AVCodecContext* codecContext) { // It is important for pytorch itself to create the cuda context. If ffmpeg // creates the context it may not be compatible with pytorch. // This is a dummy tensor to initialize the cuda context. torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); - codecContext->hw_device_ctx = getCudaContext(device); + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + codecContext->hw_device_ctx = getCudaContext(device_); return; } -void convertAVFrameToFrameOutputOnCuda( - const torch::Device& device, +void CudaDevice::convertAVFrameToFrameOutput( const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, UniqueAVFrame& avFrame, SingleStreamDecoder::FrameOutput& frameOutput, @@ -217,11 +211,11 @@ void convertAVFrameToFrameOutputOnCuda( "x3, got ", shape); } else { - dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device); + dst = allocateEmptyHWCTensor(height, width, device_); } // Use the user-requested GPU for running the NPP kernel. - c10::cuda::CUDAGuard deviceGuard(device); + c10::cuda::CUDAGuard deviceGuard(device_); NppiSize oSizeROI = {width, height}; Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]}; @@ -249,7 +243,7 @@ void convertAVFrameToFrameOutputOnCuda( // output. at::cuda::CUDAEvent nppDoneEvent; at::cuda::CUDAStream nppStreamWrapper = - c10::cuda::getStreamFromExternal(nppGetStream(), device.index()); + c10::cuda::getStreamFromExternal(nppGetStream(), device_.index()); nppDoneEvent.record(nppStreamWrapper); nppDoneEvent.block(at::cuda::getCurrentCUDAStream()); @@ -264,11 +258,7 @@ void convertAVFrameToFrameOutputOnCuda( // we have to do this because of an FFmpeg bug where hardware decoding is not // appropriately set, so we just go off and find the matching codec for the CUDA // device -std::optional findCudaCodec( - const torch::Device& device, - const AVCodecID& codecId) { - throwErrorIfNonCudaDevice(device); - +std::optional CudaDevice::findCodec(const AVCodecID& codecId) { void* i = nullptr; const AVCodec* codec = nullptr; while ((codec = av_codec_iterate(&i)) != nullptr) { diff --git a/src/torchcodec/_core/CudaDevice.h b/src/torchcodec/_core/CudaDevice.h new file mode 100644 index 00000000..14bed9c8 --- /dev/null +++ b/src/torchcodec/_core/CudaDevice.h @@ -0,0 +1,33 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/DeviceInterface.h" + +namespace facebook::torchcodec { + +class CudaDevice : public DeviceInterface { + public: + CudaDevice(const std::string& device); + + virtual ~CudaDevice(){}; + + std::optional findCodec(const AVCodecID& codecId) override; + + void initializeContext(AVCodecContext* codecContext) override; + + void convertAVFrameToFrameOutput( + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + UniqueAVFrame& avFrame, + SingleStreamDecoder::FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = + std::nullopt) override; + + void releaseContext(AVCodecContext* codecContext) override; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp new file mode 100644 index 00000000..a990cb65 --- /dev/null +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -0,0 +1,56 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/DeviceInterface.h" +#include +#include + +namespace facebook::torchcodec { + +namespace { +std::mutex g_interface_mutex; +std::map g_interface_map; + +std::string getDeviceType(const std::string& device) { + size_t pos = device.find(':'); + if (pos == std::string::npos) { + return device; + } + return device.substr(0, pos); +} + +} // namespace + +bool registerDeviceInterface( + const std::string deviceType, + CreateDeviceInterfaceFn createInterface) { + std::scoped_lock lock(g_interface_mutex); + TORCH_CHECK( + g_interface_map.find(deviceType) == g_interface_map.end(), + "Device interface already registered for ", + deviceType); + g_interface_map.insert({deviceType, createInterface}); + return true; +} + +std::shared_ptr createDeviceInterface( + const std::string device) { + // TODO: remove once DeviceInterface for CPU is implemented + if (device == "cpu") { + return nullptr; + } + + std::scoped_lock lock(g_interface_mutex); + std::string deviceType = getDeviceType(device); + TORCH_CHECK( + g_interface_map.find(deviceType) != g_interface_map.end(), + "Unsupported device: ", + device); + + return std::shared_ptr(g_interface_map[deviceType](device)); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 352b83d3..906e8181 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -7,6 +7,7 @@ #pragma once #include +#include #include #include #include @@ -23,25 +24,42 @@ namespace facebook::torchcodec { // deviceFunction(device, ...); // } -// Initialize the hardware device that is specified in `device`. Some builds -// support CUDA and others only support CPU. -void initializeContextOnCuda( - const torch::Device& device, - AVCodecContext* codecContext); - -void convertAVFrameToFrameOutputOnCuda( - const torch::Device& device, - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - UniqueAVFrame& avFrame, - SingleStreamDecoder::FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = std::nullopt); - -void releaseContextOnCuda( - const torch::Device& device, - AVCodecContext* codecContext); - -std::optional findCudaCodec( - const torch::Device& device, - const AVCodecID& codecId); +class DeviceInterface { + public: + DeviceInterface(const std::string& device) : device_(device) {} + + virtual ~DeviceInterface(){}; + + torch::Device& device() { + return device_; + }; + + virtual std::optional findCodec(const AVCodecID& codecId) = 0; + + // Initialize the hardware device that is specified in `device`. Some builds + // support CUDA and others only support CPU. + virtual void initializeContext(AVCodecContext* codecContext) = 0; + + virtual void convertAVFrameToFrameOutput( + const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, + UniqueAVFrame& avFrame, + SingleStreamDecoder::FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt) = 0; + + virtual void releaseContext(AVCodecContext* codecContext) = 0; + + protected: + torch::Device device_; +}; + +using CreateDeviceInterfaceFn = + std::function; + +bool registerDeviceInterface( + const std::string deviceType, + const CreateDeviceInterfaceFn createInterface); + +std::shared_ptr createDeviceInterface( + const std::string device); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index b7438f19..f535a2d9 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -95,11 +95,8 @@ SingleStreamDecoder::SingleStreamDecoder( SingleStreamDecoder::~SingleStreamDecoder() { for (auto& [streamIndex, streamInfo] : streamInfos_) { auto& device = streamInfo.videoStreamOptions.device; - if (device.type() == torch::kCPU) { - } else if (device.type() == torch::kCUDA) { - releaseContextOnCuda(device, streamInfo.codecContext.get()); - } else { - TORCH_CHECK(false, "Invalid device type: " + device.str()); + if (device) { + device->releaseContext(streamInfo.codecContext.get()); } } } @@ -389,7 +386,7 @@ torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { void SingleStreamDecoder::addStream( int streamIndex, AVMediaType mediaType, - const torch::Device& device, + DeviceInterface* device, std::optional ffmpegThreadCount) { TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, @@ -427,10 +424,12 @@ void SingleStreamDecoder::addStream( // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic - if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - findCudaCodec(device, streamInfo.stream->codecpar->codec_id) - .value_or(avCodec)); + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (device) { + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + device->findCodec(streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); + } } AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); @@ -445,8 +444,10 @@ void SingleStreamDecoder::addStream( streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; // TODO_CODE_QUALITY same as above. - if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - initializeContextOnCuda(device, codecContext); + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (device) { + device->initializeContext(codecContext); + } } retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); @@ -472,15 +473,10 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions) { - TORCH_CHECK( - videoStreamOptions.device.type() == torch::kCPU || - videoStreamOptions.device.type() == torch::kCUDA, - "Invalid device type: " + videoStreamOptions.device.str()); - addStream( streamIndex, AVMEDIA_TYPE_VIDEO, - videoStreamOptions.device, + videoStreamOptions.device.get(), videoStreamOptions.ffmpegThreadCount); auto& streamMetadata = @@ -1221,20 +1217,15 @@ SingleStreamDecoder::convertAVFrameToFrameOutput( formatContext_->streams[activeStreamIndex_]->time_base); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { + } else if (!streamInfo.videoStreamOptions.device) { convertAVFrameToFrameOutputOnCPU( avFrame, frameOutput, preAllocatedOutputTensor); - } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { - convertAVFrameToFrameOutputOnCuda( - streamInfo.videoStreamOptions.device, + } else if (streamInfo.videoStreamOptions.device) { + streamInfo.videoStreamOptions.device->convertAVFrameToFrameOutput( streamInfo.videoStreamOptions, avFrame, frameOutput, preAllocatedOutputTensor); - } else { - TORCH_CHECK( - false, - "Invalid device type: " + streamInfo.videoStreamOptions.device.str()); } return frameOutput; } @@ -1573,8 +1564,10 @@ SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput( videoStreamOptions, streamMetadata); int height = frameDims.height; int width = frameDims.width; - data = allocateEmptyHWCTensor( - height, width, videoStreamOptions.device, numFrames); + torch::Device device = (videoStreamOptions.device) + ? videoStreamOptions.device->device() + : torch::kCPU; + data = allocateEmptyHWCTensor(height, width, device, numFrames); } torch::Tensor allocateEmptyHWCTensor( diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index f712cdbb..9bac9ba6 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -16,6 +16,7 @@ #include "src/torchcodec/_core/FFMPEGCommon.h" namespace facebook::torchcodec { +class DeviceInterface; // The SingleStreamDecoder class can be used to decode video frames to Tensors. // Note that SingleStreamDecoder is not thread-safe. @@ -138,7 +139,7 @@ class SingleStreamDecoder { std::optional height; std::optional colorConversionLibrary; // By default we use CPU for decoding for both C++ and python users. - torch::Device device = torch::kCPU; + std::shared_ptr device; }; struct AudioStreamOptions { @@ -459,7 +460,7 @@ class SingleStreamDecoder { void addStream( int streamIndex, AVMediaType mediaType, - const torch::Device& device = torch::kCPU, + DeviceInterface* device = nullptr, std::optional ffmpegThreadCount = std::nullopt); // Returns the "best" stream index for a given media type. The "best" is diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 596412a8..6f520cfe 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -11,6 +11,7 @@ #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" @@ -242,16 +243,8 @@ void _add_video_stream( } } if (device.has_value()) { - if (device.value() == "cpu") { - videoStreamOptions.device = torch::Device(torch::kCPU); - } else if (device.value().rfind("cuda", 0) == 0) { // starts with "cuda" - std::string deviceStr(device.value()); - videoStreamOptions.device = torch::Device(deviceStr); - } else { - throw std::runtime_error( - "Invalid device=" + std::string(device.value()) + - ". device must be either cpu or cuda."); - } + videoStreamOptions.device = + createDeviceInterface(std::string(device.value())); } auto videoDecoder = unwrapTensorToGetDecoder(decoder); From afdca1febdec7baeb6e2850e8a21f21233fbd201 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Tue, 1 Apr 2025 20:46:44 +0000 Subject: [PATCH 2/5] Use std::unique_ptr to create DeviceInterface Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/DeviceInterface.cpp | 4 ++-- src/torchcodec/_core/DeviceInterface.h | 2 +- test/VideoDecoderTest.cpp | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index a990cb65..a67302b9 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -36,7 +36,7 @@ bool registerDeviceInterface( return true; } -std::shared_ptr createDeviceInterface( +std::unique_ptr createDeviceInterface( const std::string device) { // TODO: remove once DeviceInterface for CPU is implemented if (device == "cpu") { @@ -50,7 +50,7 @@ std::shared_ptr createDeviceInterface( "Unsupported device: ", device); - return std::shared_ptr(g_interface_map[deviceType](device)); + return std::unique_ptr(g_interface_map[deviceType](device)); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 906e8181..9e5e00d0 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -59,7 +59,7 @@ bool registerDeviceInterface( const std::string deviceType, const CreateDeviceInterfaceFn createInterface); -std::shared_ptr createDeviceInterface( +std::unique_ptr createDeviceInterface( const std::string device); } // namespace facebook::torchcodec diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index a7ad4c6d..1937ff97 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" #include From 3edf0ae824ac9afb7ac0d745977ac95a4c1bc761 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Tue, 1 Apr 2025 23:21:23 +0000 Subject: [PATCH 3/5] Create DeviceInterface in addStream Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/CudaDevice.cpp | 8 ++--- src/torchcodec/_core/CudaDevice.h | 2 +- src/torchcodec/_core/DeviceInterface.cpp | 31 ++++++++++++++++--- src/torchcodec/_core/DeviceInterface.h | 10 +++--- src/torchcodec/_core/SingleStreamDecoder.cpp | 32 ++++++++++---------- src/torchcodec/_core/SingleStreamDecoder.h | 6 ++-- src/torchcodec/_core/custom_ops.cpp | 3 +- 7 files changed, 58 insertions(+), 34 deletions(-) diff --git a/src/torchcodec/_core/CudaDevice.cpp b/src/torchcodec/_core/CudaDevice.cpp index 98714ee5..f407f5a5 100644 --- a/src/torchcodec/_core/CudaDevice.cpp +++ b/src/torchcodec/_core/CudaDevice.cpp @@ -16,9 +16,9 @@ extern "C" { namespace facebook::torchcodec { namespace { -bool g_cuda = registerDeviceInterface("cuda", [](const std::string& device) { - return new CudaDevice(device); -}); +bool g_cuda = registerDeviceInterface( + torch::kCUDA, + [](const torch::Device& device) { return new CudaDevice(device); }); // We reuse cuda contexts across VideoDeoder instances. This is because // creating a cuda context is expensive. The cache mechanism is as follows: @@ -164,7 +164,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) { } } // namespace -CudaDevice::CudaDevice(const std::string& device) : DeviceInterface(device) { +CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) { if (device_.type() != torch::kCUDA) { throw std::runtime_error("Unsupported device: " + device_.str()); } diff --git a/src/torchcodec/_core/CudaDevice.h b/src/torchcodec/_core/CudaDevice.h index 14bed9c8..e40a7748 100644 --- a/src/torchcodec/_core/CudaDevice.h +++ b/src/torchcodec/_core/CudaDevice.h @@ -12,7 +12,7 @@ namespace facebook::torchcodec { class CudaDevice : public DeviceInterface { public: - CudaDevice(const std::string& device); + CudaDevice(const torch::Device& device); virtual ~CudaDevice(){}; diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index a67302b9..7612334b 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -12,7 +12,7 @@ namespace facebook::torchcodec { namespace { std::mutex g_interface_mutex; -std::map g_interface_map; +std::map g_interface_map; std::string getDeviceType(const std::string& device) { size_t pos = device.find(':'); @@ -25,7 +25,7 @@ std::string getDeviceType(const std::string& device) { } // namespace bool registerDeviceInterface( - const std::string deviceType, + torch::DeviceType deviceType, CreateDeviceInterfaceFn createInterface) { std::scoped_lock lock(g_interface_mutex); TORCH_CHECK( @@ -36,15 +36,36 @@ bool registerDeviceInterface( return true; } -std::unique_ptr createDeviceInterface( - const std::string device) { +torch::Device createTorchDevice(const std::string device) { // TODO: remove once DeviceInterface for CPU is implemented if (device == "cpu") { - return nullptr; + return torch::kCPU; } std::scoped_lock lock(g_interface_mutex); std::string deviceType = getDeviceType(device); + auto deviceInterface = std::find_if( + g_interface_map.begin(), + g_interface_map.end(), + [&](const std::pair& arg) { + return device.rfind( + torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0; + }); + TORCH_CHECK( + deviceInterface != g_interface_map.end(), "Unsupported device: ", device); + + return torch::Device(device); +} + +std::unique_ptr createDeviceInterface( + const torch::Device& device) { + auto deviceType = device.type(); + // TODO: remove once DeviceInterface for CPU is implemented + if (deviceType == torch::kCPU) { + return nullptr; + } + + std::scoped_lock lock(g_interface_mutex); TORCH_CHECK( g_interface_map.find(deviceType) != g_interface_map.end(), "Unsupported device: ", diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 9e5e00d0..558a11ed 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -26,7 +26,7 @@ namespace facebook::torchcodec { class DeviceInterface { public: - DeviceInterface(const std::string& device) : device_(device) {} + DeviceInterface(const torch::Device& device) : device_(device) {} virtual ~DeviceInterface(){}; @@ -53,13 +53,15 @@ class DeviceInterface { }; using CreateDeviceInterfaceFn = - std::function; + std::function; bool registerDeviceInterface( - const std::string deviceType, + torch::DeviceType deviceType, const CreateDeviceInterfaceFn createInterface); +torch::Device createTorchDevice(const std::string device); + std::unique_ptr createDeviceInterface( - const std::string device); + const torch::Device& device); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index f535a2d9..5b98cbab 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -94,9 +94,9 @@ SingleStreamDecoder::SingleStreamDecoder( SingleStreamDecoder::~SingleStreamDecoder() { for (auto& [streamIndex, streamInfo] : streamInfos_) { - auto& device = streamInfo.videoStreamOptions.device; - if (device) { - device->releaseContext(streamInfo.codecContext.get()); + auto& deviceInterface = streamInfo.deviceInterface; + if (deviceInterface) { + deviceInterface->releaseContext(streamInfo.codecContext.get()); } } } @@ -386,7 +386,7 @@ torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { void SingleStreamDecoder::addStream( int streamIndex, AVMediaType mediaType, - DeviceInterface* device, + const torch::Device& device, std::optional ffmpegThreadCount) { TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, @@ -414,6 +414,7 @@ void SingleStreamDecoder::addStream( streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base; streamInfo.stream = formatContext_->streams[activeStreamIndex_]; streamInfo.avMediaType = mediaType; + streamInfo.deviceInterface = createDeviceInterface(device); // This should never happen, checking just to be safe. TORCH_CHECK( @@ -425,9 +426,10 @@ void SingleStreamDecoder::addStream( // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (device) { + if (streamInfo.deviceInterface) { avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - device->findCodec(streamInfo.stream->codecpar->codec_id) + streamInfo.deviceInterface + ->findCodec(streamInfo.stream->codecpar->codec_id) .value_or(avCodec)); } } @@ -445,8 +447,8 @@ void SingleStreamDecoder::addStream( // TODO_CODE_QUALITY same as above. if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (device) { - device->initializeContext(codecContext); + if (streamInfo.deviceInterface) { + streamInfo.deviceInterface->initializeContext(codecContext); } } @@ -476,7 +478,7 @@ void SingleStreamDecoder::addVideoStream( addStream( streamIndex, AVMEDIA_TYPE_VIDEO, - videoStreamOptions.device.get(), + videoStreamOptions.device, videoStreamOptions.ffmpegThreadCount); auto& streamMetadata = @@ -1217,11 +1219,11 @@ SingleStreamDecoder::convertAVFrameToFrameOutput( formatContext_->streams[activeStreamIndex_]->time_base); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else if (!streamInfo.videoStreamOptions.device) { + } else if (!streamInfo.deviceInterface) { convertAVFrameToFrameOutputOnCPU( avFrame, frameOutput, preAllocatedOutputTensor); - } else if (streamInfo.videoStreamOptions.device) { - streamInfo.videoStreamOptions.device->convertAVFrameToFrameOutput( + } else if (streamInfo.deviceInterface) { + streamInfo.deviceInterface->convertAVFrameToFrameOutput( streamInfo.videoStreamOptions, avFrame, frameOutput, @@ -1564,10 +1566,8 @@ SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput( videoStreamOptions, streamMetadata); int height = frameDims.height; int width = frameDims.width; - torch::Device device = (videoStreamOptions.device) - ? videoStreamOptions.device->device() - : torch::kCPU; - data = allocateEmptyHWCTensor(height, width, device, numFrames); + data = allocateEmptyHWCTensor( + height, width, videoStreamOptions.device, numFrames); } torch::Tensor allocateEmptyHWCTensor( diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 9bac9ba6..253bf2e1 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -139,7 +139,7 @@ class SingleStreamDecoder { std::optional height; std::optional colorConversionLibrary; // By default we use CPU for decoding for both C++ and python users. - std::shared_ptr device; + torch::Device device = torch::kCPU; }; struct AudioStreamOptions { @@ -358,6 +358,8 @@ class SingleStreamDecoder { // Used to know whether a new FilterGraphContext or UniqueSwsContext should // be created before decoding a new frame. DecodedFrameContext prevFrameContext; + + std::unique_ptr deviceInterface; }; // -------------------------------------------------------------------------- @@ -460,7 +462,7 @@ class SingleStreamDecoder { void addStream( int streamIndex, AVMediaType mediaType, - DeviceInterface* device = nullptr, + const torch::Device& device = torch::kCPU, std::optional ffmpegThreadCount = std::nullopt); // Returns the "best" stream index for a given media type. The "best" is diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 6f520cfe..05a6390d 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -243,8 +243,7 @@ void _add_video_stream( } } if (device.has_value()) { - videoStreamOptions.device = - createDeviceInterface(std::string(device.value())); + videoStreamOptions.device = createTorchDevice(std::string(device.value())); } auto videoDecoder = unwrapTensorToGetDecoder(decoder); From 32d9598c914406ff67f1520a8770c91846d8d3fc Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Tue, 1 Apr 2025 23:44:40 +0000 Subject: [PATCH 4/5] Drop releaseContext Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/CudaDevice.cpp | 17 +++++++++++------ src/torchcodec/_core/CudaDevice.h | 5 +++-- src/torchcodec/_core/DeviceInterface.h | 2 -- src/torchcodec/_core/SingleStreamDecoder.cpp | 9 --------- src/torchcodec/_core/SingleStreamDecoder.h | 2 -- 5 files changed, 14 insertions(+), 21 deletions(-) diff --git a/src/torchcodec/_core/CudaDevice.cpp b/src/torchcodec/_core/CudaDevice.cpp index f407f5a5..5bde4106 100644 --- a/src/torchcodec/_core/CudaDevice.cpp +++ b/src/torchcodec/_core/CudaDevice.cpp @@ -53,7 +53,7 @@ torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) { void addToCacheIfCacheHasCapacity( const torch::Device& device, - AVCodecContext* codecContext) { + AVBufferRef* hwContext) { torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); if (static_cast(deviceIndex) >= MAX_CUDA_GPUS) { return; @@ -64,8 +64,7 @@ void addToCacheIfCacheHasCapacity( MAX_CONTEXTS_PER_GPU_IN_CACHE) { return; } - g_cached_hw_device_ctxs[deviceIndex].push_back(codecContext->hw_device_ctx); - codecContext->hw_device_ctx = nullptr; + g_cached_hw_device_ctxs[deviceIndex].push_back(av_buffer_ref(hwContext)); } AVBufferRef* getFromCache(const torch::Device& device) { @@ -170,17 +169,23 @@ CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) { } } -void CudaDevice::releaseContext(AVCodecContext* codecContext) { - addToCacheIfCacheHasCapacity(device_, codecContext); +CudaDevice::~CudaDevice() { + if (ctx_) { + addToCacheIfCacheHasCapacity(device_, ctx_); + av_buffer_unref(&ctx_); + } } void CudaDevice::initializeContext(AVCodecContext* codecContext) { + TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + // It is important for pytorch itself to create the cuda context. If ffmpeg // creates the context it may not be compatible with pytorch. // This is a dummy tensor to initialize the cuda context. torch::Tensor dummyTensorForCudaInitialization = torch::empty( {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); - codecContext->hw_device_ctx = getCudaContext(device_); + ctx_ = getCudaContext(device_); + codecContext->hw_device_ctx = av_buffer_ref(ctx_); return; } diff --git a/src/torchcodec/_core/CudaDevice.h b/src/torchcodec/_core/CudaDevice.h index e40a7748..0ed53859 100644 --- a/src/torchcodec/_core/CudaDevice.h +++ b/src/torchcodec/_core/CudaDevice.h @@ -14,7 +14,7 @@ class CudaDevice : public DeviceInterface { public: CudaDevice(const torch::Device& device); - virtual ~CudaDevice(){}; + virtual ~CudaDevice(); std::optional findCodec(const AVCodecID& codecId) override; @@ -27,7 +27,8 @@ class CudaDevice : public DeviceInterface { std::optional preAllocatedOutputTensor = std::nullopt) override; - void releaseContext(AVCodecContext* codecContext) override; + private: + AVBufferRef* ctx_ = nullptr; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 558a11ed..a5b0e365 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -46,8 +46,6 @@ class DeviceInterface { SingleStreamDecoder::FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; - virtual void releaseContext(AVCodecContext* codecContext) = 0; - protected: torch::Device device_; }; diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 5b98cbab..4993cd81 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -92,15 +92,6 @@ SingleStreamDecoder::SingleStreamDecoder( initializeDecoder(); } -SingleStreamDecoder::~SingleStreamDecoder() { - for (auto& [streamIndex, streamInfo] : streamInfos_) { - auto& deviceInterface = streamInfo.deviceInterface; - if (deviceInterface) { - deviceInterface->releaseContext(streamInfo.codecContext.get()); - } - } -} - void SingleStreamDecoder::initializeDecoder() { TORCH_CHECK(!initialized_, "Attempted double initialization."); diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 253bf2e1..5962f254 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -23,8 +23,6 @@ class DeviceInterface; // Do not call non-const APIs concurrently on the same object. class SingleStreamDecoder { public: - ~SingleStreamDecoder(); - // -------------------------------------------------------------------------- // CONSTRUCTION API // -------------------------------------------------------------------------- From 79c84c65a9bac4efe4aae10c211d8b821086332d Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Fri, 4 Apr 2025 16:37:24 +0000 Subject: [PATCH 5/5] Move device interface from stream to decoder Signed-off-by: Dmitry Rogozhkin --- src/torchcodec/_core/SingleStreamDecoder.cpp | 18 +++++++++--------- src/torchcodec/_core/SingleStreamDecoder.h | 3 +-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 4993cd81..c7c714da 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -405,7 +405,8 @@ void SingleStreamDecoder::addStream( streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base; streamInfo.stream = formatContext_->streams[activeStreamIndex_]; streamInfo.avMediaType = mediaType; - streamInfo.deviceInterface = createDeviceInterface(device); + + deviceInterface = createDeviceInterface(device); // This should never happen, checking just to be safe. TORCH_CHECK( @@ -417,10 +418,9 @@ void SingleStreamDecoder::addStream( // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (streamInfo.deviceInterface) { + if (deviceInterface) { avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - streamInfo.deviceInterface - ->findCodec(streamInfo.stream->codecpar->codec_id) + deviceInterface->findCodec(streamInfo.stream->codecpar->codec_id) .value_or(avCodec)); } } @@ -438,8 +438,8 @@ void SingleStreamDecoder::addStream( // TODO_CODE_QUALITY same as above. if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (streamInfo.deviceInterface) { - streamInfo.deviceInterface->initializeContext(codecContext); + if (deviceInterface) { + deviceInterface->initializeContext(codecContext); } } @@ -1210,11 +1210,11 @@ SingleStreamDecoder::convertAVFrameToFrameOutput( formatContext_->streams[activeStreamIndex_]->time_base); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); - } else if (!streamInfo.deviceInterface) { + } else if (!deviceInterface) { convertAVFrameToFrameOutputOnCPU( avFrame, frameOutput, preAllocatedOutputTensor); - } else if (streamInfo.deviceInterface) { - streamInfo.deviceInterface->convertAVFrameToFrameOutput( + } else if (deviceInterface) { + deviceInterface->convertAVFrameToFrameOutput( streamInfo.videoStreamOptions, avFrame, frameOutput, diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 5962f254..4879a3b7 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -356,8 +356,6 @@ class SingleStreamDecoder { // Used to know whether a new FilterGraphContext or UniqueSwsContext should // be created before decoding a new frame. DecodedFrameContext prevFrameContext; - - std::unique_ptr deviceInterface; }; // -------------------------------------------------------------------------- @@ -494,6 +492,7 @@ class SingleStreamDecoder { SeekMode seekMode_; ContainerMetadata containerMetadata_; UniqueDecodingAVFormatContext formatContext_; + std::unique_ptr deviceInterface; std::map streamInfos_; const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM;