Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class GraphViewer;
} // namespace onnxruntime
namespace onnxruntime {

struct ComputationCapacity;
struct ComputeCapability;
class KernelRegistry;
class KernelRegistryManager;

Expand Down Expand Up @@ -50,7 +50,7 @@ class IExecutionProvider {
have overlap, and it's ONNXRuntime's responsibility to do the partition
and decide whether a node will be assigned to <*this> execution provider.
*/
virtual std::vector<std::unique_ptr<ComputationCapacity>>
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const KernelRegistry*>& kernel_registries) const;

Expand Down
29 changes: 0 additions & 29 deletions onnxruntime/core/framework/computation_capacity.h

This file was deleted.

33 changes: 33 additions & 0 deletions onnxruntime/core/framework/compute_capability.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/graph/indexed_sub_graph.h"

namespace onnxruntime {
// A structure encodes a subgraph and the method to run it.
struct ComputeCapability {
// The subgraph that an XP can execute, it could contain a single node
// or multiple nodes.
std::unique_ptr<IndexedSubGraph> sub_graph;

// When an execution provider fuses a subgraph into a kernel, it passes
// a kernel create function to onnxruntime so the runtime can create the
// compute kernel for the subgraph. Otherwise onnxruntime will search
// kernels in pre-defined kernel registry provided by XP.
KernelCreateFn fuse_kernel_function;

// TODO: if there is a FusedKernelFn attached, onnxruntime will generate
// the default KernelDefinition for it, according to the OpSchema it
// auto-generates. An execution provider can further set some advanced
// fields on kernel definition, such as memory placement / in-place
// annotation.
ComputeCapability() : sub_graph(nullptr), fuse_kernel_function(nullptr) {}

ComputeCapability(std::unique_ptr<IndexedSubGraph> t_sub_graph,
KernelCreateFn t_kernel_func)
: sub_graph(std::move(t_sub_graph)),
fuse_kernel_function(t_kernel_func) {}
};
} // namespace onnxruntime
8 changes: 4 additions & 4 deletions onnxruntime/core/framework/execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "core/framework/execution_provider.h"

#include "core/graph/graph_viewer.h"
#include "core/framework/computation_capacity.h"
#include "core/framework/compute_capability.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"
Expand All @@ -24,16 +24,16 @@ AllocatorPtr IExecutionProvider::GetAllocator(int id, ONNXRuntimeMemType mem_typ
return nullptr;
}

std::vector<std::unique_ptr<ComputationCapacity>>
std::vector<std::unique_ptr<ComputeCapability>>
IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const {
std::vector<std::unique_ptr<ComputationCapacity>> result;
std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& node : graph.Nodes()) {
for (auto registry : kernel_registries) {
if (registry->TryFindKernel(node, Type()) != nullptr) {
std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
sub_graph->nodes.push_back(node.Index());
result.push_back(std::make_unique<ComputationCapacity>(std::move(sub_graph), nullptr));
result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph), nullptr));
}
}
}
Expand Down
22 changes: 11 additions & 11 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "core/framework/kernel_registry_manager.h"
#include "core/graph/function.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/computation_capacity.h"
#include "core/framework/compute_capability.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/kernel_registry.h"
Expand Down Expand Up @@ -66,29 +66,29 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
for (auto& provider : providers_) {
auto capability_results = provider->GetCapability(GraphViewer(graph), kernel_registries);
int count = 0;
for (auto& capacity : capability_results) {
if (nullptr == capacity || nullptr == capacity->sub_graph_) {
for (auto& capability : capability_results) {
if (nullptr == capability || nullptr == capability->sub_graph) {
continue;
}
if (nullptr == capacity->sub_graph_->GetMetaDef()) {
if (nullptr == capability->sub_graph->GetMetaDef()) {
// The <provider> can run a single node in the <graph> if not using meta-defs.
// A fused kernel is not supported in this case.
ONNXRUNTIME_ENFORCE(1 == capacity->sub_graph_->nodes.size());
ONNXRUNTIME_ENFORCE(capacity->fuse_kernel_function_ == nullptr);
ONNXRUNTIME_ENFORCE(1 == capability->sub_graph->nodes.size());
ONNXRUNTIME_ENFORCE(capability->fuse_kernel_function == nullptr);

auto node = graph.GetNode(capacity->sub_graph_->nodes[0]);
auto node = graph.GetNode(capability->sub_graph->nodes[0]);
if (nullptr != node && node->GetExecutionProviderType().empty()) {
node->SetExecutionProviderType(provider->Type());
}
} else {
// The <provider> can run a fused <sub_graph> in the <graph>.
//
// Add fused node into <graph>
ONNXRUNTIME_ENFORCE(nullptr != capacity->sub_graph_->GetMetaDef());
std::string node_name = provider->Type() + "_" + capacity->sub_graph_->GetMetaDef()->name + "_" + std::to_string(count++);
auto& fused_node = graph.FuseSubGraph(std::move(capacity->sub_graph_), node_name);
ONNXRUNTIME_ENFORCE(nullptr != capability->sub_graph->GetMetaDef());
std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++);
auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name);
fused_node.SetExecutionProviderType(provider->Type());
auto fused_kernel_func = capacity->fuse_kernel_function_;
auto fused_kernel_func = capability->fuse_kernel_function;
if (fused_kernel_func != nullptr) {
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
KernelDefBuilder builder;
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"
#include "contrib_ops/contrib_kernels.h"
#include "core/framework/computation_capacity.h"
#include "core/framework/compute_capability.h"

namespace onnxruntime {

Expand Down Expand Up @@ -484,14 +484,16 @@ static void RegisterCPUKernels(std::function<void(KernelCreateInfo&&)> create_fn
}

std::shared_ptr<KernelRegistry> CPUExecutionProvider::GetKernelRegistry() const {
static std::shared_ptr<KernelRegistry> kernel_registry = std::make_shared<KernelRegistry>(RegisterCPUKernels);
static std::shared_ptr<KernelRegistry>
kernel_registry = std::make_shared<KernelRegistry>(RegisterCPUKernels);
return kernel_registry;
}

std::vector<std::unique_ptr<ComputationCapacity>>
std::vector<std::unique_ptr<ComputeCapability>>
CPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const {
std::vector<std::unique_ptr<ComputationCapacity>> result = IExecutionProvider::GetCapability(graph, kernel_registries);
std::vector<std::unique_ptr<ComputeCapability>>
result = IExecutionProvider::GetCapability(graph, kernel_registries);

for (auto& rule : fuse_rules_) {
rule(graph, result);
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ struct CPUExecutionProviderInfo {
CPUExecutionProviderInfo() = default;
};

using FuseRuleFn = std::function<void(const onnxruntime::GraphViewer&, std::vector<std::unique_ptr<ComputationCapacity>>&)>;
using FuseRuleFn = std::function<void(const onnxruntime::GraphViewer&, std::vector<std::unique_ptr<ComputeCapability>>&)>;

// Logical device representation.
class CPUExecutionProvider : public IExecutionProvider {
public:
explicit CPUExecutionProvider(const CPUExecutionProviderInfo& info) {
DeviceAllocatorRegistrationInfo device_info({ONNXRuntimeMemTypeDefault, [](int) { return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
DeviceAllocatorRegistrationInfo device_info({ONNXRuntimeMemTypeDefault, [](int) {
return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
#ifdef USE_JEMALLOC
ONNXRUNTIME_UNUSED_PARAMETER(info);
//JEMalloc already has memory pool, so just use device allocator.
Expand All @@ -45,7 +47,7 @@ class CPUExecutionProvider : public IExecutionProvider {
return onnxruntime::kCpuExecutionProvider;
}

virtual std::vector<std::unique_ptr<ComputationCapacity>>
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const override;

Expand All @@ -54,8 +56,9 @@ class CPUExecutionProvider : public IExecutionProvider {
ONNXRUNTIME_ENFORCE(strcmp(dst.Location().name, CPU) == 0);

// Todo: support copy with different devices.
if (strcmp(src.Location().name, CPU) != 0)
if (strcmp(src.Location().name, CPU) != 0) {
ONNXRUNTIME_NOT_IMPLEMENTED("copy from ", src.Location().name, " is not implemented");
}

// no really copy needed if is copy to cpu.
dst.ShallowCopy(src);
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "cuda_fence.h"
#include "cuda_allocator.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/computation_capacity.h"
#include "core/framework/compute_capability.h"

using namespace onnxruntime::common;

Expand Down Expand Up @@ -62,12 +62,12 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyIn], cudaStreamNonBlocking));
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&streams_[kCudaStreamCopyOut], cudaStreamNonBlocking));

DeviceAllocatorRegistrationInfo default_allocator_info({ONNXRuntimeMemTypeDefault,
[](int id) { return std::make_unique<CUDAAllocator>(id); }, std::numeric_limits<size_t>::max()});
DeviceAllocatorRegistrationInfo default_allocator_info(
{ONNXRuntimeMemTypeDefault, [](int id) { return std::make_unique<CUDAAllocator>(id); }, std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(default_allocator_info, device_id_));

DeviceAllocatorRegistrationInfo pinned_allocator_info({ONNXRuntimeMemTypeCPUOutput,
[](int) { return std::make_unique<CUDAPinnedAllocator>(); }, std::numeric_limits<size_t>::max()});
DeviceAllocatorRegistrationInfo pinned_allocator_info(
{ONNXRuntimeMemTypeCPUOutput, [](int) { return std::make_unique<CUDAPinnedAllocator>(); }, std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_));
}

Expand Down Expand Up @@ -824,10 +824,10 @@ bool CUDAExecutionProvider::RNNNeedFallbackToCPU(const onnxruntime::Node& node,
return false;
}

std::vector<std::unique_ptr<ComputationCapacity>>
std::vector<std::unique_ptr<ComputeCapability>>
CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const {
std::vector<std::unique_ptr<ComputationCapacity>> result = IExecutionProvider::GetCapability(graph, kernel_registries);
std::vector<std::unique_ptr<ComputeCapability>> result = IExecutionProvider::GetCapability(graph, kernel_registries);

for (auto& node : graph.Nodes()) {
bool fallback_to_cpu_provider = false;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class CUDAExecutionProvider : public IExecutionProvider {

virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;

virtual std::vector<std::unique_ptr<ComputationCapacity>>
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const override;
private:
Expand Down
14 changes: 8 additions & 6 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "core/framework/op_kernel.h"
#include "core/framework/session_state.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/computation_capacity.h"
#include "core/framework/compute_capability.h"
#include "core/graph/model.h"
#include "core/graph/op.h"
#include "core/providers/cpu/cpu_execution_provider.h"
Expand Down Expand Up @@ -67,16 +67,17 @@ void RegisterOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
class FuseExecutionProvider : public IExecutionProvider {
public:
explicit FuseExecutionProvider() {
DeviceAllocatorRegistrationInfo device_info({ONNXRuntimeMemTypeDefault, [](int) { return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
DeviceAllocatorRegistrationInfo device_info({ONNXRuntimeMemTypeDefault,
[](int) { return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
InsertAllocator(std::shared_ptr<IArenaAllocator>(
std::make_unique<DummyArena>(device_info.factory(0))));
}

std::vector<std::unique_ptr<ComputationCapacity>>
std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override {
// Fuse two add into one.
std::vector<std::unique_ptr<ComputationCapacity>> result;
std::vector<std::unique_ptr<ComputeCapability>> result;
std::unique_ptr<IndexedSubGraph> sub_graph = std::make_unique<IndexedSubGraph>();
for (auto& node : graph.Nodes()) {
sub_graph->nodes.push_back(node.Index());
Expand All @@ -89,12 +90,13 @@ class FuseExecutionProvider : public IExecutionProvider {
meta_def->since_version = 1;
meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
sub_graph->SetMetaDef(meta_def);
result.push_back(std::make_unique<ComputationCapacity>(std::move(sub_graph), nullptr));
result.push_back(std::make_unique<ComputeCapability>(std::move(sub_graph), nullptr));
return result;
}

std::shared_ptr<::onnxruntime::KernelRegistry> GetKernelRegistry() const override {
static std::shared_ptr<::onnxruntime::KernelRegistry> kernel_registry = std::make_shared<::onnxruntime::KernelRegistry>(RegisterOperatorKernels);
static std::shared_ptr<::onnxruntime::KernelRegistry>
kernel_registry = std::make_shared<::onnxruntime::KernelRegistry>(RegisterOperatorKernels);
return kernel_registry;
}

Expand Down
15 changes: 10 additions & 5 deletions onnxruntime/test/tvm/tvm_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <tvm/runtime/ndarray.h>
#include "core/codegen/tvm/tvm_kernel.h"
#include "core/framework/execution_provider.h"
#include "core/framework/computation_capacity.h"
#include "core/framework/compute_capability.h"
#include "core/graph/graph_viewer.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/session/inference_session.h"
Expand All @@ -23,7 +23,10 @@ tvm::Schedule DefaultTVMScheduleGenerator(const TVMGraph& tvm_graph) {
return tvm::create_schedule(args);
}

tvm::runtime::Module BuildStackVMDefaultModule(tvm::Schedule schedule, tvm::BuildConfig config, tvm::Array<tvm::Tensor> tvm_args, std::vector<std::string>& target_func_names) {
tvm::runtime::Module BuildStackVMDefaultModule(tvm::Schedule schedule,
tvm::BuildConfig config,
tvm::Array<tvm::Tensor> tvm_args,
std::vector<std::string>& target_func_names) {
auto target = tvm::target::stackvm();
std::string func_name = "func";
auto args = tvm::Array<tvm::Tensor>(tvm_args);
Expand Down Expand Up @@ -70,7 +73,7 @@ class UnionSet {
std::vector<int> farthers_;
};

void FuseAdd(const onnxruntime::GraphViewer& graph, std::vector<std::unique_ptr<ComputationCapacity>>& capacities) {
void FuseAdd(const onnxruntime::GraphViewer& graph, std::vector<std::unique_ptr<ComputeCapability>>& capacities) {
std::vector<onnxruntime::NodeIndex> add_nodes;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Add") {
Expand Down Expand Up @@ -136,8 +139,10 @@ void FuseAdd(const onnxruntime::GraphViewer& graph, std::vector<std::unique_ptr<
sub_graph->SetMetaDef(meta_def);
//TODO:set fuse kernel func;
capacities.push_back(
std::make_unique<ComputationCapacity>(std::move(sub_graph),
[](const OpKernelInfo& info) -> OpKernel* { return new TVMFuseAddKernels<DefaultTVMScheduleGenerator, BuildStackVMDefaultModule>(info); }));
std::make_unique<ComputeCapability>(
std::move(sub_graph),
[](const OpKernelInfo& info) -> OpKernel* {
return new TVMFuseAddKernels<DefaultTVMScheduleGenerator, BuildStackVMDefaultModule>(info); }));
}
}
}
Expand Down