Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kezhan/execute graph refactoring #1553

Merged
merged 21 commits into from
Aug 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d546f7a
checking execution provider logic updated.
linkerzhang Aug 1, 2019
8c1b47d
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 2, 2019
69a66f0
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 2, 2019
398b64c
fix the logic of copy input and output.
linkerzhang Aug 2, 2019
26e69ad
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 5, 2019
7d03eaf
update
linkerzhang Aug 6, 2019
749eab5
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 6, 2019
fc48324
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 6, 2019
53eb9f4
update
linkerzhang Aug 6, 2019
783ef3c
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 7, 2019
d7b0d2b
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 7, 2019
9e3c407
update
linkerzhang Aug 7, 2019
62cc6a1
update
linkerzhang Aug 7, 2019
567aecc
update
linkerzhang Aug 7, 2019
f14ee8f
update
linkerzhang Aug 8, 2019
e679733
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 8, 2019
ae1bb3d
sync and update
linkerzhang Aug 9, 2019
1cae5ac
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 9, 2019
4c79280
fix ngraph failure.
linkerzhang Aug 9, 2019
46f3b19
fix comments
linkerzhang Aug 9, 2019
749fb88
Merge branch 'master' of https://github.com/Microsoft/onnxruntime int…
linkerzhang Aug 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ struct OrtDevice {
DeviceId device_id;
};

inline bool operator==(const OrtDevice& left, const OrtDevice& other) {
return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is memcmp more efficient?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would assume that object padding is always the same


In reply to: 310326739 [](ancestors = 310326739)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuslepukhin For the same compiler, same build, isn't it?

}

