Skip to content

Make device interface generic #606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function(make_torchcodec_libraries
set(decoder_sources
AVIOContextHolder.cpp
FFMPEGCommon.cpp
DeviceInterface.cpp
SingleStreamDecoder.cpp
# TODO: lib name should probably not be "*_decoder*" now that it also
# contains an encoder
Expand All @@ -68,8 +69,6 @@ function(make_torchcodec_libraries

if(ENABLE_CUDA)
list(APPEND decoder_sources CudaDevice.cpp)
else()
list(APPEND decoder_sources CPUOnlyDevice.cpp)
endif()

set(decoder_library_dependencies
Expand Down
45 changes: 0 additions & 45 deletions src/torchcodec/_core/CPUOnlyDevice.cpp

This file was deleted.

59 changes: 27 additions & 32 deletions src/torchcodec/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <torch/types.h>
#include <mutex>

#include "src/torchcodec/_core/DeviceInterface.h"
#include "src/torchcodec/_core/CudaDevice.h"
#include "src/torchcodec/_core/FFMPEGCommon.h"
#include "src/torchcodec/_core/SingleStreamDecoder.h"

Expand All @@ -16,6 +16,10 @@ extern "C" {
namespace facebook::torchcodec {
namespace {

bool g_cuda = registerDeviceInterface(
torch::kCUDA,
[](const torch::Device& device) { return new CudaDevice(device); });

// We reuse cuda contexts across VideoDeoder instances. This is because
// creating a cuda context is expensive. The cache mechanism is as follows:
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
Expand Down Expand Up @@ -49,7 +53,7 @@ torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) {

void addToCacheIfCacheHasCapacity(
const torch::Device& device,
AVCodecContext* codecContext) {
AVBufferRef* hwContext) {
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
if (static_cast<int>(deviceIndex) >= MAX_CUDA_GPUS) {
return;
Expand All @@ -60,8 +64,7 @@ void addToCacheIfCacheHasCapacity(
MAX_CONTEXTS_PER_GPU_IN_CACHE) {
return;
}
g_cached_hw_device_ctxs[deviceIndex].push_back(codecContext->hw_device_ctx);
codecContext->hw_device_ctx = nullptr;
g_cached_hw_device_ctxs[deviceIndex].push_back(av_buffer_ref(hwContext));
}

AVBufferRef* getFromCache(const torch::Device& device) {
Expand Down Expand Up @@ -158,39 +161,35 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
device, nonNegativeDeviceIndex, type);
#endif
}
} // namespace

void throwErrorIfNonCudaDevice(const torch::Device& device) {
TORCH_CHECK(
device.type() != torch::kCPU,
"Device functions should only be called if the device is not CPU.")
if (device.type() != torch::kCUDA) {
throw std::runtime_error("Unsupported device: " + device.str());
CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) {
if (device_.type() != torch::kCUDA) {
throw std::runtime_error("Unsupported device: " + device_.str());
}
}
} // namespace

void releaseContextOnCuda(
const torch::Device& device,
AVCodecContext* codecContext) {
throwErrorIfNonCudaDevice(device);
addToCacheIfCacheHasCapacity(device, codecContext);
CudaDevice::~CudaDevice() {
if (ctx_) {
addToCacheIfCacheHasCapacity(device_, ctx_);
av_buffer_unref(&ctx_);
}
}

void initializeContextOnCuda(
const torch::Device& device,
AVCodecContext* codecContext) {
throwErrorIfNonCudaDevice(device);
void CudaDevice::initializeContext(AVCodecContext* codecContext) {
TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized");

// It is important for pytorch itself to create the cuda context. If ffmpeg
// creates the context it may not be compatible with pytorch.
// This is a dummy tensor to initialize the cuda context.
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
codecContext->hw_device_ctx = getCudaContext(device);
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
ctx_ = getCudaContext(device_);
codecContext->hw_device_ctx = av_buffer_ref(ctx_);
return;
}

void convertAVFrameToFrameOutputOnCuda(
const torch::Device& device,
void CudaDevice::convertAVFrameToFrameOutput(
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
UniqueAVFrame& avFrame,
SingleStreamDecoder::FrameOutput& frameOutput,
Expand All @@ -217,11 +216,11 @@ void convertAVFrameToFrameOutputOnCuda(
"x3, got ",
shape);
} else {
dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device);
dst = allocateEmptyHWCTensor(height, width, device_);
}

// Use the user-requested GPU for running the NPP kernel.
c10::cuda::CUDAGuard deviceGuard(device);
c10::cuda::CUDAGuard deviceGuard(device_);

NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
Expand Down Expand Up @@ -249,7 +248,7 @@ void convertAVFrameToFrameOutputOnCuda(
// output.
at::cuda::CUDAEvent nppDoneEvent;
at::cuda::CUDAStream nppStreamWrapper =
c10::cuda::getStreamFromExternal(nppGetStream(), device.index());
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
nppDoneEvent.record(nppStreamWrapper);
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());

Expand All @@ -264,11 +263,7 @@ void convertAVFrameToFrameOutputOnCuda(
// we have to do this because of an FFmpeg bug where hardware decoding is not
// appropriately set, so we just go off and find the matching codec for the CUDA
// device
std::optional<const AVCodec*> findCudaCodec(
const torch::Device& device,
const AVCodecID& codecId) {
throwErrorIfNonCudaDevice(device);

std::optional<const AVCodec*> CudaDevice::findCodec(const AVCodecID& codecId) {
void* i = nullptr;
const AVCodec* codec = nullptr;
while ((codec = av_codec_iterate(&i)) != nullptr) {
Expand Down
34 changes: 34 additions & 0 deletions src/torchcodec/_core/CudaDevice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include "src/torchcodec/_core/DeviceInterface.h"

namespace facebook::torchcodec {

class CudaDevice : public DeviceInterface {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The term "device" should be reserved for the concept of pytorch device object. Here, the CudaDevice class name is misleading because it doesn't refer to a device, it refers to a device interface. @dvrogozh , do you mind submitting a PR to rename CudaDevice into CudaDeviceInterface (file names and file classes)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug : no problem, will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug : filed #626

public:
CudaDevice(const torch::Device& device);

virtual ~CudaDevice();

std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;

void initializeContext(AVCodecContext* codecContext) override;

void convertAVFrameToFrameOutput(
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
UniqueAVFrame& avFrame,
SingleStreamDecoder::FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

private:
AVBufferRef* ctx_ = nullptr;
};

} // namespace facebook::torchcodec
77 changes: 77 additions & 0 deletions src/torchcodec/_core/DeviceInterface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "src/torchcodec/_core/DeviceInterface.h"
#include <map>
#include <mutex>

namespace facebook::torchcodec {

namespace {
std::mutex g_interface_mutex;
std::map<torch::DeviceType, CreateDeviceInterfaceFn> g_interface_map;

std::string getDeviceType(const std::string& device) {
size_t pos = device.find(':');
if (pos == std::string::npos) {
return device;
}
return device.substr(0, pos);
}

} // namespace

bool registerDeviceInterface(
torch::DeviceType deviceType,
CreateDeviceInterfaceFn createInterface) {
std::scoped_lock lock(g_interface_mutex);
TORCH_CHECK(
g_interface_map.find(deviceType) == g_interface_map.end(),
"Device interface already registered for ",
deviceType);
g_interface_map.insert({deviceType, createInterface});
return true;
}

torch::Device createTorchDevice(const std::string device) {
// TODO: remove once DeviceInterface for CPU is implemented
if (device == "cpu") {
return torch::kCPU;
}

std::scoped_lock lock(g_interface_mutex);
std::string deviceType = getDeviceType(device);
auto deviceInterface = std::find_if(
g_interface_map.begin(),
g_interface_map.end(),
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
return device.rfind(
torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
});
TORCH_CHECK(
deviceInterface != g_interface_map.end(), "Unsupported device: ", device);

return torch::Device(device);
}

std::unique_ptr<DeviceInterface> createDeviceInterface(
const torch::Device& device) {
auto deviceType = device.type();
// TODO: remove once DeviceInterface for CPU is implemented
if (deviceType == torch::kCPU) {
return nullptr;
}

std::scoped_lock lock(g_interface_mutex);
TORCH_CHECK(
g_interface_map.find(deviceType) != g_interface_map.end(),
"Unsupported device: ",
device);

return std::unique_ptr<DeviceInterface>(g_interface_map[deviceType](device));
}

} // namespace facebook::torchcodec
58 changes: 38 additions & 20 deletions src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include <torch/types.h>
#include <functional>
#include <memory>
#include <stdexcept>
#include <string>
Expand All @@ -23,25 +24,42 @@ namespace facebook::torchcodec {
// deviceFunction(device, ...);
// }

// Initialize the hardware device that is specified in `device`. Some builds
// support CUDA and others only support CPU.
void initializeContextOnCuda(
const torch::Device& device,
AVCodecContext* codecContext);

void convertAVFrameToFrameOutputOnCuda(
const torch::Device& device,
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
UniqueAVFrame& avFrame,
SingleStreamDecoder::FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

void releaseContextOnCuda(
const torch::Device& device,
AVCodecContext* codecContext);

std::optional<const AVCodec*> findCudaCodec(
const torch::Device& device,
const AVCodecID& codecId);
class DeviceInterface {
public:
DeviceInterface(const torch::Device& device) : device_(device) {}

virtual ~DeviceInterface(){};

torch::Device& device() {
return device_;
};

virtual std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) = 0;

// Initialize the hardware device that is specified in `device`. Some builds
// support CUDA and others only support CPU.
virtual void initializeContext(AVCodecContext* codecContext) = 0;

virtual void convertAVFrameToFrameOutput(
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
UniqueAVFrame& avFrame,
SingleStreamDecoder::FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;

protected:
torch::Device device_;
};

using CreateDeviceInterfaceFn =
std::function<DeviceInterface*(const torch::Device& device)>;

bool registerDeviceInterface(
torch::DeviceType deviceType,
const CreateDeviceInterfaceFn createInterface);

torch::Device createTorchDevice(const std::string device);

std::unique_ptr<DeviceInterface> createDeviceInterface(
const torch::Device& device);

} // namespace facebook::torchcodec
Loading
Loading