Skip to content

Commit 9f7bac4

Browse files
authored
[Refactor] Backup Analyzer to get the appropriate arith informations (#1311)
* [Refactor] Update Vectorization Functions to Accept Analyzer Parameter - Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization. - Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness. - Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities. * [Fix] Corrected PostOrderVisit call in loop_vectorize.cc - Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis. * fix * lint fix * fix
1 parent 721baed commit 9f7bac4

File tree

8 files changed

+87
-46
lines changed

8 files changed

+87
-46
lines changed

3rdparty/tvm

Submodule tvm updated from bc31e7a to cd2b2b6

src/op/copy.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
852852
auto par_op = ParallelOp(transformed_loop);
853853

854854
if (is_cpu_target) {
855-
vectorized_thread_loop = VectorizeLoop(transformed_loop);
855+
vectorized_thread_loop = VectorizeLoop(transformed_loop, analyzer);
856856
} else {
857857
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
858858
InferLevel::kFree};
@@ -865,7 +865,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
865865
auto thread_var = T.thread_var;
866866
auto thread_loop =
867867
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
868-
vectorized_thread_loop = VectorizeLoop(thread_loop);
868+
vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
869869
}
870870

871871
if (par_op->GetPredicate(T.thread_var).defined()) {

src/op/fill.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,15 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
207207
InferLevel::kFree);
208208
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
209209
par_op->GetLoopLayout());
210-
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
210+
auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
211211
if (par_op->GetPredicate(T.thread_var).defined()) {
212212
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
213213
vectorized_thread_loop);
214214
}
215215
return vectorized_thread_loop;
216216
} else if (dst.scope() == "local") {
217217
auto init_loop = MakeSIMTLoop(analyzer);
218-
auto vectorized_thread_loop = VectorizeLoop(init_loop);
218+
auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
219219
return vectorized_thread_loop;
220220
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
221221
dst.scope() == "global") {
@@ -225,7 +225,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
225225
InferLevel::kFree);
226226
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
227227
par_op->GetLoopLayout());
228-
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
228+
auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
229229
if (par_op->GetPredicate(T.thread_var).defined()) {
230230
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
231231
vectorized_thread_loop);

src/op/parallel.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
452452
// As the pass will do post processing to the layout
453453
auto maybe_remapped_root_ =
454454
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
455-
int vector_size = GetVectorizeSize(maybe_remapped_root_);
456-
455+
int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer);
457456
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';
458457

459458
PrimExpr loop_total_size = 1;

src/transform/layout_inference.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <tvm/tir/utils.h>
1313

1414
#include <algorithm>
15+
#include <memory>
1516
#include <queue>
1617

1718
#include "../layout/utils.h"
@@ -85,6 +86,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
8586
auto &next = infer_list_[cur_infer_id];
8687
auto iter_var = thread_var_vec_[cur_infer_id];
8788
auto thread_bounds = thread_bounds_vec_[cur_infer_id];
89+
arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get();
8890
auto buffer_oob = buffer_oob_vec_[cur_infer_id];
8991
// Double-check that 'next' is valid
9092
ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
@@ -108,7 +110,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
108110
// Run InferLayout
109111
auto updates =
110112
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
111-
&analyzer_, buffer_oob},
113+
cur_analyzer, buffer_oob},
112114
level);
113115
// Process the returned updates
114116
for (const auto &[buffer, layout] : updates) {
@@ -266,6 +268,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
266268
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
267269
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
268270
"length.";
271+
ICHECK_EQ(analyzer_vec_.size(), infer_list_.size())
272+
<< "Size mismatch: analyzer_vec_ and infer_list_ must match in "
273+
"length.";
269274
ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size())
270275
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
271276
"length.";
@@ -452,6 +457,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
452457
} else {
453458
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
454459
}
460+
analyzer_vec_.push_back(analyzer_.Clone());
455461

456462
// Compute buffer oob for each buffer in the op
457463
if (const auto *copy = p.as<CopyNode>()) {
@@ -542,6 +548,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
542548
} else {
543549
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
544550
}
551+
analyzer_vec_.push_back(analyzer_.Clone());
545552
buffer_oob_vec_.push_back(false);
546553
} else {
547554
IRVisitorWithAnalyzer::VisitStmt(op->body);
@@ -683,6 +690,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
683690
IterVarType::kDataPar);
684691
std::vector<IterVar> thread_var_vec_;
685692
std::vector<Range> thread_bounds_vec_;
693+
std::vector<std::unique_ptr<arith::Analyzer>> analyzer_vec_;
686694
std::vector<bool> buffer_oob_vec_;
687695
Target target_;
688696
LayoutMap annotated_layout_map_;
@@ -1024,7 +1032,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
10241032
});
10251033

