Skip to content

Commit

Permalink
Kezhan/execute graph refactoring (#1553)
Browse files Browse the repository at this point in the history
* checking execution provider logic updated.

* fix the logic of copy input and output.

* update

* update

* update

* update

* update

* update

* fix ngraph failure.

* fix comments
  • Loading branch information
linkerzhang authored Aug 14, 2019
1 parent b405482 commit bd64ca3
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 111 deletions.
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ struct OrtDevice {
DeviceId device_id;
};

inline bool operator==(const OrtDevice& left, const OrtDevice& other) {
return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type();
}

struct OrtAllocatorInfo {
// use string for name, so we could have customized allocator in execution provider.
const char* name;
Expand Down
48 changes: 29 additions & 19 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ class PlannerImpl {

Status ComputeUseCounts() {
// Note: for every ml-value, its definition must appear before all its uses in a topological sort of a valid model
std::unordered_set<std::string> graph_inputs;
for (auto& graph_input : graph_viewer_.GetInputsIncludingInitializers()) {
graph_inputs.insert(graph_input->Name());
}

for (auto graph_input : graph_viewer_.GetInputs()) {
OrtValueIndex index = Index(graph_input->Name());
Expand All @@ -371,15 +375,7 @@ class PlannerImpl {
for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
auto pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index);
for (auto node_input : pnode->InputDefs()) {
if (node_input->Exists())
UseCount(node_input->Name())++;
}

for (auto node_input : pnode->ImplicitInputDefs()) {
if (node_input->Exists())
UseCount(node_input->Name())++;
}
// Identify where each output of this node should be allocated.
// This is determined by the opkernel bound to the node.
const KernelCreateInfo* kernel_create_info = nullptr;
Expand All @@ -394,31 +390,45 @@ class PlannerImpl {
if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")";
return Status(ONNXRUNTIME, FAIL, errormsg.str());
}

auto exec_provider = execution_providers_.Get(*pnode);
if (exec_provider == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the execution provider ",
pnode->GetExecutionProviderType());
}

auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
// increment UseCount and add location information if applicable for the provided input def
auto process_input = [&graph_inputs, &exec_provider, &p_kernelDef, this](const NodeArg& input, size_t arg_idx) {
const auto& name = input.Name();
UseCount(name)++;

// If it's a graph input or outer scope node arg, set its plan.
// NOTE: Copy nodes should have already been added if a graph input is fed as input
// to nodes assigned to different providers.
if (graph_inputs.find(name) != graph_inputs.cend() ||
std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(),
[&name](const NodeArg* value) {
return value && value->Name() == name;
}) != outer_scope_node_args_.cend()) {
OrtValueIndex index = Index(name);
plan_.SetLocation(static_cast<size_t>(index),
exec_provider->GetAllocator(0, p_kernelDef->InputMemoryType(arg_idx))->Info());
}

return Status::OK();
};

ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->InputDefs(), process_input));
ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->ImplicitInputDefs(), process_input));

auto outputs = pnode->OutputDefs();
auto num_outputs = outputs.size();

