@@ -23,69 +23,97 @@ using arith::IRVisitorWithAnalyzer;
2323
2424enum class Role { kConsumer , kProducer , kBoth };
2525
26- class TMAFinder : public StmtExprVisitor {
26+ class ProducerBufferDetector : public StmtExprVisitor {
2727public:
28- void clear () { has_tma_load_ = false ; }
28+ ProducerBufferDetector (
29+ std::unordered_set<const BufferNode *> cur_producer_buffers)
30+ : cur_producer_buffers_(cur_producer_buffers) {}
31+
32+ void clear () { has_producer_buffer_ = false ; }
2933
3034 void VisitExpr_ (const CallNode *call) final {
3135 if (call->op .same_as (tma_load ()) || call->op .same_as (tma_load_im2col ())) {
32- has_tma_load_ = true ;
36+ has_producer_buffer_ = true ;
3337 }
38+ StmtExprVisitor::VisitExpr_ (call);
3439 }
3540
36- bool has_tma_load_ = false ;
41+ void VisitExpr_ (const BufferLoadNode *op) final {
42+ if (cur_producer_buffers_.count (op->buffer .get ())) {
43+ has_producer_buffer_ = true ;
44+ }
45+ StmtExprVisitor::VisitExpr_ (op);
46+ }
47+
48+ bool has_producer_buffer_ = false ;
49+ std::unordered_set<const BufferNode *> cur_producer_buffers_;
3750};
3851
3952class ProducerUsedBufferFinder : public StmtExprVisitor {
4053public:
4154 auto FindProducerusedBuffer (Stmt stmt) {
42- VisitStmt (stmt);
43- return used_in_producer_cond_;
55+ producer_buffers_.clear ();
56+ std::unordered_set<const BufferNode *> last_producer_buffers_;
57+ for (;;) {
58+ VisitStmt (stmt);
59+ if (producer_buffers_ == last_producer_buffers_) {
60+ break ;
61+ }
62+ last_producer_buffers_ = producer_buffers_;
63+ }
64+ return producer_buffers_;
4465 }
4566
4667 void InsertBuffer (const PrimExpr &expr) {
4768 // Find the buffer that is used in the condition
4869 VarUseDefAnalyzer usage (Array<Var>{});
4970 usage (expr);
5071 for (const auto &buffer : usage.buffer_use_count_ ) {
51- used_in_producer_cond_ .insert (buffer.first );
72+ producer_buffers_ .insert (buffer.first );
5273 }
5374 }
5475
5576 void VisitStmt_ (const IfThenElseNode *op) final {
56- TMAFinder tma_finder ;
57- tma_finder (op->then_case );
77+ ProducerBufferDetector producer_buffer_detector (producer_buffers_) ;
78+ producer_buffer_detector (op->then_case );
5879 if (op->else_case .defined ()) {
59- tma_finder (op->else_case .value ());
80+ producer_buffer_detector (op->else_case .value ());
6081 }
61- if (tma_finder. has_tma_load_ ) {
82+ if (producer_buffer_detector. has_producer_buffer_ ) {
6283 InsertBuffer (op->condition );
6384 }
6485 StmtExprVisitor::VisitStmt_ (op);
6586 }
6687
6788 void VisitStmt_ (const ForNode *op) final {
68- TMAFinder tma_finder ;
69- tma_finder (op->body );
70- if (tma_finder. has_tma_load_ ) {
89+ ProducerBufferDetector producer_buffer_detector (producer_buffers_) ;
90+ producer_buffer_detector (op->body );
91+ if (producer_buffer_detector. has_producer_buffer_ ) {
7192 InsertBuffer (op->min );
7293 InsertBuffer (op->extent );
7394 }
7495 StmtExprVisitor::VisitStmt_ (op);
7596 }
7697
98+ void VisitStmt_ (const BufferStoreNode *op) final {
99+ if (producer_buffers_.count (op->buffer .get ())) {
100+ InsertBuffer (op->value );
101+ }
102+ StmtExprVisitor::VisitStmt_ (op);
103+ }
104+
77105 void VisitExpr_ (const CallNode *op) final {
78106 if (op->op .same_as (tma_load ()) || op->op .same_as (tma_load_im2col ())) {
79107 for (auto arg : op->args ) {
80108 if (auto buffer_load = arg.as <BufferLoadNode>()) {
81- used_in_producer_cond_ .insert (buffer_load->buffer .get ());
109+ producer_buffers_ .insert (buffer_load->buffer .get ());
82110 }
83111 }
84112 }
85113 }
86114
87115private:
88- std::unordered_set<const BufferNode *> used_in_producer_cond_ ;
116+ std::unordered_set<const BufferNode *> producer_buffers_ ;
89117};
90118
91119class WarpSpecializedRoleMarker : public StmtVisitor {
@@ -95,7 +123,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
95123
96124 void Prepare (const Stmt &stmt) {
97125 ProducerUsedBufferFinder finder;
98- used_in_producer_cond_ = finder.FindProducerusedBuffer (stmt);
126+ producer_buffers_ = finder.FindProducerusedBuffer (stmt);
99127 }
100128
101129 Role GetRole (const StmtNode *stmt) const {
@@ -123,7 +151,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
123151 void VisitStmt_ (const BufferStoreNode *op) final {
124152 bool is_shared_store =
125153 op->buffer .scope () == " shared.dyn" || op->buffer .scope () == " shared" ;
126- if (used_in_producer_cond_ .count (op->buffer .get ())) {
154+ if (producer_buffers_ .count (op->buffer .get ())) {
127155 SetRole (op, Role::kBoth );
128156 return ;
129157 }
@@ -207,7 +235,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
207235 std::unordered_map<const StmtNode *, Role> map_;
208236 bool has_simt_copy_ = false ;
209237 bool has_bulk_copy_ = false ;
210- std::unordered_set<const BufferNode *> used_in_producer_cond_ ;
238+ std::unordered_set<const BufferNode *> producer_buffers_ ;
211239};
212240
213241static PrimExpr makeGetBarrier (PrimExpr barrier_id) {
@@ -1112,7 +1140,7 @@ class WarpSpecializedRewriter : public StmtExprMutator {
11121140
11131141 auto inc_reg_stmt = Evaluate (0 );
11141142 auto dec_reg_stmt = Evaluate (0 );
1115- if (dec_reg >= 0 && inc_reg >= 0 ) {
1143+ if (dec_reg >= 0 && inc_reg >= 0 && !marker. HasSimtCopy () ) {
11161144 inc_reg_stmt = Evaluate (Call (DataType::Handle (), set_max_nreg (),
11171145 {inc_reg == 0 ? 240 : inc_reg, 1 }));
11181146 dec_reg_stmt = Evaluate (Call (DataType::Handle (), set_max_nreg (),
0 commit comments