10261034
if ((has_non_local || has_cast_operations) && !has_reducer) {
1027-
for_node = VectorizeLoop(for_node);
1035+
for_node = VectorizeLoop(for_node, analyzer_);
10281036
}
10291037

10301038
if (result_.predicate_map.count(root) && parallel_loop) {

src/transform/legalize_vectorized_loop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class LoopVectorizedLegalizer : IRMutatorWithAnalyzer {
7373
// Change the loop kind from vectorized to serial
7474
for_node.CopyOnWrite()->kind = ForKind::kSerial;
7575
// Apply vectorization transformation to the loop
76-
return VectorizeLoop(for_node);
76+
return VectorizeLoop(for_node, analyzer_);
7777
}
7878
};
7979

src/transform/loop_vectorize.cc

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ struct VectorizePlanResult {
4545
PrimExpr condition;
4646
};
4747

48-
class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer {
48+
class VectorizeFindGlobalAccess : public StmtExprVisitor {
4949
public:
5050
VectorizeFindGlobalAccess() = default;
5151

@@ -60,19 +60,20 @@ class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer {
6060
void VisitStmt_(const BufferStoreNode *node) final {
6161
if (node->buffer.scope() == "global")
6262
has_global_access_ = true;
63-
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
63+
return StmtExprVisitor::VisitStmt_(node);
6464
}
6565

6666
void VisitExpr_(const BufferLoadNode *node) final {
6767
if (node->buffer.scope() == "global")
6868
has_global_access_ = true;
69-
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
69+
return StmtExprVisitor::VisitExpr_(node);
7070
}
7171
};
7272

73-
class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
73+
class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
7474
public:
75-
VectorizePlanner() = default;
75+
explicit VectorizePlanner(arith::Analyzer *analyzer)
76+
: arith::IRMutatorWithAnalyzer(analyzer) {}
7677

7778
int Plan(const For &node) {
7879
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
@@ -92,21 +93,31 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
9293
}
9394

9495
private:
95-
void VisitStmt_(const ForNode *node) final {
96+
Stmt VisitStmt_(const ForNode *node) final {
9697
inner_for_ = node;
97-
auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent));
98-
// Here I disable dynamic shape completely,
99-
// In order to do it, the Planner should accept an analyzer with
100-
// arithmetic info outside to prove the dividiblity of vector size
101-
if (!extent_ptr) {
102-
vector_size_ = 1;
103-
return;
98+
bool contains_nested_for = false;
99+
// Must analysis vectorization on the innermost loop
100+
PostOrderVisit(Downcast<Stmt>(node->body), [&](const ObjectRef &obj) {
101+
if (obj.as<ForNode>()) {
102+
contains_nested_for = true;
103+
}
104+
});
105+
106+
if (!contains_nested_for) {
107+
auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent));
108+
// Here I disable dynamic shape completely,
109+
// In order to do it, the Planner should accept an analyzer with
110+
// arithmetic info outside to prove the dividiblity of vector size
111+
if (!extent_ptr) {
112+
vector_size_ = 1;
113+
return ffi::GetRef<Stmt>(node);
114+
}
115+
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
104116
}
105-
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
106-
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
117+
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
107118
}
108119

109-
void VisitExpr_(const BufferLoadNode *node) final {
120+
PrimExpr VisitExpr_(const BufferLoadNode *node) final {
110121
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
111122
node->buffer.scope() == "shared.dyn")
112123
has_nonlocal_memory_access_ = true;
@@ -115,43 +126,44 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
115126
// constant buffer that tl hack to use as local register.
116127
auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
117128
if (boundary_check && boundary_check->value == 1) {
118-
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
129+
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
119130
}
120131
}
121132
UpdateVectorSize(node->indices, node->buffer);
133+
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
122134
}
123135