struct OrtAllocatorInfo {
// use string for name, so we could have customized allocator in execution provider.
const char* name;
Expand Down
48 changes: 29 additions & 19 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ class PlannerImpl {

Status ComputeUseCounts() {
// Note: for every ml-value, its definition must appear before all its uses in a topological sort of a valid model
std::unordered_set<std::string> graph_inputs;
for (auto& graph_input : graph_viewer_.GetInputsIncludingInitializers()) {
graph_inputs.insert(graph_input->Name());
}

for (auto graph_input : graph_viewer_.GetInputs()) {
OrtValueIndex index = Index(graph_input->Name());
Expand All @@ -371,15 +375,7 @@ class PlannerImpl {
for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
auto pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index);
for (auto node_input : pnode->InputDefs()) {
if (node_input->Exists())
UseCount(node_input->Name())++;
}

for (auto node_input : pnode->ImplicitInputDefs()) {
if (node_input->Exists())
UseCount(node_input->Name())++;
}
// Identify where each output of this node should be allocated.
// This is determined by the opkernel bound to the node.
const KernelCreateInfo* kernel_create_info = nullptr;
Expand All @@ -394,31 +390,45 @@ class PlannerImpl {
if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")";
return Status(ONNXRUNTIME, FAIL, errormsg.str());
}

auto exec_provider = execution_providers_.Get(*pnode);
if (exec_provider == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the execution provider ",
pnode->GetExecutionProviderType());
}

auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
// increment UseCount and add location information if applicable for the provided input def
auto process_input = [&graph_inputs, &exec_provider, &p_kernelDef, this](const NodeArg& input, size_t arg_idx) {
const auto& name = input.Name();
UseCount(name)++;

// If it's a graph input or outer scope node arg, set its plan.
// NOTE: Copy nodes should have already been added if a graph input is fed as input
// to nodes assigned to different providers.
if (graph_inputs.find(name) != graph_inputs.cend() ||
std::find_if(outer_scope_node_args_.cbegin(), outer_scope_node_args_.cend(),
[&name](const NodeArg* value) {
return value && value->Name() == name;
}) != outer_scope_node_args_.cend()) {
OrtValueIndex index = Index(name);
plan_.SetLocation(static_cast<size_t>(index),
exec_provider->GetAllocator(0, p_kernelDef->InputMemoryType(arg_idx))->Info());
}

return Status::OK();
};

ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->InputDefs(), process_input));
ORT_RETURN_IF_ERROR(Node::ForEachWithIndex(pnode->ImplicitInputDefs(), process_input));

auto outputs = pnode->OutputDefs();
auto num_outputs = outputs.size();

for (size_t i = 0; i < num_outputs; ++i) {
auto* node_output = outputs[i];
if (!node_output->Exists()) continue;
OrtValueIndex index = Index(node_output->Name());
ProcessDef(index, node_output);
++UseCount(index);
if (strcmp(default_allocator_info.name, CPU) != 0) {
// By default, outputs of this node are allocated on the default device allocator,
// except for outputs marked for allocation in MemoryType:
auto memory_type = p_kernelDef->OutputMemoryType(i);
plan_.SetLocation(static_cast<size_t>(index), memory_type == OrtMemTypeDefault
? default_allocator_info
: exec_provider->GetAllocator(0, memory_type)->Info());
}
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, p_kernelDef->OutputMemoryType(i))->Info());
}
// if sync is needed, mark allocation plan as create_fence_if_async=true
// note that the input arg may come from an execution provider (i.e. CPU) that does not support async,
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/framework/feeds_fetches_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ struct FeedsFetchesInfo {
class FeedsFetchesManager {
public:
struct MLValueCopyInfo {
int allocation_device_id = 0;
OrtDevice target_device;
const IExecutionProvider* allocation_provider = nullptr;
const IExecutionProvider* copy_provider = nullptr;
};

static Status Create(const std::vector<std::string>& feed_names, const std::vector<std::string>& output_names,
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,6 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f
//prepare the func kernel
KernelDefBuilder builder;
BuildFusedKernelDef(builder, *node);
if (node->GetExecutionProviderType() == onnxruntime::kNGraphExecutionProvider || node->GetExecutionProviderType() == onnxruntime::kNnapiExecutionProvider) {
builder.SetDefaultInputsMemoryType(OrtMemTypeCPUInput);
builder.SetDefaultOutputMemoryType(OrtMemTypeCPUOutput);
}
ORT_RETURN_IF_ERROR(fused_kernel_registry->Register(
builder, static_cast<KernelCreatePtrFn>([](const OpKernelInfo& info) -> OpKernel* { return new FunctionKernel(info); })));
}
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,18 @@ class SessionState {
* \param p_node0 Nullable
* \param kci0 Nullable
*/
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0)
NodeInfo(size_t index0, const onnxruntime::Node* p_node0, const KernelCreateInfo* kci0, const OrtDevice& device0)
: index(index0),
p_node(p_node0),
kci(kci0) {
}
kci(kci0),
device(&device0) {}

size_t index;
// Nullable
const onnxruntime::Node* p_node = nullptr;
// Nullable
const KernelCreateInfo* kci = nullptr;
const OrtDevice* device = nullptr;
};

using NameNodeInfoMapType = std::unordered_map<std::string, std::vector<NodeInfo>>;
Expand Down
20 changes: 17 additions & 3 deletions onnxruntime/core/framework/session_state_initializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
if (implicit_inputs && implicit_inputs->empty()) {
implicit_inputs = nullptr;
}
const auto* exec_plan = session_state.GetExecutionPlan();
const auto& name_to_id = session_state.GetOrtValueNameIdxMap();

for (auto& node : graph.Nodes()) {
// note that KernelCreateInfo may not exist for custom kernel
Expand All @@ -365,7 +367,11 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
return Status::OK();
}

SessionState::NodeInfo node_info(index, &node, kci);
int arg_index;
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg.Name(), arg_index));
const auto& device = exec_plan->GetLocation(arg_index).device;

SessionState::NodeInfo node_info(index, &node, kci, device);

if (IsArgNameInInputsOutputs(arg.Name(), graph_inputs)) {
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(arg.Name(), node_info));
Expand Down Expand Up @@ -397,8 +403,13 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
// copy to/from CPU to go through the control flow nodes where possible/applicable.
// the processing for the subgraph where the implicit input is consumed will do the real check on whether any
// copy to a different device is required
SessionState::NodeInfo node_info(std::numeric_limits<size_t>::max(), &node, kci);
for (const auto& input_def : node_implicit_inputs) {
int arg_index;
//Question: the implicit input may not be found in this session state name to id map, but in parent session state name to id map.
//@Scott
linkerzhang marked this conversation as resolved.
Show resolved Hide resolved
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(input_def->Name(), arg_index));
auto& device = exec_plan->GetLocation(arg_index).device;
SessionState::NodeInfo node_info(std::numeric_limits<size_t>::max(), &node, kci, device);
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(input_def->Name(), node_info));
}
}
Expand All @@ -413,7 +424,6 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph

auto& input_map = session_state.GetInputNodeInfoMap();
auto end_map = input_map.cend();
SessionState::NodeInfo empty_node_info(std::numeric_limits<size_t>::max(), nullptr, nullptr);

