@@ -119,6 +119,19 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
119119 const AllocateNode *alloc{nullptr };
120120 };
121121
122+ struct StmtAttr {
123+ // the level in the scope stack
124+ size_t level{0 };
125+ };
126+
127+ void UpdateStmtAttr (const Object *stmt, size_t level) {
128+ if (stmt_attrs_.find (stmt) == stmt_attrs_.end ()) {
129+ stmt_attrs_[stmt] = StmtAttr{level};
130+ } else {
131+ stmt_attrs_[stmt].level = level;
132+ }
133+ }
134+
122135 void VisitStmt_ (const AllocateNode *op) final {
123136 size_t level = scope_.size ();
124137 const VarNode *buf = op->buffer_var .get ();
@@ -137,13 +150,14 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
137150 if (it != alloc_info_.end () && it->second .alloc ) {
138151 ICHECK_LT (it->second .level , scope_.size ());
139152 if (IsAppropriateSharedMemory (GetRef<Var>(buf))) {
140- scope_[it-> second . level ].touched .push_back (buf);
153+ scope_[scope_. size () - 1 ].touched .push_back (buf);
141154 }
142155 }
143156 StmtEntry e = scope_.back ();
144157 scope_.pop_back ();
145158 if (e.touched .size () != 0 ) {
146159 e.stmt = op;
160+ UpdateStmtAttr (op, scope_level_);
147161 linear_seq_.push_back (e);
148162 }
149163 }
@@ -156,6 +170,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
156170 scope_.pop_back ();
157171 if (e.touched .size () != 0 ) {
158172 e.stmt = op;
173+ UpdateStmtAttr (op, scope_level_);
159174 linear_seq_.push_back (e);
160175 }
161176 }
@@ -169,7 +184,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
169184 ICHECK_LT (it->second .level , scope_.size ())
170185 << " Load memory in places other than store." ;
171186 if (IsAppropriateSharedMemory (GetRef<Var>(buf))) {
172- scope_[it-> second . level ].touched .push_back (buf);
187+ scope_[scope_. size () - 1 ].touched .push_back (buf);
173188 }
174189 }
175190 }
@@ -180,7 +195,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
180195 if (it != alloc_info_.end () && it->second .alloc ) {
181196 ICHECK_LT (it->second .level , scope_.size ());
182197 if (IsAppropriateSharedMemory (GetRef<Var>(buf))) {
183- scope_[it-> second . level ].touched .push_back (buf);
198+ scope_[scope_. size () - 1 ].touched .push_back (buf);
184199 }
185200 }
186201 }
@@ -189,6 +204,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
189204 scope_.push_back (StmtEntry ());
190205 StmtEntry e;
191206 e.stmt = op;
207+ UpdateStmtAttr (op, scope_level_);
192208 int64_t begin_index = static_cast <int64_t >(linear_seq_.size ());
193209 // before scope.
194210 linear_seq_.push_back (e);
@@ -226,7 +242,15 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
226242
227243 void VisitStmt_ (const IfThenElseNode *op) final { VisitNewScope (op); }
228244
229- void VisitStmt_ (const ForNode *op) final { VisitNewScope (op); }
245+ void VisitStmt_ (const ForNode *op) final {
246+ if (op->body ->IsInstance <SeqStmtNode>()) {
247+ scope_level_++;
248+ VisitNewScope (op);
249+ scope_level_--;
250+ } else {
251+ VisitNewScope (op);
252+ }
253+ }
230254
231255 void VisitStmt_ (const WhileNode *op) final { VisitNewScope (op); }
232256
@@ -236,6 +260,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
236260 std::vector<StmtEntry> linear_seq_;
237261 // The storage scope of each buffer
238262 std::unordered_map<const VarNode *, AllocEntry> alloc_info_;
263+ // The attribute of each statement
264+ std::unordered_map<const Object *, StmtAttr> stmt_attrs_;
239265
240266private:
241267 // Wrapper function to determine if the shared memory allocation for a
@@ -251,6 +277,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
251277 bool in_thread_env_{false };
252278 // The scope stack.
253279 std::vector<StmtEntry> scope_;
280+ // The size of the scope.
281+ size_t scope_level_{0 };
254282};
255283
256284/* !
@@ -279,8 +307,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
279307 bool verbose = false ) {
280308 SharedMemLinearAccessPatternFinder finder (is_dynamic, verbose);
281309 finder (stmt);
282- this ->LivenessAnalysis (finder.linear_seq_ );
283- this ->PlanMemory (finder.linear_seq_ );
310+ this ->LivenessAnalysis (finder.linear_seq_ , finder. stmt_attrs_ );
311+ this ->PlanMemory (finder.linear_seq_ , finder. stmt_attrs_ );
284312 }
285313
286314private:
@@ -491,6 +519,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
491519 }
492520
493521 using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
522+ using StmtAttr = SharedMemLinearAccessPatternFinder::StmtAttr;
494523 struct StorageEntry {
495524 // The constant size of the buffer in bits, only used if it is constant
496525 uint64_t const_nbits{0 };
@@ -515,7 +544,9 @@ class SharedMemoryRewriter : public StmtExprMutator {
515544 * \brief Liveness analysis to find gen and kill point of each variable.
516545 * \param seq the linear pattern of storage access
517546 */
518- void LivenessAnalysis (const std::vector<StmtEntry> &seq) {
547+ void LivenessAnalysis (
548+ const std::vector<StmtEntry> &seq,
549+ const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
519550 // find kill point, do a reverse linear scan.
520551 std::unordered_set<const VarNode *> touched;
521552 for (size_t i = seq.size (); i != 0 ; --i) {
@@ -543,17 +574,174 @@ class SharedMemoryRewriter : public StmtExprMutator {
543574 }
544575
545576 if (verbose_) {
546- LOG (DEBUG) << " Liveness Analysis Results for "
577+ std::vector<const Object *> stmt_keys;
578+ for (const auto &stmt_entry : seq) {
579+ auto stmt = stmt_entry.stmt ;
580+ if (std::find (stmt_keys.begin (), stmt_keys.end (), stmt) ==
581+ stmt_keys.end ()) {
582+ stmt_keys.push_back (stmt);
583+ }
584+ }
585+ LOG (DEBUG) << " Before reorder kill points, Liveness Analysis Results for "
547586 << (is_dynamic_ ? " Dynamic" : " Static" ) << " Shared Memory:" ;
548- for (const auto &pair : event_map_) {
549- const Object *stmt_obj = pair.first ;
550- const EventEntry &entry = pair.second ;
587+ for (const auto &stmt_key : stmt_keys) {
588+ auto it = event_map_.find (stmt_key);
589+ if (it == event_map_.end ())
590+ continue ;
591+
592+ const EventEntry &entry = it->second ;
593+ if (entry.gen .empty () && entry.kill .empty ())
594+ continue ;
595+ ICHECK (stmt_attrs.count (stmt_key))
596+ << " stmt_key = " << stmt_key->GetTypeKey ();
597+ auto level = stmt_attrs.at (stmt_key).level ;
598+ LOG (DEBUG) << " Statement: " << stmt_key->GetTypeKey ()
599+ << " (scope_level: " << level << " )" ;
600+
601+ std::stringstream gen_vars_ss;
602+ bool x_generated = false ;
603+ for (const VarNode *var : entry.gen ) {
604+ gen_vars_ss << var->name_hint << " " ;
605+ if (var->name_hint == " x" ) {
606+ x_generated = true ;
607+ }
608+ }
609+ if (!entry.gen .empty ()) {
610+ std::string gen_log_msg = " GEN: " + gen_vars_ss.str ();
611+ if (x_generated) {
612+ gen_log_msg += " <-- Buffer 'x' generated" ;
613+ }
614+ LOG (DEBUG) << gen_log_msg;
615+ }
616+
617+ std::stringstream kill_vars_ss;
618+ bool x_killed = false ;
619+ for (const VarNode *var : entry.kill ) {
620+ kill_vars_ss << var->name_hint << " " ;
621+ if (var->name_hint == " x" ) {
622+ x_killed = true ;
623+ }
624+ }
625+ if (!entry.kill .empty ()) {
626+ std::string kill_log_msg = " KILL: " + kill_vars_ss.str ();
627+ if (x_killed) {
628+ kill_log_msg += " <-- Buffer 'x' killed" ;
629+ }
630+ LOG (DEBUG) << kill_log_msg;
631+ }
632+ }
633+ LOG (DEBUG) << " End of Liveness Analysis Results." ;
634+ }
635+
636+ // Reorder kill points:
637+ // For each buffer, if its kill statement is at a deeper scope level than
638+ // its gen statement, we need to move the kill point to the end of the gen
639+ // statement's scope level. This ensures proper memory deallocation at the
640+ // right scope boundary.
641+ std::vector<StmtEntry> gen_kill_seq;
642+ for (const auto &stmt_entry : seq) {
643+ // if has gen and kill, add to gen_kill_seq
644+ if (event_map_[stmt_entry.stmt ].gen .size () > 0 ||
645+ event_map_[stmt_entry.stmt ].kill .size () > 0 ) {
646+ gen_kill_seq.push_back (stmt_entry);
647+ }
648+ }
551649
552- if (entry.gen .empty () && entry.kill .empty ()) {
553- continue ; // Skip statements with no gen/kill events for brevity
650+ for (auto &event_pair : event_map_) {
651+ const Object *stmt = event_pair.first ;
652+ EventEntry &event = event_pair.second ;
653+
654+ // Skip if no kill points to process
655+ if (event.kill .empty ())
656+ continue ;
657+
658+ // Get scope level of current statement
659+ ICHECK (stmt_attrs.count (stmt));
660+ int kill_level = stmt_attrs.at (stmt).level ;
661+
662+ std::unordered_set<const VarNode *> visited_buffers;
663+
664+ // For each killed buffer, find its gen statement and check scope levels
665+ for (auto it = event.kill .begin (); it != event.kill .end ();) {
666+ const VarNode *buffer = *it;
667+ bool found_gen = false ;
668+ int gen_level = 0 ;
669+
670+ // Find the gen statement for this buffer
671+ for (const auto &gen_pair : event_map_) {
672+ const auto &gen_event = gen_pair.second ;
673+ if (std::find (gen_event.gen .begin (), gen_event.gen .end (), buffer) !=
674+ gen_event.gen .end ()) {
675+ found_gen = true ;
676+ gen_level = stmt_attrs.at (gen_pair.first ).level ;
677+ break ;
678+ }
679+ }
680+
681+ if (found_gen && kill_level > gen_level) {
682+ if (visited_buffers.count (buffer)) {
683+ ++it;
684+ continue ;
685+ }
686+ // Need to move kill point - remove from current event
687+ it = event.kill .erase (it);
688+
689+ // Find the last statement at gen_level and add kill point there
690+ // Find the last statement at gen_level in the sequence
691+ const Object *last_stmt_at_level = nullptr ;
692+ auto stmt_it = gen_kill_seq.begin ();
693+ for (; stmt_it != gen_kill_seq.end (); ++stmt_it) {
694+ if (stmt_it->stmt == stmt) {
695+ break ;
696+ }
697+ }
698+ // start from current statement and find the last statement at
699+ // gen_level
700+
701+ for (; stmt_it != gen_kill_seq.end (); ++stmt_it) {
702+ // Check if next statement has different level
703+ auto next_it = stmt_it + 1 ;
704+ if (next_it == gen_kill_seq.end () ||
705+ stmt_attrs.at (next_it->stmt ).level == gen_level) {
706+ last_stmt_at_level = stmt_it->stmt ;
707+ break ;
708+ }
709+ }
710+ if (last_stmt_at_level) {
711+ event_map_[last_stmt_at_level].kill .push_back (buffer);
712+ visited_buffers.insert (buffer);
713+ }
714+ } else {
715+ ++it;
554716 }
717+ }
718+ }
555719
556- LOG (DEBUG) << " Statement: " << stmt_obj->GetTypeKey ();
720+ std::vector<const Object *> stmt_keys;
721+ for (const auto &stmt_entry : seq) {
722+ auto stmt = stmt_entry.stmt ;
723+ if (std::find (stmt_keys.begin (), stmt_keys.end (), stmt) ==
724+ stmt_keys.end ()) {
725+ stmt_keys.push_back (stmt);
726+ }
727+ }
728+
729+ if (verbose_) {
730+ LOG (DEBUG) << " Liveness Analysis Results for "
731+ << (is_dynamic_ ? " Dynamic" : " Static" ) << " Shared Memory:" ;
732+ for (const auto &stmt_key : stmt_keys) {
733+ auto it = event_map_.find (stmt_key);
734+ if (it == event_map_.end ())
735+ continue ;
736+
737+ const EventEntry &entry = it->second ;
738+ if (entry.gen .empty () && entry.kill .empty ())
739+ continue ;
740+ ICHECK (stmt_attrs.count (stmt_key))
741+ << " stmt_key = " << stmt_key->GetTypeKey ();
742+ auto level = stmt_attrs.at (stmt_key).level ;
743+ LOG (DEBUG) << " Statement: " << stmt_key->GetTypeKey ()
744+ << " (scope_level: " << level << " )" ;
557745
558746 std::stringstream gen_vars_ss;
559747 bool x_generated = false ;
@@ -596,7 +784,9 @@ class SharedMemoryRewriter : public StmtExprMutator {
596784 * \param seq the linear pattern of storage access
597785 * \param alloc_info
598786 */
599- void PlanMemory (const std::vector<StmtEntry> &seq) {
787+ void
788+ PlanMemory (const std::vector<StmtEntry> &seq,
789+ const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
600790 std::unordered_set<const VarNode *> inplace_flag;
601791
602792 for (size_t i = 0 ; i < seq.size (); ++i) {
0 commit comments