Skip to content

Commit f189b76

Browse files
yuanbyupranavsharma
authored andcommitted
Some small edits and renaming. (#153)
1 parent 0ae9354 commit f189b76

File tree

11 files changed

+89
-73
lines changed

11 files changed

+89
-73
lines changed

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class GraphViewer;
1313
} // namespace onnxruntime
1414
namespace onnxruntime {
1515

16-
struct ComputationCapacity;
16+
struct ComputeCapability;
1717
class KernelRegistry;
1818
class KernelRegistryManager;
1919

@@ -50,7 +50,7 @@ class IExecutionProvider {
5050
have overlap, and it's ONNXRuntime's responsibility to do the partition
5151
and decide whether a node will be assigned to <*this> execution provider.
5252
*/
53-
virtual std::vector<std::unique_ptr<ComputationCapacity>>
53+
virtual std::vector<std::unique_ptr<ComputeCapability>>
5454
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
5555
const std::vector<const KernelRegistry*>& kernel_registries) const;
5656

onnxruntime/core/framework/computation_capacity.h

Lines changed: 0 additions & 29 deletions
This file was deleted.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "core/common/common.h"
6+
#include "core/graph/indexed_sub_graph.h"
7+
8+
namespace onnxruntime {
9+
// A structure encodes a subgraph and the method to run it.
10+
struct ComputeCapability {
11+
// The subgraph that an XP can execute, it could contain a single node
12+
// or multiple nodes.
13+
std::unique_ptr<IndexedSubGraph> sub_graph;
14+
15+
// When an execution provider fuses a subgraph into a kernel, it passes
16+
// a kernel create function to onnxruntime so the runtime can create the
17+
// compute kernel for the subgraph. Otherwise onnxruntime will search
18+
// kernels in pre-defined kernel registry provided by XP.
19+
KernelCreateFn fuse_kernel_function;
20+
21+
// TODO: if there is a FusedKernelFn attached, onnxruntime will generate
22+
// the default KernelDefinition for it, according to the OpSchema it
23+
// auto-generates. An execution provider can further set some advanced
24+
// fields on kernel definition, such as memory placement / in-place
25+
// annotation.
26+
ComputeCapability() : sub_graph(nullptr), fuse_kernel_function(nullptr) {}
27+
28+
ComputeCapability(std::unique_ptr<IndexedSubGraph> t_sub_graph,
29+
KernelCreateFn t_kernel_func)
30+
: sub_graph(std::move(t_sub_graph)),
31+
fuse_kernel_function(t_kernel_func) {}
32+
};
33+
} // namespace onnxruntime

onnxruntime/core/framework/execution_provider.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "core/framework/execution_provider.h"
44

55
#include "core/graph/graph_viewer.h"
6-
#include "core/framework/computation_capacity.h"
6+
#include "core/framework/compute_capability.h"
77
#include "core/framework/kernel_registry_manager.h"
88
#include "core/framework/op_kernel.h"
99
#include "core/framework/kernel_registry.h"
@@ -24,16 +24,16 @@ AllocatorPtr IExecutionProvider::GetAllocator(int id, ONNXRuntimeMemType mem_typ
2424
return nullptr;
2525
}
2626

27-
std::vector<std::unique_ptr<ComputationCapacity>>
27+
std::vector<std::unique_ptr<ComputeCapability>>
2828
IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
2929
const std::vector<const KernelRegistry*>& kernel_registries) const {
30-
std::vector<std::unique_ptr<ComputationCapacity>> result;
30+
std::vector<std::unique_ptr<ComputeCapability>> result;
3131
for (auto& node : graph.Nodes()) {
3232
for (auto registry : kernel_registries) {
3333
if (registry->TryFindKernel(node, Type()) != nullptr) {
3434
std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
3535
sub_graph->nodes.push_back(node.Index());
36-
result.push_back(std::make_unique<ComputationCapacity>(std::move(sub_graph), nullptr));
36+
result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph), nullptr));
3737
}
3838
}
3939
}

