Skip to content

Commit f4cdc80

Browse files
committed
Merged PR 1227: memory allocation planner testing infrastructure
(a) Introduced PlannerContext to decouple planner from how shape-inference results are stored. This also allows testing planner with simulated shape-inference results (since shape inference will not be ready for a while). (b) Introduced testing infrastructure to test planner (to build multi-node graphs). Related work items: #141
1 parent e9da32c commit f4cdc80

File tree

3 files changed

+269
-188
lines changed

3 files changed

+269
-188
lines changed

lotus/core/framework/allocation_planner.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace Lotus {
2020
class PlannerImpl {
2121
private:
2222
const SessionState* p_session_state_;
23+
const ISequentialPlannerContext* p_context_;
2324
SequentialExecutionPlan* plan_;
2425

2526
// MLValueInfo: Auxiliary information about an MLValue used only during plan-generation:
@@ -171,23 +172,23 @@ class PlannerImpl {
171172

172173
bool SameSize(const LotusIR::NodeArg& arg1, const LotusIR::NodeArg& arg2) {
173174
if ((!arg1.Exists()) || (!arg2.Exists())) return false;
174-
auto p_shape1 = arg1.Shape();
175-
auto p_shape2 = arg2.Shape();
175+
auto p_shape1 = p_context_->GetShape(arg1);
176+
auto p_shape2 = p_context_->GetShape(arg2);
176177
// If the shapes are unknown, we conservatively assume they may be of different size.
177178
if ((nullptr == p_shape1) || (nullptr == p_shape2)) return false;
178179
return SameSize(*p_shape1, arg1.Type(), *p_shape2, arg2.Type());
179180
}
180181

181182
// Find if freelist contains a buffer of the same size as output_arg
182183
bool FindReusableTensor(const LotusIR::NodeArg& output_arg, MLValueIndex* reusable_tensor) {
183-
auto p_required_buffer_shape = output_arg.Shape();
184+
auto p_required_buffer_shape = p_context_->GetShape(output_arg);
184185
if (nullptr == p_required_buffer_shape) return false;
185186
auto required_buffer_type = output_arg.Type();
186187

187188
for (auto it = freelist_.begin(); it != freelist_.end(); ++it) {
188189
auto reusable = it->ml_value;
189190
auto p_node_arg = ml_value_info_.at(reusable).p_def_site;
190-
auto p_available_buffer_shape = p_node_arg->Shape();
191+
auto p_available_buffer_shape = p_context_->GetShape(*p_node_arg);
191192
if (nullptr != p_available_buffer_shape) {
192193
auto available_buffer_type = p_node_arg->Type();
193194
if (SameSize(*p_available_buffer_shape, available_buffer_type, *p_required_buffer_shape, required_buffer_type)) {
@@ -341,8 +342,9 @@ class PlannerImpl {
341342
}
342343

343344
public:
344-
Status CreatePlan(const SessionState& session_state, SequentialExecutionPlan* plan) {
345+
Status CreatePlan(const SessionState& session_state, const ISequentialPlannerContext& context, SequentialExecutionPlan* plan) {
345346
p_session_state_ = &session_state;
347+
p_context_ = &context;
346348
plan_ = plan;
347349

348350
auto p_graph = p_session_state_->GetGraph();
@@ -375,9 +377,10 @@ class PlannerImpl {
375377
}
376378
};
377379

378-
Status SequentialPlanner::CreatePlan(const SessionState& session_state, SequentialExecutionPlan* plan) {
380+
Status SequentialPlanner::CreatePlan(const SessionState& session_state, const ISequentialPlannerContext& context,
381+
SequentialExecutionPlan* plan) {
379382
PlannerImpl planner;
380-
return planner.CreatePlan(session_state, plan);
383+
return planner.CreatePlan(session_state, context, plan);
381384
}
382385

383386
Status AllocationPlanner::CreatePlan(AllocationPlannerType allocation_planner_type,

lotus/core/framework/allocation_planner.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,35 @@ struct SequentialExecutionPlan {
9494
std::vector<MLValueIndex> to_be_freed;
9595
};
9696

97+
// ISequentialPlannerContext abstracts how the planner accesses information (such as inferred shape)
98+
// to do the planning.
99+
class ISequentialPlannerContext {
100+
public:
101+
virtual const TensorShapeProto* GetShape(const LotusIR::NodeArg& arg) const = 0;
102+
};
103+
104+
class SequentialPlannerContext : public ISequentialPlannerContext {
105+
public:
106+
virtual const TensorShapeProto* GetShape(const LotusIR::NodeArg& arg) const override {
107+
(arg);
108+
// Once shape-inference is in place, we can return the result of shape-inference
109+
return nullptr;
110+
}
111+
};
112+
97113
class SequentialPlanner {
98114
public:
115+
// This API allows user to provide a custom planner context. Currently, this is used
116+
// primarily for testing.
99117
static Status CreatePlan(const SessionState& session_state,
118+
const ISequentialPlannerContext& context,
100119
SequentialExecutionPlan* plan);
120+
121+
// This uses a standard planner context and is meant to be the primary API for creating a plan.
122+
static Status CreatePlan(const SessionState& session_state, SequentialExecutionPlan* plan) {
123+
SequentialPlannerContext context;
124+
return CreatePlan(session_state, context, plan);
125+
}
101126
};
102127

103128
/*

0 commit comments

Comments
 (0)