for (size_t i = 0; i < num_outputs; ++i) {
auto* node_output = outputs[i];
if (!node_output->Exists()) continue;
OrtValueIndex index = Index(node_output->Name());
ProcessDef(index, node_output);
++UseCount(index);
if (strcmp(default_allocator_info.name, CPU) != 0) {
// By default, outputs of this node are allocated on the default device allocator,
// except for outputs marked for allocation in MemoryType:
auto memory_type = p_kernelDef->OutputMemoryType(i);
plan_.SetLocation(static_cast<size_t>(index), memory_type == OrtMemTypeDefault
? default_allocator_info
: exec_provider->GetAllocator(0, memory_type)->Info());
}
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, p_kernelDef->OutputMemoryType(i))->Info());
}
// if sync is needed, mark allocation plan as create_fence_if_async=true
// note that the input arg may come from an execution provider (i.e. CPU) that does not support async,
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/framework/feeds_fetches_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ struct FeedsFetchesInfo {
class FeedsFetchesManager {
public:
struct MLValueCopyInfo {
int allocation_device_id = 0;
OrtDevice target_device;
const IExecutionProvider* allocation_provider = nullptr;
const IExecutionProvider* copy_provider = nullptr;
};

static Status Create(const std::vector<std::string>& feed_names, const std::vector<std::string>& output_names,
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,6 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
//prepare the func kernel
KernelDefBuilder builder;
BuildFusedKernelDef(builder, *node);
if (node->GetExecutionProviderType() == onnxruntime::kNGraphExecutionProvider || node->GetExecutionProviderType() == onnxruntime::kNnapiExecutionProvider) {
builder.SetDefaultInputsMemoryType(OrtMemTypeCPUInput);
builder.SetDefaultOutputMemoryType(OrtMemTypeCPUOutput);
}
ORT_RETURN_IF_ERROR(fused_kernel_registry->Register(
builder, static_cast<KernelCreatePtrFn>([](const OpKernelInfo& info) -> OpKernel* { return new FunctionKernel(info); })));
}
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,18 @@ class SessionState {
* \param p_node0 Nullable
* \param kci0 Nullable
*/
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0)
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0)
: index(index0),
p_node(p_node0),
kci(kci0) {
}
kci(kci0),
device(&device0) {}

size_t index;
// Nullable
const onnxruntime::Node* p_node = nullptr;
// Nullable
const KernelCreateInfo* kci = nullptr;
const OrtDevice* device = nullptr;
};

using NameNodeInfoMapType = std::unordered_map<std::string, std::vector<NodeInfo>>;
Expand Down
20 changes: 17 additions & 3 deletions onnxruntime/core/framework/session_state_initializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
if (implicit_inputs && implicit_inputs->empty()) {
implicit_inputs = nullptr;
}
const auto* exec_plan = session_state.GetExecutionPlan();
const auto& name_to_id = session_state.GetOrtValueNameIdxMap();

for (auto& node : graph.Nodes()) {
// note that KernelCreateInfo may not exist for custom kernel
Expand All @@ -365,7 +367,11 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
return Status::OK();
}

SessionState::NodeInfo node_info(index, &node, kci);
int arg_index;
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index));
const auto& device = exec_plan->GetLocation(arg_index).device;

SessionState::NodeInfo node_info(index, &node, kci, device);

if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) {
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info));
Expand Down Expand Up @@ -397,8 +403,13 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
// copy to/from CPU to go through the control flow nodes where possible/applicable.
// the processing for the subgraph where the implicit input is consumed will do the real check on whether any
// copy to a different device is required
SessionState::NodeInfo node_info(std::numeric_limits<size_t>::max(), &node, kci);
for (const auto& input_def : node_implicit_inputs) {
int arg_index;
//Question: the implicit input may not be found in this session state name to id map, but in parent session state name to id map.
//@Scott
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index));
auto& device = exec_plan->GetLocation(arg_index).device;
SessionState::NodeInfo node_info(std::numeric_limits<size_t>::max(), &node, kci, device);
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info));
}
}
Expand All @@ -413,7 +424,6 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph

auto& input_map = session_state.GetInputNodeInfoMap();
auto end_map = input_map.cend();
SessionState::NodeInfo empty_node_info(std::numeric_limits<size_t>::max(), nullptr, nullptr);

for (const auto& graph_input : graph_inputs) {
const auto& name = graph_input->Name();
Expand All @@ -422,6 +432,10 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
// utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere.
LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name "
<< name << " is not used by any node.";
int arg_index;
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(name, arg_index));
auto& device = exec_plan->GetLocation(arg_index).device;
SessionState::NodeInfo empty_node_info(std::numeric_limits<size_t>::max(), nullptr, nullptr, device);
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(name, empty_node_info));
}
}
Expand Down
116 changes: 37 additions & 79 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,18 @@ AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorI
return session_state.GetExecutionProviders().GetAllocator(allocator_info);
}

