@@ -20,6 +20,7 @@ namespace Lotus {
2020class PlannerImpl {
2121 private:
2222 const SessionState* p_session_state_;
23+ const ISequentialPlannerContext* p_context_;
2324 SequentialExecutionPlan* plan_;
2425
2526 // MLValueInfo: Auxiliary information about an MLValue used only during plan-generation:
@@ -171,23 +172,23 @@ class PlannerImpl {
171172
172173 bool SameSize (const LotusIR::NodeArg& arg1, const LotusIR::NodeArg& arg2) {
173174 if ((!arg1.Exists ()) || (!arg2.Exists ())) return false ;
174- auto p_shape1 = arg1. Shape ( );
175- auto p_shape2 = arg2. Shape ( );
175+ auto p_shape1 = p_context_-> GetShape (arg1 );
176+ auto p_shape2 = p_context_-> GetShape (arg2 );
176177 // If the shapes are unknown, we conservatively assume they may be of different size.
177178 if ((nullptr == p_shape1) || (nullptr == p_shape2)) return false ;
178179 return SameSize (*p_shape1, arg1.Type (), *p_shape2, arg2.Type ());
179180 }
180181
181182 // Find if freelist contains a buffer of the same size as output_arg
182183 bool FindReusableTensor (const LotusIR::NodeArg& output_arg, MLValueIndex* reusable_tensor) {
183- auto p_required_buffer_shape = output_arg. Shape ( );
184+ auto p_required_buffer_shape = p_context_-> GetShape (output_arg );
184185 if (nullptr == p_required_buffer_shape) return false ;
185186 auto required_buffer_type = output_arg.Type ();
186187
187188 for (auto it = freelist_.begin (); it != freelist_.end (); ++it) {
188189 auto reusable = it->ml_value ;
189190 auto p_node_arg = ml_value_info_.at (reusable).p_def_site ;
190- auto p_available_buffer_shape = p_node_arg-> Shape ( );
191+ auto p_available_buffer_shape = p_context_-> GetShape (*p_node_arg );
191192 if (nullptr != p_available_buffer_shape) {
192193 auto available_buffer_type = p_node_arg->Type ();
193194 if (SameSize (*p_available_buffer_shape, available_buffer_type, *p_required_buffer_shape, required_buffer_type)) {
@@ -341,8 +342,9 @@ class PlannerImpl {
341342 }
342343
343344 public:
344- Status CreatePlan (const SessionState& session_state, SequentialExecutionPlan* plan) {
345+ Status CreatePlan (const SessionState& session_state, const ISequentialPlannerContext& context, SequentialExecutionPlan* plan) {
345346 p_session_state_ = &session_state;
347+ p_context_ = &context;
346348 plan_ = plan;
347349
348350 auto p_graph = p_session_state_->GetGraph ();
@@ -375,9 +377,10 @@ class PlannerImpl {
375377 }
376378};
377379
378- Status SequentialPlanner::CreatePlan (const SessionState& session_state, SequentialExecutionPlan* plan) {
380+ Status SequentialPlanner::CreatePlan (const SessionState& session_state, const ISequentialPlannerContext& context,
381+ SequentialExecutionPlan* plan) {
379382 PlannerImpl planner;
380- return planner.CreatePlan (session_state, plan);
383+ return planner.CreatePlan (session_state, context, plan);
381384}
382385
383386Status AllocationPlanner::CreatePlan (AllocationPlannerType allocation_planner_type,
0 commit comments