124-
void VisitStmt_(const BufferStoreNode *node) final {
136+
Stmt VisitStmt_(const BufferStoreNode *node) final {
125137
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
126138
node->buffer.scope() == "shared.dyn")
127139
has_nonlocal_memory_access_ = true;
128140
UpdateVectorSize(node->indices, node->buffer);
129-
return arith::IRVisitorWithAnalyzer::VisitExpr(node->value);
141+
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
130142
}
131143

132-
void VisitStmt_(const IfThenElseNode *node) final {
144+
Stmt VisitStmt_(const IfThenElseNode *node) final {
133145
CheckConditionVectorized(node->condition);
134-
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
146+
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
135147
}
136148

137-
void VisitExpr_(const CallNode *node) final {
149+
PrimExpr VisitExpr_(const CallNode *node) final {
138150
if (node->op == builtin::if_then_else()) {
139151
CheckConditionVectorized(node->args[0]);
140152
} else if (node->op == builtin::call_extern()) {
141153
// do not vectorize extern calls
142154
vector_size_ = 1;
143155
}
144-
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
156+
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
145157
}
146158

147159
void CheckConditionVectorized(const PrimExpr &cond) {
148160
// TODO: perform some checks here
149161
}
150162

151-
void VisitExpr_(const CastNode *node) final {
163+
PrimExpr VisitExpr_(const CastNode *node) final {
152164
vector_size_ = arith::ZeroAwareGCD(
153165
vector_load_bits_max_ / node->dtype.bits(), vector_size_);
154-
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
166+
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
155167
}
156168

157169
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
@@ -171,19 +183,16 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
171183
for (int i = 0; i < indices.size(); ++i) {
172184
elem_offset += indices[i] * strides[i];
173185
}
174-
175186
// 2. If element offset is independent with loop_var, ignore it
176-
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
187+
if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) {
177188
return;
178189
}
179-
180190
// 3. Tight vectorize bound
181191
vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
182192
buffer->dtype.bits());
183-
184193
// 4. Try to vectorize buffer load
185194
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
186-
inner_for_->extent, vector_size_, &analyzer_)) {
195+
inner_for_->extent, vector_size_, analyzer_)) {
187196
vector_size_ /= 2;
188197
}
189198
}
@@ -235,7 +244,14 @@ class VectorizeRewriter : public StmtExprMutator {
235244
const int vector_size_;
236245
};
237246

238-
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
247+
int GetVectorizeSize(const For &loop) {
248+
arith::Analyzer analyzer;
249+
return VectorizePlanner(&analyzer).Plan(loop);
250+
}
251+
252+
int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) {
253+
return VectorizePlanner(analyzer).Plan(loop);
254+
}
239255

240256
bool CanProveIndependent(const PrimExpr &expr, Var var,
241257
arith::Analyzer *analyzer) {
@@ -274,10 +290,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
274290
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
275291
0))
276292
return false;
277-
293+
auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}}));
278294
// The base offset must be divisible
279-
if (!analyzer->CanProveEqual(
280-
FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) {
295+
if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr),
296+
zero)) {
281297
return false;
282298
}
283299

@@ -308,7 +324,20 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
308324

309325
For VectorizeLoop(const For &loop, int vectorize_hint) {
310326
if (vectorize_hint <= 0) {
311-
VectorizePlanner planner;
327+
arith::Analyzer analyzer;
328+
VectorizePlanner planner(&analyzer);
329+
vectorize_hint = planner.Plan(loop);
330+
}
331+
if (vectorize_hint == 1)
332+
return loop;
333+
auto rewriter = VectorizeRewriter(vectorize_hint);
334+
return Downcast<For>(rewriter(loop));
335+
}
336+
337+
For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
338+
int vectorize_hint) {
339+
if (vectorize_hint <= 0) {
340+
VectorizePlanner planner(analyzer);
312341
vectorize_hint = planner.Plan(loop);
313342
}
314343
if (vectorize_hint == 1)

src/transform/loop_vectorize.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,13 @@ using namespace tir;
3535

3636
int GetVectorizeSize(const For &loop);
3737

38+
int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer);
39+
3840
For VectorizeLoop(const For &loop, int vectorize_hint = -1);
3941

42+
For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
43+
int vectorize_hint = -1);
44+
4045
// Can prove expr is independent with var, i.e. the value of expr doesn't change
4146
// when var changes
4247
bool CanProveIndependent(const PrimExpr &expr, Var var,

0 commit comments

Comments
 (0)