common::Status AllocateHelper(const IExecutionProvider& execution_provider, int device_id, const Tensor& fetched_tensor,
bool ProviderIsCpuBased(const std::string& provider_type) {
return provider_type == onnxruntime::kCpuExecutionProvider ||
provider_type == onnxruntime::kMklDnnExecutionProvider ||
provider_type == onnxruntime::kNGraphExecutionProvider ||
provider_type == onnxruntime::kNupharExecutionProvider ||
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
provider_type == onnxruntime::kNnapiExecutionProvider;
}

common::Status AllocateHelper(const IExecutionProvider& execution_provider, const OrtDevice& device, const Tensor& fetched_tensor,
OrtValue& output_mlvalue) {
auto allocator = execution_provider.GetAllocator(device_id, OrtMemTypeDefault);
auto allocator = execution_provider.GetAllocator(device.Id(), OrtMemTypeDefault);
if (!allocator) {
return Status(common::ONNXRUNTIME, common::FAIL, "invalid allocator");
}
Expand Down Expand Up @@ -62,20 +71,20 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
const FeedsFetchesManager::MLValueCopyInfo& copy_info,
const OrtValue& source_mlvalue,
OrtValue& target_mlvalue) {
if (copy_info.copy_provider == nullptr) {
if (copy_info.allocation_provider == nullptr){
target_mlvalue = source_mlvalue;
} else {
auto& source_tensor = source_mlvalue.Get<Tensor>();
return Status::OK();
}

if (!target_mlvalue.IsAllocated()) {
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.allocation_device_id,
source_tensor, target_mlvalue));
}
auto& source_tensor = source_mlvalue.Get<Tensor>();
if (!target_mlvalue.IsAllocated()) {
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device,
source_tensor, target_mlvalue));
}

Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();

ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
}
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));

return Status::OK();
}
Expand All @@ -86,8 +95,6 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
FeedsFetchesManager::MLValueCopyInfo& copy_info) {
needed_copy = false;

//TODO: make it configurable
const int target_device_id = 0;
std::vector<SessionState::NodeInfo> node_info_vec;
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));

Expand All @@ -111,51 +118,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
break;
}

auto& required_provider_type = GetNodeInputProviderType(node_info);
auto& input_tensor = orig_mlvalue.Get<Tensor>();
auto& input_tensor_loc = input_tensor.Location();

auto* p_input_provider = exec_providers.Get(input_tensor_loc);
if (!p_input_provider) {
p_input_provider = exec_providers.Get(onnxruntime::kCpuExecutionProvider);
ORT_ENFORCE(p_input_provider);
}

//no copy for nGraph
if (required_provider_type == onnxruntime::kNGraphExecutionProvider) {
new_mlvalue = orig_mlvalue;
break;
}

auto input_provider_type = p_input_provider->Type();
if (input_provider_type == required_provider_type && input_tensor_loc.mem_type == OrtMemTypeDefault) {
new_mlvalue = orig_mlvalue;
break;
}

// If a node requires input on cpu and input tensor is allocated with pinned memory allocator, don't do copy
if (required_provider_type == onnxruntime::kCpuExecutionProvider &&
input_tensor_loc.mem_type == OrtMemTypeCPU) {
auto& required_device = *node_info.device;
auto& input_tensor_device = orig_mlvalue.Get<Tensor>().Location().device;
if (required_device == input_tensor_device) {
// No copy needed for same device.
new_mlvalue = orig_mlvalue;
break;
}

auto& required_provider_type = GetNodeInputProviderType(node_info);
auto* required_provider = exec_providers.Get(required_provider_type);
ORT_ENFORCE(required_provider);

auto* p_copy_provider = (required_provider_type != onnxruntime::kCpuExecutionProvider)
? required_provider
: p_input_provider;

copy_info.allocation_device_id = target_device_id;
copy_info.target_device = required_device;
copy_info.allocation_provider = required_provider;
copy_info.copy_provider = p_copy_provider;

ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue));

