Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for CUDA streams in GNN plugin #4012

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BoostTrackBuilding final : public Acts::TrackBuildingBase {
std::vector<std::vector<int>> operator()(
std::any nodes, std::any edges, std::any edge_weights,
std::vector<int> &spacepointIDs,
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext &execContext = {}) override;
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
torch::Device device() const override { return m_device; };

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase {

std::tuple<std::any, std::any, std::any, std::any> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {},
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext &execContext = {}) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase {
std::tuple<std::any, std::any, std::any> operator()(
std::vector<float>& inputValues, std::size_t numNodes,
const std::vector<std::uint64_t>& moduleIds,
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext& execContext = {}) override;
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };
Expand Down
20 changes: 13 additions & 7 deletions Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@
#include <exception>
#include <vector>

#include <c10/cuda/CUDAStream.h>
#include <torch/torch.h>

namespace Acts {

/// Error that is thrown if no edges are found
struct NoEdgesError : std::exception {};

/// Capture the context of the execution
struct ExecutionContext {
torch::Device device{torch::kCPU};
std::optional<c10::cuda::CUDAStream> stream;
};
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved

// TODO maybe replace std::any with some kind of variant<unique_ptr<torch>,
// unique_ptr<onnx>>?
// TODO maybe replace input for GraphConstructionBase with some kind of
Expand All @@ -34,13 +41,12 @@ class GraphConstructionBase {
/// then gives the number of features
/// @param moduleIds Module IDs of the features (used for module-map-like
/// graph construction)
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
///
/// @param execContext Device & stream information
/// @return (node_features, edge_features, edge_index)
virtual std::tuple<std::any, std::any, std::any> operator()(
std::vector<float> &inputValues, std::size_t numNodes,
const std::vector<std::uint64_t> &moduleIds,
torch::Device device = torch::Device(torch::kCPU)) = 0;
const ExecutionContext &execContext = {}) = 0;

virtual torch::Device device() const = 0;

Expand All @@ -54,12 +60,12 @@ class EdgeClassificationBase {
/// @param nodeFeatures Node tensor with shape (n_nodes, n_node_features)
/// @param edgeIndex Edge-index tensor with shape (2, n_edges)
/// @param edgeFeatures Edge-feature tensor with shape (n_edges, n_edge_features)
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
/// @param execContext Device & stream information
///
/// @return (node_features, edge_features, edge_index, edge_scores)
virtual std::tuple<std::any, std::any, std::any, std::any> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {},
torch::Device device = torch::Device(torch::kCPU)) = 0;
const ExecutionContext &execContext = {}) = 0;

virtual torch::Device device() const = 0;

Expand All @@ -74,13 +80,13 @@ class TrackBuildingBase {
/// @param edgeIndex Edge-index tensor with shape (2, n_edges)
/// @param edgeScores Scores of the previous edge classification phase
/// @param spacepointIDs IDs of the nodes (must have size=n_nodes)
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
/// @param execContext Device & stream information
///
/// @return tracks (as vectors of node-IDs)
virtual std::vector<std::vector<int>> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeScores,
std::vector<int> &spacepointIDs,
torch::Device device = torch::Device(torch::kCPU)) = 0;
const ExecutionContext &execContext = {}) = 0;

