|
6 | 6 | #include "core/framework/kernel_registry_manager.h" |
7 | 7 | #include "core/graph/function.h" |
8 | 8 | #include "core/graph/graph_viewer.h" |
9 | | -#include "core/framework/computation_capacity.h" |
| 9 | +#include "core/framework/compute_capability.h" |
10 | 10 | #include "core/framework/kernel_registry_manager.h" |
11 | 11 | #include "core/framework/execution_providers.h" |
12 | 12 | #include "core/framework/kernel_registry.h" |
@@ -66,29 +66,29 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const { |
66 | 66 | for (auto& provider : providers_) { |
67 | 67 | auto capability_results = provider->GetCapability(GraphViewer(graph), kernel_registries); |
68 | 68 | int count = 0; |
69 | | - for (auto& capacity : capability_results) { |
70 | | - if (nullptr == capacity || nullptr == capacity->sub_graph_) { |
| 69 | + for (auto& capability : capability_results) { |
| 70 | + if (nullptr == capability || nullptr == capability->sub_graph) { |
71 | 71 | continue; |
72 | 72 | } |
73 | | - if (nullptr == capacity->sub_graph_->GetMetaDef()) { |
| 73 | + if (nullptr == capability->sub_graph->GetMetaDef()) { |
74 | 74 | // The <provider> can run a single node in the <graph> if not using meta-defs. |
75 | 75 | // A fused kernel is not supported in this case. |
76 | | - ONNXRUNTIME_ENFORCE(1 == capacity->sub_graph_->nodes.size()); |
77 | | - ONNXRUNTIME_ENFORCE(capacity->fuse_kernel_function_ == nullptr); |
| 76 | + ONNXRUNTIME_ENFORCE(1 == capability->sub_graph->nodes.size()); |
| 77 | + ONNXRUNTIME_ENFORCE(capability->fuse_kernel_function == nullptr); |
78 | 78 |
|
79 | | - auto node = graph.GetNode(capacity->sub_graph_->nodes[0]); |
| 79 | + auto node = graph.GetNode(capability->sub_graph->nodes[0]); |
80 | 80 | if (nullptr != node && node->GetExecutionProviderType().empty()) { |
81 | 81 | node->SetExecutionProviderType(provider->Type()); |
82 | 82 | } |
83 | 83 | } else { |
84 | 84 | // The <provider> can run a fused <sub_graph> in the <graph>. |
85 | 85 | // |
86 | 86 | // Add fused node into <graph> |
87 | | - ONNXRUNTIME_ENFORCE(nullptr != capacity->sub_graph_->GetMetaDef()); |
88 | | - std::string node_name = provider->Type() + "_" + capacity->sub_graph_->GetMetaDef()->name + "_" + std::to_string(count++); |
89 | | - auto& fused_node = graph.FuseSubGraph(std::move(capacity->sub_graph_), node_name); |
| 87 | + ONNXRUNTIME_ENFORCE(nullptr != capability->sub_graph->GetMetaDef()); |
| 88 | + std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++); |
| 89 | + auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name); |
90 | 90 | fused_node.SetExecutionProviderType(provider->Type()); |
91 | | - auto fused_kernel_func = capacity->fuse_kernel_function_; |
| 91 | + auto fused_kernel_func = capability->fuse_kernel_function; |
92 | 92 | if (fused_kernel_func != nullptr) { |
93 | 93 | // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. |
94 | 94 | KernelDefBuilder builder; |
|
0 commit comments