for (const auto& graph_input : graph_inputs) {
const auto& name = graph_input->Name();
Expand All @@ -422,6 +432,10 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph
// utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere.
LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name "
<< name << " is not used by any node.";
int arg_index;
ORT_RETURN_IF_ERROR(name_to_id.GetIdx(name, arg_index));
auto& device = exec_plan->GetLocation(arg_index).device;
SessionState::NodeInfo empty_node_info(std::numeric_limits<size_t>::max(), nullptr, nullptr, device);
ORT_RETURN_IF_ERROR(session_state.AddInputNameToNodeInfoMapping(name, empty_node_info));
}
}
Expand Down
116 changes: 37 additions & 79 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,18 @@ AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorI
return session_state.GetExecutionProviders().GetAllocator(allocator_info);
}

common::Status AllocateHelper(const IExecutionProvider& execution_provider, int device_id, const Tensor& fetched_tensor,
bool ProviderIsCpuBased(const std::string& provider_type) {
return provider_type == onnxruntime::kCpuExecutionProvider ||
provider_type == onnxruntime::kMklDnnExecutionProvider ||
linkerzhang marked this conversation as resolved.
Show resolved Hide resolved
provider_type == onnxruntime::kNGraphExecutionProvider ||
provider_type == onnxruntime::kNupharExecutionProvider ||
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
provider_type == onnxruntime::kNnapiExecutionProvider;
}

common::Status AllocateHelper(const IExecutionProvider& execution_provider, const OrtDevice& device, const Tensor& fetched_tensor,
OrtValue& output_mlvalue) {
auto allocator = execution_provider.GetAllocator(device_id, OrtMemTypeDefault);
auto allocator = execution_provider.GetAllocator(device.Id(), OrtMemTypeDefault);
if (!allocator) {
return Status(common::ONNXRUNTIME, common::FAIL, "invalid allocator");
}
Expand Down Expand Up @@ -62,20 +71,20 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr,
const FeedsFetchesManager::MLValueCopyInfo& copy_info,
const OrtValue& source_mlvalue,
OrtValue& target_mlvalue) {
if (copy_info.copy_provider == nullptr) {
if (copy_info.allocation_provider == nullptr){
target_mlvalue = source_mlvalue;
} else {
linkerzhang marked this conversation as resolved.
Show resolved Hide resolved
auto& source_tensor = source_mlvalue.Get<Tensor>();
return Status::OK();
}

if (!target_mlvalue.IsAllocated()) {
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.allocation_device_id,
source_tensor, target_mlvalue));
}
auto& source_tensor = source_mlvalue.Get<Tensor>();
if (!target_mlvalue.IsAllocated()) {
ORT_RETURN_IF_ERROR(utils::AllocateHelper(*copy_info.allocation_provider, copy_info.target_device,
source_tensor, target_mlvalue));
}

Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();
Tensor* p_output_tensor = target_mlvalue.GetMutable<Tensor>();

ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));
}
ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(source_tensor, *p_output_tensor));

return Status::OK();
}
Expand All @@ -86,8 +95,6 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
FeedsFetchesManager::MLValueCopyInfo& copy_info) {
needed_copy = false;

//TODO: make it configurable
const int target_device_id = 0;
std::vector<SessionState::NodeInfo> node_info_vec;
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));

Expand All @@ -111,51 +118,23 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
break;
}

auto& required_provider_type = GetNodeInputProviderType(node_info);
auto& input_tensor = orig_mlvalue.Get<Tensor>();
auto& input_tensor_loc = input_tensor.Location();

auto* p_input_provider = exec_providers.Get(input_tensor_loc);
if (!p_input_provider) {
p_input_provider = exec_providers.Get(onnxruntime::kCpuExecutionProvider);
ORT_ENFORCE(p_input_provider);
}

//no copy for nGraph
if (required_provider_type == onnxruntime::kNGraphExecutionProvider) {
new_mlvalue = orig_mlvalue;
break;
}

auto input_provider_type = p_input_provider->Type();
if (input_provider_type == required_provider_type && input_tensor_loc.mem_type == OrtMemTypeDefault) {
new_mlvalue = orig_mlvalue;
break;
}

// If a node requires input on cpu and input tensor is allocated with pinned memory allocator, don't do copy
if (required_provider_type == onnxruntime::kCpuExecutionProvider &&
input_tensor_loc.mem_type == OrtMemTypeCPU) {
auto& required_device = *node_info.device;
auto& input_tensor_device = orig_mlvalue.Get<Tensor>().Location().device;
if (required_device == input_tensor_device) {
// No copy needed for same device.
new_mlvalue = orig_mlvalue;
break;
}

auto& required_provider_type = GetNodeInputProviderType(node_info);
auto* required_provider = exec_providers.Get(required_provider_type);
ORT_ENFORCE(required_provider);

auto* p_copy_provider = (required_provider_type != onnxruntime::kCpuExecutionProvider)
? required_provider
: p_input_provider;

copy_info.allocation_device_id = target_device_id;
copy_info.target_device = required_device;
copy_info.allocation_provider = required_provider;
copy_info.copy_provider = p_copy_provider;

linkerzhang marked this conversation as resolved.
Show resolved Hide resolved
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, orig_mlvalue, new_mlvalue));

