Skip to content

Commit c1eef51

Browse files
authored
[Pipeline] Skip condition expression analysis for global reading (#713)
* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107 * Refactor inject_pipeline.cc to enhance pipeline body rewriting and condition handling - Introduced a new function to replace IfThenElse nodes with their then_case while preserving attributes. - Streamlined the PipelineBodyRewriter to improve buffer access rewriting and async state management. - Enhanced the handling of pipeline loop conditions and added support for predicate conditions in the pipeline body. - Removed obsolete code and improved overall code clarity and maintainability. * lint fix * Refactor return statements in inject_pipeline.cc to remove unnecessary std::move calls - Updated return statements in multiple methods to return objects directly instead of using std::move, improving code clarity and potentially avoiding unnecessary moves. - Ensured consistent handling of BufferStore and BufferLoad nodes during pipeline transformations. * test fix * Enhance global read detection in pipeline planning - Updated the handling of global reads to account for condition expressions within IfThenElse nodes, ensuring accurate identification of global memory accesses. - Introduced a new flag to track whether the visitor is within a condition expression, improving the correctness of buffer access analysis. - Refactored the VisitStmt_ method to properly handle the structure of IfThenElse nodes, enhancing the clarity and maintainability of the code.
1 parent 49d5d80 commit c1eef51

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

src/transform/pipeline_planning.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <tvm/tir/transform.h>
77

88
#include "../target/utils.h"
9+
#include "tvm/ir/expr.h"
910

1011
namespace tvm {
1112
namespace tl {
@@ -81,7 +82,11 @@ class BufferRegionCollector : public StmtExprVisitor {
8182
auto load_region = BufferRegion(load_buffer, region);
8283
reads_.push_back(load_region);
8384

84-
if (op->buffer.scope() == "global") {
85+
if (op->buffer.scope() == "global" && !within_condition_expr_) {
86+
// skip condition expr of if_then_else node
87+
// shared[i] = T.if_then_else(global[i] < n, register_a[i], register_b[i])
88+
// is not a global read shared[i] = T.if_then_else(global[i] < n,
89+
// global_a[i], global_b[i]) is a global read
8590
is_global_read_ = true;
8691
}
8792
}
@@ -103,18 +108,38 @@ class BufferRegionCollector : public StmtExprVisitor {
103108
// because we only care about the buffer itself instead of indices
104109
reads_.push_back(buffer_region);
105110
}
111+
} else if (op->op.same_as(builtin::if_then_else())) {
112+
within_condition_expr_ = true;
113+
this->VisitExpr(op->args[0]);
114+
within_condition_expr_ = false;
115+
for (auto i = 1; i < op->args.size(); i++) {
116+
this->VisitExpr(op->args[i]);
117+
}
106118
} else {
107119
StmtExprVisitor::VisitExpr_(op);
108120
}
109121
}
110122

123+
void VisitStmt_(const IfThenElseNode *op) final {
124+
within_condition_expr_ = true;
125+
this->VisitExpr(op->condition);
126+
within_condition_expr_ = false;
127+
this->VisitStmt(op->then_case);
128+
if (op->else_case.defined()) {
129+
within_condition_expr_ = true;
130+
this->VisitStmt(op->else_case.value());
131+
within_condition_expr_ = false;
132+
}
133+
}
134+
111135
private:
112136
Map<Var, Buffer> buffer_data_to_buffer_;
113137
Array<BufferRegion> reads_;
114138
Array<BufferRegion> writes_;
115139
bool is_global_read_ = false;
116140
bool under_buffer_store_ = false;
117141
bool is_global_copy_pattern_ = false;
142+
bool within_condition_expr_ = false;
118143
};
119144

120145
class PipelinePlanner : public StmtExprMutator {

0 commit comments

Comments
 (0)