needed_copy = true;

// } loop of node_info_vec
} while (false);

return Status::OK();
Expand Down Expand Up @@ -344,43 +323,26 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
continue;
}

auto& fetched_tensor = fetched_mlvalue.Get<Tensor>();
auto& fetched_tensor_location = fetched_tensor.Location();
auto* p_fetched_provider = execution_providers.Get(fetched_tensor_location);
if (!p_fetched_provider) {
p_fetched_provider = cpu_execution_provider;
}

auto fetched_provider_type = p_fetched_provider->Type();
auto& output_mlvalue = user_fetches[idx];

const IExecutionProvider* p_output_provider = nullptr;

auto target_device = OrtDevice();
auto& output_mlvalue = user_fetches[idx];
if (output_mlvalue.IsAllocated()) {
Tensor* p_output_tensor = output_mlvalue.GetMutable<Tensor>();
target_device = p_output_tensor->Location().device;
p_output_provider = execution_providers.Get(p_output_tensor->Location());
}
auto fetch_result_device = fetched_mlvalue.Get<Tensor>().Location().device;
if (target_device == fetch_result_device) {
user_fetches[idx] = fetched_mlvalue;
continue;
}

if (!p_output_provider) {
p_output_provider = cpu_execution_provider;
}

auto output_provider_type = p_output_provider->Type();

if (fetched_provider_type == output_provider_type ||
(p_output_provider == cpu_execution_provider && fetched_tensor_location.mem_type == OrtMemTypeCPUOutput)) {
user_fetches[idx] = fetched_mlvalue;
continue;
}

needed_copy = true;

auto* p_copy_provider = (fetched_provider_type != onnxruntime::kCpuExecutionProvider)
? p_fetched_provider
: p_output_provider;

const int device_id = 0; // TODO: As per comment in the copy input code, make this configurable.
FeedsFetchesManager::MLValueCopyInfo copy_info{device_id, p_output_provider, p_copy_provider};
FeedsFetchesManager::MLValueCopyInfo copy_info{target_device, p_output_provider};
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, fetched_mlvalue, output_mlvalue));

if (copiers) {
Expand Down Expand Up @@ -410,11 +372,7 @@ static common::Status CachedCopyOutputsAcrossDevices(

static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) {
for (const auto& execution_provider : execution_providers) {
if (execution_provider->Type() != onnxruntime::kCpuExecutionProvider &&
execution_provider->Type() != onnxruntime::kMklDnnExecutionProvider &&
execution_provider->Type() != onnxruntime::kNGraphExecutionProvider &&
execution_provider->Type() != onnxruntime::kNupharExecutionProvider &&
execution_provider->Type() != onnxruntime::kOpenVINOExecutionProvider) {
if (!ProviderIsCpuBased(execution_provider->Type())) {
return DeviceCopyCheck::Unknown;
}
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in
DeviceAllocatorRegistrationInfo pinned_allocator_info(
{OrtMemTypeCPUOutput, [](int) { return std::make_unique<CUDAPinnedAllocator>(0, CUDA_PINNED); }, std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_));

// TODO: this is actually used for the cuda kernels which explicitly ask for inputs from CPU.
// This will be refactored/removed when allocator and execution provider are decoupled.
DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUInput, [](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>("CUDA_CPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUInput)); }, std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(cpu_allocator_info));
}

CUDAExecutionProvider::~CUDAExecutionProvider() {
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/test/providers/provider_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,6 @@ void OpTester::Run(ExpectResult expect_result,
std::unordered_map<std::string, OrtValue> feeds;
std::vector<std::string> output_names;
FillFeedsAndOutputNames(feeds, output_names);

// Run the model
SessionOptions so;
so.session_logid = op_;
Expand Down

0 comments on commit bd64ca3

Please sign in to comment.