Skip to content

Commit abf2ee0

Browse files
committed
Merged PR 1137: Merge rama/optarg to master
Extend static memory planner to support optional arguments in nodes.
1 parent bb3737f commit abf2ee0

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

lotus/core/framework/allocation_planner.cc

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
#include "core/graph/utils.h"
88
#include "core/framework/data_types.h"
99

10-
/* TODO: Need to address
11-
- placement of constant tensors (e.g., weights) in shared
12-
- ensure output buffers are allocated in appropriate arena
13-
- handle optional inputs
10+
/*
11+
TODO:
12+
- (Not for milestone 1)
13+
- handle different types of devices:
14+
- identify placement of all tensors and ml-values
15+
- insert copies between different devices as required
1416
*/
1517

1618
namespace Lotus {
@@ -210,10 +212,13 @@ class PlannerImpl {
210212

211213
for (SequentialExecutionPlan::NodeExecutionPlan& step : execution_plan) {
212214
auto pnode = graph.GetNode(step.node_index);
213-
for (auto node_input : pnode->InputDefs())
214-
UseCount(node_input->Name())++;
215+
for (auto node_input : pnode->InputDefs()) {
216+
if (node_input->Exists())
217+
UseCount(node_input->Name())++;
218+
}
215219
for (auto node_output : pnode->OutputDefs())
216-
ProcessDef(node_output);
220+
if (node_output->Exists())
221+
ProcessDef(node_output);
217222
}
218223

219224
for (auto graph_output : graph.GetOutputs()) {
@@ -249,6 +254,7 @@ class PlannerImpl {
249254
// determine allocation for outputs of pnode
250255
int output_arg_num = 0;
251256
for (auto node_output : pnode->OutputDefs()) {
257+
if (!node_output->Exists()) continue;
252258
auto current = index(node_output->Name());
253259
AllocPlan(current).value_type = GetMLDataType(*node_output);
254260
MLValueIndex reused;
@@ -269,17 +275,21 @@ class PlannerImpl {
269275
}
270276
// determine if inputs of *pnode can be freed:
271277
for (auto node_input : pnode->InputDefs()) {
272-
auto& sym = node_input->Name();
273-
auto original = Buffer(index(sym));
274-
if (0 == --UseCount(original))
275-
freelist_.push_front(FreeBufferInfo(original, program_counter));
278+
if (node_input->Exists()) {
279+
auto& sym = node_input->Name();
280+
auto original = Buffer(index(sym));
281+
if (0 == --UseCount(original))
282+
freelist_.push_front(FreeBufferInfo(original, program_counter));
283+
}
276284
}
277285
// determine if any outputs of *pnode are unused and can be freed:
278286
for (auto node_output : pnode->OutputDefs()) {
279-
auto& sym = node_output->Name();
280-
auto original = Buffer(index(sym));
281-
if (0 == UseCount(original))
282-
freelist_.push_front(FreeBufferInfo(original, program_counter));
287+
if (node_output->Exists()) {
288+
auto& sym = node_output->Name();
289+
auto original = Buffer(index(sym));
290+
if (0 == UseCount(original))
291+
freelist_.push_front(FreeBufferInfo(original, program_counter));
292+
}
283293
}
284294
}
285295
}

0 commit comments

Comments
 (0)