Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
linkerzhang committed Aug 6, 2019
1 parent 26e69ad commit 7d03eaf
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ class PlannerImpl {

Status ComputeUseCounts() {
// Note: for every ml-value, its definition must appear before all its uses in a topological sort of a valid model
std::unordered_set<std::string> graph_inputs;
for (auto& graph_input : graph_viewer_.GetInputsIncludingInitializers()) {
graph_inputs.insert(graph_input->Name());
}

for (auto graph_input : graph_viewer_.GetInputs()) {
OrtValueIndex index = Index(graph_input->Name());
Expand All @@ -368,15 +372,7 @@ class PlannerImpl {
for (SequentialExecutionPlan::NodeExecutionPlan& step : plan_.execution_plan) {
auto pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the node ", step.node_index);
for (auto node_input : pnode->InputDefs()) {
if (node_input->Exists())
UseCount(node_input->Name())++;
}

for (auto node_input : pnode->ImplicitInputDefs()) {
if (node_input->Exists())
UseCount(node_input->Name())++;
}
// Identify where each output of this node should be allocated.
// This is determined by the opkernel bound to the node.
const KernelCreateInfo* kernel_create_info = nullptr;
Expand All @@ -391,17 +387,34 @@ class PlannerImpl {
if (!pnode->Name().empty()) errormsg << " (node " << pnode->Name() << ")";
return Status(ONNXRUNTIME, FAIL, errormsg.str());
}

auto exec_provider = execution_providers_.Get(*pnode);
if (exec_provider == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Can not find the execution provider ",
pnode->GetExecutionProviderType());
}

auto inputs = pnode->InputDefs();
auto num_inputs = inputs.size();
for (size_t i = 0; i < num_inputs; ++i) {
if (inputs[i]->Exists()) {
UseCount(inputs[i]->Name())++;
if (graph_inputs.end() != graph_inputs.find(inputs[i]->Name())) {
// If it's a graph input, set its plan.
// NOTE: Copy nodes should have already been added if a graph input is fed as inputs of nodes assigned to different providers.
OrtValueIndex index = Index(inputs[i]->Name());
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetAllocator(0, p_kernelDef->InputMemoryType(i))->Info());
}
}
}

for (auto node_input : pnode->ImplicitInputDefs()) {
if (node_input->Exists())
UseCount(node_input->Name())++;
}

auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
auto outputs = pnode->OutputDefs();
auto num_outputs = outputs.size();

for (size_t i = 0; i < num_outputs; ++i) {
auto* node_output = outputs[i];
if (!node_output->Exists()) continue;
Expand Down

0 comments on commit 7d03eaf

Please sign in to comment.