onnxruntime/core/framework/graph_partitioner.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "core/framework/kernel_registry_manager.h"
77
#include "core/graph/function.h"
88
#include "core/graph/graph_viewer.h"
9-
#include "core/framework/computation_capacity.h"
9+
#include "core/framework/compute_capability.h"
1010
#include "core/framework/kernel_registry_manager.h"
1111
#include "core/framework/execution_providers.h"
1212
#include "core/framework/kernel_registry.h"
@@ -66,29 +66,29 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
6666
for (auto& provider : providers_) {
6767
auto capability_results = provider->GetCapability(GraphViewer(graph), kernel_registries);
6868
int count = 0;
69-
for (auto& capacity : capability_results) {
70-
if (nullptr == capacity || nullptr == capacity->sub_graph_) {
69+
for (auto& capability : capability_results) {
70+
if (nullptr == capability || nullptr == capability->sub_graph) {
7171
continue;
7272
}
73-
if (nullptr == capacity->sub_graph_->GetMetaDef()) {
73+
if (nullptr == capability->sub_graph->GetMetaDef()) {
7474
// The <provider> can run a single node in the <graph> if not using meta-defs.
7575
// A fused kernel is not supported in this case.
76-
ONNXRUNTIME_ENFORCE(1 == capacity->sub_graph_->nodes.size());
77-
ONNXRUNTIME_ENFORCE(capacity->fuse_kernel_function_ == nullptr);
76+
ONNXRUNTIME_ENFORCE(1 == capability->sub_graph->nodes.size());
77+
ONNXRUNTIME_ENFORCE(capability->fuse_kernel_function == nullptr);
7878

79-
auto node = graph.GetNode(capacity->sub_graph_->nodes[0]);
79+
auto node = graph.GetNode(capability->sub_graph->nodes[0]);
8080
if (nullptr != node && node->GetExecutionProviderType().empty()) {
8181
node->SetExecutionProviderType(provider->Type());
8282
}
8383
} else {
8484
// The <provider> can run a fused <sub_graph> in the <graph>.
8585
//
8686
// Add fused node into <graph>
87-
ONNXRUNTIME_ENFORCE(nullptr != capacity->sub_graph_->GetMetaDef());
88-
std::string node_name = provider->Type() + "_" + capacity->sub_graph_->GetMetaDef()->name + "_" + std::to_string(count++);
89-
auto& fused_node = graph.FuseSubGraph(std::move(capacity->sub_graph_), node_name);
87+
ONNXRUNTIME_ENFORCE(nullptr != capability->sub_graph->GetMetaDef());
88+
std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++);
89+
auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name);
9090
fused_node.SetExecutionProviderType(provider->Type());
91-
auto fused_kernel_func = capacity->fuse_kernel_function_;
91+
auto fused_kernel_func = capability->fuse_kernel_function;
9292
if (fused_kernel_func != nullptr) {
9393
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
9494
KernelDefBuilder builder;

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "core/framework/op_kernel.h"
66
#include "core/framework/kernel_registry.h"
77
#include "contrib_ops/contrib_kernels.h"
8-
#include "core/framework/computation_capacity.h"
8+
#include "core/framework/compute_capability.h"
99

1010
namespace onnxruntime {
1111

@@ -484,14 +484,16 @@ static void RegisterCPUKernels(std::function<void(KernelCreateInfo&&)> create_fn
484484
}
485485

486486
std::shared_ptr<KernelRegistry> CPUExecutionProvider::GetKernelRegistry() const {
487-
static std::shared_ptr<KernelRegistry> kernel_registry = std::make_shared<KernelRegistry>(RegisterCPUKernels);
487+
static std::shared_ptr<KernelRegistry>
488+
kernel_registry = std::make_shared<KernelRegistry>(RegisterCPUKernels);
488489
return kernel_registry;
489490
}
490491

491-
std::vector<std::unique_ptr<ComputationCapacity>>
492+
std::vector<std::unique_ptr<ComputeCapability>>
492493
CPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
493494
const std::vector<const KernelRegistry*>& kernel_registries) const {
494-
std::vector<std::unique_ptr<ComputationCapacity>> result = IExecutionProvider::GetCapability(graph, kernel_registries);
495+
std::vector<std::unique_ptr<ComputeCapability>>
496+
result = IExecutionProvider::GetCapability(graph, kernel_registries);
495497

496498
for (auto& rule : fuse_rules_) {
497499
rule(graph, result);

onnxruntime/core/providers/cpu/cpu_execution_provider.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ struct CPUExecutionProviderInfo {
1919
CPUExecutionProviderInfo() = default;
2020
};
2121

22-
using FuseRuleFn = std::function<void(const onnxruntime::GraphViewer&, std::vector<std::unique_ptr<ComputationCapacity>>&)>;
22+
using FuseRuleFn = std::function<void(const onnxruntime::GraphViewer&, std::vector<std::unique_ptr<ComputeCapability>>&)>;
23+
2324
// Logical device representation.
2425
class CPUExecutionProvider : public IExecutionProvider {
2526
public:
2627
explicit CPUExecutionProvider(const CPUExecutionProviderInfo& info) {
27-
DeviceAllocatorRegistrationInfo device_info({ONNXRuntimeMemTypeDefault, [](int) { return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
28+
DeviceAllocatorRegistrationInfo device_info({ONNXRuntimeMemTypeDefault, [](int) {
29+
return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
2830
#ifdef USE_JEMALLOC
2931
ONNXRUNTIME_UNUSED_PARAMETER(info);
3032
//JEMalloc already has memory pool, so just use device allocator.
@@ -45,7 +47,7 @@ class CPUExecutionProvider : public IExecutionProvider {
4547
return onnxruntime::kCpuExecutionProvider;
4648
}
4749

48-
virtual std::vector<std::unique_ptr<ComputationCapacity>>
50+
virtual std::vector<std::unique_ptr<ComputeCapability>>
4951
GetCapability(const onnxruntime::GraphViewer& graph,
5052
const std::vector<const KernelRegistry*>& kernel_registries) const override;
5153

@@ -54,8 +56,9 @@ class CPUExecutionProvider : public IExecutionProvider {
5456
ONNXRUNTIME_ENFORCE(strcmp(dst.Location().name, CPU) == 0);
5557

5658
// Todo: support copy with different devices.
57-
if (strcmp(src.Location().name, CPU) != 0)
59+
if (strcmp(src.Location().name, CPU) != 0) {
5860
ONNXRUNTIME_NOT_IMPLEMENTED("copy from ", src.Location().name, " is not implemented");
61+
}
5962

6063
// no really copy needed if is copy to cpu.
6164
dst.ShallowCopy(src);

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "cuda_fence.h"
99
#include "cuda_allocator.h"
1010
#include "core/framework/kernel_registry.h"
11-
#include "core/framework/computation_capacity.h"
11+
#include "core/framework/compute_capability.h"
1212

1313
using namespace onnxruntime::common;
1414

@@ -62,12 +62,12 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in
6262
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyIn], cudaStreamNonBlocking));
6363
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking));
6464

65-
DeviceAllocatorRegistrationInfo default_allocator_info({ONNXRuntimeMemTypeDefault,
66-
[](int id) { return std::make_unique<CUDAAllocator>(id); }, std::numeric_limits<size_t>::max()});
65+
DeviceAllocatorRegistrationInfo default_allocator_info(
66+
{ONNXRuntimeMemTypeDefault, [](int id) { return std::make_unique<CUDAAllocator>(id); }, std::numeric_limits<size_t>::max()});
6767
InsertAllocator(CreateAllocator(default_allocator_info, device_id_));
6868

69-
DeviceAllocatorRegistrationInfo pinned_allocator_info({ONNXRuntimeMemTypeCPUOutput,
70-
[](int) { return std::make_unique<CUDAPinnedAllocator>(); }, std::numeric_limits<size_t>::max()});
69+
DeviceAllocatorRegistrationInfo pinned_allocator_info(
70+
{ONNXRuntimeMemTypeCPUOutput, [](int) { return std::make_unique<CUDAPinnedAllocator>(); }, std::numeric_limits<size_t>::max()});
7171
InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_));
7272
}
7373

@@ -824,10 +824,10 @@ bool CUDAExecutionProvider::RNNNeedFallbackToCPU(const onnxruntime::Node& node,
824824
return false;
825825
}
826826

827-
std::vector<std::unique_ptr<ComputationCapacity>>
827+
std::vector<std::unique_ptr<ComputeCapability>>
828828
CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
829829
const std::vector<const KernelRegistry*>& kernel_registries) const {
830-
std::vector<std::unique_ptr<ComputationCapacity>> result = IExecutionProvider::GetCapability(graph, kernel_registries);
830+
std::vector<std::unique_ptr<ComputeCapability>> result = IExecutionProvider::GetCapability(graph, kernel_registries);
831831

832832
for (auto& node : graph.Nodes()) {
833833
bool fallback_to_cpu_provider = false;

onnxruntime/core/providers/cuda/cuda_execution_provider.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
9090

9191
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
9292

93-
virtual std::vector<std::unique_ptr<ComputationCapacity>>
93+
virtual std::vector<std::unique_ptr<ComputeCapability>>
9494
GetCapability(const onnxruntime::GraphViewer& graph,
9595
const std::vector<const KernelRegistry*>& kernel_registries) const override;
9696
private:

onnxruntime/test/framework/inference_session_test.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "core/framework/op_kernel.h"
1717
#include "core/framework/session_state.h"
1818
#include "core/graph/graph_viewer.h"
19-
#include "core/framework/computation_capacity.h"
19+
#include "core/framework/compute_capability.h"
2020
#include "core/graph/model.h"
2121
#include "core/graph/op.h"
2222
#include "core/providers/cpu/cpu_execution_provider.h"
@@ -67,16 +67,17 @@ void RegisterOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
6767
class FuseExecutionProvider : public IExecutionProvider {
6868
public:
6969
explicit FuseExecutionProvider() {
70-
DeviceAllocatorRegistrationInfo device_info({ONNXRuntimeMemTypeDefault, [](int) { return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
70+
DeviceAllocatorRegistrationInfo device_info({ONNXRuntimeMemTypeDefault,
71+
[](int) { return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
7172
InsertAllocator(std::shared_ptr<IArenaAllocator>(
7273
std::make_unique<DummyArena>(device_info.factory(0))));
7374
}
7475

75-
std::vector<std::unique_ptr<ComputationCapacity>>
76+
std::vector<std::unique_ptr<ComputeCapability>>
7677
GetCapability(const onnxruntime::GraphViewer& graph,
7778
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override {
7879
// Fuse two add into one.
79-
std::vector<std::unique_ptr<ComputationCapacity>> result;
80+
std::vector<std::unique_ptr<ComputeCapability>> result;
8081
std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
8182
for (auto& node : graph.Nodes()) {
8283
sub_graph->nodes.push_back(node.Index());
@@ -89,12 +90,13 @@ class FuseExecutionProvider : public IExecutionProvider {
8990
meta_def->since_version = 1;
9091
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
9192
sub_graph->SetMetaDef(meta_def);
92-
result.push_back(std::make_unique<ComputationCapacity>(std::move(sub_graph), nullptr));
93+
result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph), nullptr));
9394
return result;
9495
}
9596

9697
std::shared_ptr<::onnxruntime::KernelRegistry> GetKernelRegistry() const override {
97-
static std::shared_ptr<::onnxruntime::KernelRegistry> kernel_registry = std::make_shared<::onnxruntime::KernelRegistry>(RegisterOperatorKernels);
98+
static std::shared_ptr<::onnxruntime::KernelRegistry>
99+
kernel_registry = std::make_shared<::onnxruntime::KernelRegistry>(RegisterOperatorKernels);
98100
return kernel_registry;
99101
}
100102

0 commit comments

Comments
 (0)