virtual torch::Device device() const = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase {

std::tuple<std::any, std::any, std::any, std::any> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {},
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext &execContext = {}) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase {
std::tuple<std::any, std::any, std::any> operator()(
std::vector<float> &inputValues, std::size_t numNodes,
const std::vector<std::uint64_t> &moduleIds,
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext &execContext = {}) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };
Expand Down
2 changes: 1 addition & 1 deletion Plugins/ExaTrkX/src/BoostTrackBuilding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace Acts {

std::vector<std::vector<int>> BoostTrackBuilding::operator()(
std::any /*nodes*/, std::any edges, std::any weights,
std::vector<int>& spacepointIDs, torch::Device) {
std::vector<int>& spacepointIDs, const ExecutionContext& execContext) {
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
ACTS_DEBUG("Start track building");
const auto edgeTensor = std::any_cast<torch::Tensor>(edges).to(torch::kCPU);
const auto edgeWeightTensor =
Expand Down
16 changes: 11 additions & 5 deletions Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
std::vector<float> &features, const std::vector<std::uint64_t> &moduleIds,
std::vector<int> &spacepointIDs, const ExaTrkXHook &hook,
ExaTrkXTiming *timing) const {
ExecutionContext ctx;
ctx.device = m_graphConstructor->device();
#ifndef ACTS_EXATRKX_CPUONLY
if (ctx.device.type() == torch::kCUDA) {
ctx.stream = c10::cuda::getStreamFromPool(ctx.device.index());
}
#endif

try {
auto t0 = std::chrono::high_resolution_clock::now();
auto [nodeFeatures, edgeIndex, edgeFeatures] =
(*m_graphConstructor)(features, spacepointIDs.size(), moduleIds,
m_graphConstructor->device());
(*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx);
auto t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand All @@ -59,7 +66,7 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
t0 = std::chrono::high_resolution_clock::now();
auto [newNodeFeatures, newEdgeIndex, newEdgeFeatures, newEdgeScores] =
(*edgeClassifier)(std::move(nodeFeatures), std::move(edgeIndex),
std::move(edgeFeatures), edgeClassifier->device());
std::move(edgeFeatures), ctx);
t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand All @@ -76,8 +83,7 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(

t0 = std::chrono::high_resolution_clock::now();
auto res = (*m_trackBuilder)(std::move(nodeFeatures), std::move(edgeIndex),
std::move(edgeScores), spacepointIDs,
m_trackBuilder->device());
std::move(edgeScores), spacepointIDs, ctx);
t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand Down
3 changes: 2 additions & 1 deletion Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ std::ostream &operator<<(std::ostream &os, Ort::Value &v) {

std::tuple<std::any, std::any, std::any, std::any>
OnnxEdgeClassifier::operator()(std::any inputNodes, std::any inputEdges,
std::any inEdgeFeatures, torch::Device) {
std::any inEdgeFeatures,
const ExecutionContext & /*unused*/) {
auto torchDevice = torch::kCPU;
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
Ort::MemoryInfo memoryInfo("Cpu", OrtArenaAllocator, /*device_id*/ 0,
OrtMemTypeDefault);
Expand Down
12 changes: 9 additions & 3 deletions Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,23 @@ TorchEdgeClassifier::~TorchEdgeClassifier() {}

std::tuple<std::any, std::any, std::any, std::any>
TorchEdgeClassifier::operator()(std::any inNodeFeatures, std::any inEdgeIndex,
std::any inEdgeFeatures, torch::Device device) {
decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4, t5;
std::any inEdgeFeatures,
const ExecutionContext& execContext) {
const auto& device = execContext.device;
decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
t0 = std::chrono::high_resolution_clock::now();
ACTS_DEBUG("Start edge classification, use " << device);
c10::InferenceMode guard(true);

// add a protection to avoid calling for kCPU
#ifndef ACTS_EXATRKX_CPUONLY
#ifdef ACTS_EXATRKX_CPUONLY
assert(device == torch::Device(torch::kCPU));
#else
std::optional<c10::cuda::CUDAGuard> device_guard;
std::optional<c10::cuda::CUDAStreamGuard> streamGuard;
if (device.is_cuda()) {
device_guard.emplace(device.index());
streamGuard.emplace(execContext.stream.value());
}
#endif

Expand Down
10 changes: 8 additions & 2 deletions Plugins/ExaTrkX/src/TorchMetricLearning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,21 @@ TorchMetricLearning::~TorchMetricLearning() {}

std::tuple<std::any, std::any, std::any> TorchMetricLearning::operator()(
std::vector<float> &inputValues, std::size_t numNodes,
const std::vector<std::uint64_t> & /*moduleIds*/, torch::Device device) {
const std::vector<std::uint64_t> & /*moduleIds*/,
const ExecutionContext &execContext) {
const auto &device = execContext.device;
benjaminhuth marked this conversation as resolved.
Show resolved Hide resolved
ACTS_DEBUG("Start graph construction");
c10::InferenceMode guard(true);

// add a protection to avoid calling for kCPU
#ifndef ACTS_EXATRKX_CPUONLY
#ifdef ACTS_EXATRKX_CPUONLY
assert(device == torch::Device(torch::kCPU));
#else
std::optional<c10::cuda::CUDAGuard> device_guard;
std::optional<c10::cuda::CUDAStreamGuard> streamGuard;
if (device.is_cuda()) {
device_guard.emplace(device.index());
streamGuard.emplace(execContext.stream.value());
}
#endif

Expand Down
Loading