Skip to content

Commit 689ee52

Browse files
authored
[Enhancement] Refactored buffer detection logic in warp_specialized_rewriter.cc (#685)
- Renamed TMAFinder to ProducerBufferDetector and improved handling of CallNode and BufferLoadNode. - This change aims to enhance code maintainability and performance by more accurately tracking producer buffer usage.
1 parent adcba27 commit 689ee52

File tree

1 file changed

+48
-20
lines changed

1 file changed

+48
-20
lines changed

src/transform/warp_specialized_rewriter.cc

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,69 +23,97 @@ using arith::IRVisitorWithAnalyzer;
2323

2424
enum class Role { kConsumer, kProducer, kBoth };
2525

26-
class TMAFinder : public StmtExprVisitor {
26+
class ProducerBufferDetector : public StmtExprVisitor {
2727
public:
28-
void clear() { has_tma_load_ = false; }
28+
ProducerBufferDetector(
29+
std::unordered_set<const BufferNode *> cur_producer_buffers)
30+
: cur_producer_buffers_(cur_producer_buffers) {}
31+
32+
void clear() { has_producer_buffer_ = false; }
2933

3034
void VisitExpr_(const CallNode *call) final {
3135
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
32-
has_tma_load_ = true;
36+
has_producer_buffer_ = true;
3337
}
38+
StmtExprVisitor::VisitExpr_(call);
3439
}
3540

36-
bool has_tma_load_ = false;
41+
void VisitExpr_(const BufferLoadNode *op) final {
42+
if (cur_producer_buffers_.count(op->buffer.get())) {
43+
has_producer_buffer_ = true;
44+
}
45+
StmtExprVisitor::VisitExpr_(op);
46+
}
47+
48+
bool has_producer_buffer_ = false;
49+
std::unordered_set<const BufferNode *> cur_producer_buffers_;
3750
};
3851

3952
class ProducerUsedBufferFinder : public StmtExprVisitor {
4053
public:
4154
auto FindProducerusedBuffer(Stmt stmt) {
42-
VisitStmt(stmt);
43-
return used_in_producer_cond_;
55+
producer_buffers_.clear();
56+
std::unordered_set<const BufferNode *> last_producer_buffers_;
57+
for (;;) {
58+
VisitStmt(stmt);
59+
if (producer_buffers_ == last_producer_buffers_) {
60+
break;
61+
}
62+
last_producer_buffers_ = producer_buffers_;
63+
}
64+
return producer_buffers_;
4465
}
4566

4667
void InsertBuffer(const PrimExpr &expr) {
4768
// Find the buffer that is used in the condition
4869
VarUseDefAnalyzer usage(Array<Var>{});
4970
usage(expr);
5071
for (const auto &buffer : usage.buffer_use_count_) {
51-
used_in_producer_cond_.insert(buffer.first);
72+
producer_buffers_.insert(buffer.first);
5273
}
5374
}
5475

5576
void VisitStmt_(const IfThenElseNode *op) final {
56-
TMAFinder tma_finder;
57-
tma_finder(op->then_case);
77+
ProducerBufferDetector producer_buffer_detector(producer_buffers_);
78+
producer_buffer_detector(op->then_case);
5879
if (op->else_case.defined()) {
59-
tma_finder(op->else_case.value());
80+
producer_buffer_detector(op->else_case.value());
6081
}
61-
if (tma_finder.has_tma_load_) {
82+
if (producer_buffer_detector.has_producer_buffer_) {
6283
InsertBuffer(op->condition);
6384
}
6485
StmtExprVisitor::VisitStmt_(op);
6586
}
6687

6788
void VisitStmt_(const ForNode *op) final {
68-
TMAFinder tma_finder;
69-
tma_finder(op->body);
70-
if (tma_finder.has_tma_load_) {
89+
ProducerBufferDetector producer_buffer_detector(producer_buffers_);
90+
producer_buffer_detector(op->body);
91+
if (producer_buffer_detector.has_producer_buffer_) {
7192
InsertBuffer(op->min);
7293
InsertBuffer(op->extent);
7394
}
7495
StmtExprVisitor::VisitStmt_(op);
7596
}
7697

98+
void VisitStmt_(const BufferStoreNode *op) final {
99+
if (producer_buffers_.count(op->buffer.get())) {
100+
InsertBuffer(op->value);
101+
}
102+
StmtExprVisitor::VisitStmt_(op);
103+
}
104+
77105
void VisitExpr_(const CallNode *op) final {
78106
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
79107
for (auto arg : op->args) {
80108
if (auto buffer_load = arg.as<BufferLoadNode>()) {
81-
used_in_producer_cond_.insert(buffer_load->buffer.get());
109+
producer_buffers_.insert(buffer_load->buffer.get());
82110
}
83111
}
84112
}
85113
}
86114

87115
private:
88-
std::unordered_set<const BufferNode *> used_in_producer_cond_;
116+
std::unordered_set<const BufferNode *> producer_buffers_;
89117
};
90118

91119
class WarpSpecializedRoleMarker : public StmtVisitor {
@@ -95,7 +123,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
95123

96124
void Prepare(const Stmt &stmt) {
97125
ProducerUsedBufferFinder finder;
98-
used_in_producer_cond_ = finder.FindProducerusedBuffer(stmt);
126+
producer_buffers_ = finder.FindProducerusedBuffer(stmt);
99127
}
100128

101129
Role GetRole(const StmtNode *stmt) const {
@@ -123,7 +151,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
123151
void VisitStmt_(const BufferStoreNode *op) final {
124152
bool is_shared_store =
125153
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
126-
if (used_in_producer_cond_.count(op->buffer.get())) {
154+
if (producer_buffers_.count(op->buffer.get())) {
127155
SetRole(op, Role::kBoth);
128156
return;
129157
}
@@ -207,7 +235,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
207235
std::unordered_map<const StmtNode *, Role> map_;
208236
bool has_simt_copy_ = false;
209237
bool has_bulk_copy_ = false;
210-
std::unordered_set<const BufferNode *> used_in_producer_cond_;
238+
std::unordered_set<const BufferNode *> producer_buffers_;
211239
};
212240

213241
static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
@@ -1112,7 +1140,7 @@ class WarpSpecializedRewriter : public StmtExprMutator {
11121140

11131141
auto inc_reg_stmt = Evaluate(0);
11141142
auto dec_reg_stmt = Evaluate(0);
1115-
if (dec_reg >= 0 && inc_reg >= 0) {
1143+
if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) {
11161144
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
11171145
{inc_reg == 0 ? 240 : inc_reg, 1}));
11181146
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),

0 commit comments

Comments
 (0)