needed_copy = true;

// } loop of node_info_vec
} while (false);

return Status::OK();
Expand Down Expand Up @@ -344,43 +323,26 @@ static common::Status CopyOutputsAcrossDevices(const SessionState& session_state
continue;
}

auto& fetched_tensor = fetched_mlvalue.Get<Tensor>();
auto& fetched_tensor_location = fetched_tensor.Location();
auto* p_fetched_provider = execution_providers.Get(fetched_tensor_location);
if (!p_fetched_provider) {
p_fetched_provider = cpu_execution_provider;
}

auto fetched_provider_type = p_fetched_provider->Type();
auto& output_mlvalue = user_fetches[idx];

const IExecutionProvider* p_output_provider = nullptr;

auto target_device = OrtDevice();
auto& output_mlvalue = user_fetches[idx];
if (output_mlvalue.IsAllocated()) {
Tensor* p_output_tensor = output_mlvalue.GetMutable<Tensor>();
target_device = p_output_tensor->Location().device;
p_output_provider = execution_providers.Get(p_output_tensor->Location());
}
auto fetch_result_device = fetched_mlvalue.Get<Tensor>().Location().device;
if (target_device == fetch_result_device) {
user_fetches[idx] = fetched_mlvalue;
continue;
}

if (!p_output_provider) {
p_output_provider = cpu_execution_provider;
}

auto output_provider_type = p_output_provider->Type();

if (fetched_provider_type == output_provider_type ||
(p_output_provider == cpu_execution_provider && fetched_tensor_location.mem_type == OrtMemTypeCPUOutput)) {
user_fetches[idx] = fetched_mlvalue;
continue;
}

needed_copy = true;

auto* p_copy_provider = (fetched_provider_type != onnxruntime::kCpuExecutionProvider)
? p_fetched_provider
: p_output_provider;

const int device_id = 0; // TODO: As per comment in the copy input code, make this configurable.
FeedsFetchesManager::MLValueCopyInfo copy_info{device_id, p_output_provider, p_copy_provider};
FeedsFetchesManager::MLValueCopyInfo copy_info{target_device, p_output_provider};
ORT_RETURN_IF_ERROR(CopyMLValue(session_state.GetDataTransferMgr(), copy_info, fetched_mlvalue, output_mlvalue));

if (copiers) {
Expand Down Expand Up @@ -410,11 +372,7 @@ static common::Status CachedCopyOutputsAcrossDevices(

static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) {
for (const auto& execution_provider : execution_providers) {
if (execution_provider->Type() != onnxruntime::kCpuExecutionProvider &&
execution_provider->Type() != onnxruntime::kMklDnnExecutionProvider &&
execution_provider->Type() != onnxruntime::kNGraphExecutionProvider &&
execution_provider->Type() != onnxruntime::kNupharExecutionProvider &&
execution_provider->Type() != onnxruntime::kOpenVINOExecutionProvider) {
if (!ProviderIsCpuBased(execution_provider->Type())) {
return DeviceCopyCheck::Unknown;
}
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in
DeviceAllocatorRegistrationInfo pinned_allocator_info(
{OrtMemTypeCPUOutput, [](int) { return std::make_unique<CUDAPinnedAllocator>(0, CUDA_PINNED); }, std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_));

// TODO: this is actually used for the cuda kernels which explicitly ask for inputs from CPU.
// This will be refactored/removed when allocator and execution provider are decoupled.
DeviceAllocatorRegistrationInfo cpu_allocator_info({OrtMemTypeCPUInput, [](int) { return std::make_unique<CPUAllocator>(std::make_unique<OrtAllocatorInfo>("CUDA_CPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUInput)); }, std::numeric_limits<size_t>::max()});
InsertAllocator(CreateAllocator(cpu_allocator_info));
}

CUDAExecutionProvider::~CUDAExecutionProvider() {
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/test/providers/provider_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,6 @@ void OpTester::Run(ExpectResult expect_result,
std::unordered_map<std::string, OrtValue> feeds;
std::vector<std::string> output_names;
FillFeedsAndOutputNames(feeds, output_names);

// Run the model
SessionOptions so;
so.session_logid = op_;
Expand Down