@@ -45,7 +45,7 @@ struct VectorizePlanResult {
4545 PrimExpr condition;
4646};
4747
48- class VectorizeFindGlobalAccess : public arith ::IRVisitorWithAnalyzer {
48+ class VectorizeFindGlobalAccess : public StmtExprVisitor {
4949public:
5050 VectorizeFindGlobalAccess () = default ;
5151
@@ -60,19 +60,20 @@ class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer {
6060 void VisitStmt_ (const BufferStoreNode *node) final {
6161 if (node->buffer .scope () == " global" )
6262 has_global_access_ = true ;
63- return arith::IRVisitorWithAnalyzer ::VisitStmt_ (node);
63+ return StmtExprVisitor ::VisitStmt_ (node);
6464 }
6565
6666 void VisitExpr_ (const BufferLoadNode *node) final {
6767 if (node->buffer .scope () == " global" )
6868 has_global_access_ = true ;
69- return arith::IRVisitorWithAnalyzer ::VisitExpr_ (node);
69+ return StmtExprVisitor ::VisitExpr_ (node);
7070 }
7171};
7272
73- class VectorizePlanner : public arith ::IRVisitorWithAnalyzer {
73+ class VectorizePlanner : public arith ::IRMutatorWithAnalyzer {
7474public:
75- VectorizePlanner () = default ;
75+ explicit VectorizePlanner (arith::Analyzer *analyzer)
76+ : arith::IRMutatorWithAnalyzer(analyzer) {}
7677
7778 int Plan (const For &node) {
7879 tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current ();
@@ -92,21 +93,31 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
9293 }
9394
9495private:
95- void VisitStmt_ (const ForNode *node) final {
96+ Stmt VisitStmt_ (const ForNode *node) final {
9697 inner_for_ = node;
97- auto extent_ptr = as_const_int (analyzer_.Simplify (node->extent ));
98- // Here I disable dynamic shape completely,
99- // In order to do it, the Planner should accept an analyzer with
100- // arithmetic info outside to prove the dividiblity of vector size
101- if (!extent_ptr) {
102- vector_size_ = 1 ;
103- return ;
98+ bool contains_nested_for = false ;
99+ // Must analysis vectorization on the innermost loop
100+ PostOrderVisit (Downcast<Stmt>(node->body ), [&](const ObjectRef &obj) {
101+ if (obj.as <ForNode>()) {
102+ contains_nested_for = true ;
103+ }
104+ });
105+
106+ if (!contains_nested_for) {
107+ auto extent_ptr = as_const_int (analyzer_->Simplify (node->extent ));
108+ // Here I disable dynamic shape completely,
109+ // In order to do it, the Planner should accept an analyzer with
110+ // arithmetic info outside to prove the dividiblity of vector size
111+ if (!extent_ptr) {
112+ vector_size_ = 1 ;
113+ return ffi::GetRef<Stmt>(node);
114+ }
115+ vector_size_ = arith::ZeroAwareGCD (vector_size_, *extent_ptr);
104116 }
105- vector_size_ = arith::ZeroAwareGCD (vector_size_, *extent_ptr);
106- arith::IRVisitorWithAnalyzer::VisitStmt_ (node);
117+ return arith::IRMutatorWithAnalyzer::VisitStmt_ (node);
107118 }
108119
109- void VisitExpr_ (const BufferLoadNode *node) final {
120+ PrimExpr VisitExpr_ (const BufferLoadNode *node) final {
110121 if (node->buffer .scope () == " shared" || node->buffer .scope () == " global" ||
111122 node->buffer .scope () == " shared.dyn" )
112123 has_nonlocal_memory_access_ = true ;
@@ -115,43 +126,44 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
115126 // constant buffer that tl hack to use as local register.
116127 auto boundary_check = node->buffer ->shape [0 ].as <IntImmNode>();
117128 if (boundary_check && boundary_check->value == 1 ) {
118- return arith::IRVisitorWithAnalyzer ::VisitExpr_ (node);
129+ return arith::IRMutatorWithAnalyzer ::VisitExpr_ (node);
119130 }
120131 }
121132 UpdateVectorSize (node->indices , node->buffer );
133+ return arith::IRMutatorWithAnalyzer::VisitExpr_ (node);
122134 }
123135
124- void VisitStmt_ (const BufferStoreNode *node) final {
136+ Stmt VisitStmt_ (const BufferStoreNode *node) final {
125137 if (node->buffer .scope () == " shared" || node->buffer .scope () == " global" ||
126138 node->buffer .scope () == " shared.dyn" )
127139 has_nonlocal_memory_access_ = true ;
128140 UpdateVectorSize (node->indices , node->buffer );
129- return arith::IRVisitorWithAnalyzer::VisitExpr (node-> value );
141+ return arith::IRMutatorWithAnalyzer::VisitStmt_ (node);
130142 }
131143
132- void VisitStmt_ (const IfThenElseNode *node) final {
144+ Stmt VisitStmt_ (const IfThenElseNode *node) final {
133145 CheckConditionVectorized (node->condition );
134- return arith::IRVisitorWithAnalyzer ::VisitStmt_ (node);
146+ return arith::IRMutatorWithAnalyzer ::VisitStmt_ (node);
135147 }
136148
137- void VisitExpr_ (const CallNode *node) final {
149+ PrimExpr VisitExpr_ (const CallNode *node) final {
138150 if (node->op == builtin::if_then_else ()) {
139151 CheckConditionVectorized (node->args [0 ]);
140152 } else if (node->op == builtin::call_extern ()) {
141153 // do not vectorize extern calls
142154 vector_size_ = 1 ;
143155 }
144- return arith::IRVisitorWithAnalyzer ::VisitExpr_ (node);
156+ return arith::IRMutatorWithAnalyzer ::VisitExpr_ (node);
145157 }
146158
147159 void CheckConditionVectorized (const PrimExpr &cond) {
148160 // TODO: perform some checks here
149161 }
150162
151- void VisitExpr_ (const CastNode *node) final {
163+ PrimExpr VisitExpr_ (const CastNode *node) final {
152164 vector_size_ = arith::ZeroAwareGCD (
153165 vector_load_bits_max_ / node->dtype .bits (), vector_size_);
154- return arith::IRVisitorWithAnalyzer ::VisitExpr_ (node);
166+ return arith::IRMutatorWithAnalyzer ::VisitExpr_ (node);
155167 }
156168
157169 void UpdateVectorSize (const Array<PrimExpr> indices, const Buffer &buffer) {
@@ -171,19 +183,16 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
171183 for (int i = 0 ; i < indices.size (); ++i) {
172184 elem_offset += indices[i] * strides[i];
173185 }
174-
175186 // 2. If element offset is independent with loop_var, ignore it
176- if (CanProveIndependent (elem_offset, inner_for_->loop_var , & analyzer_)) {
187+ if (CanProveIndependent (elem_offset, inner_for_->loop_var , analyzer_)) {
177188 return ;
178189 }
179-
180190 // 3. Tight vectorize bound
181191 vector_size_ = arith::ZeroAwareGCD (vector_size_, vector_load_bits_max_ /
182192 buffer->dtype .bits ());
183-
184193 // 4. Try to vectorize buffer load
185194 while (!IndiceCanVectorize (elem_offset, inner_for_->loop_var ,
186- inner_for_->extent , vector_size_, & analyzer_)) {
195+ inner_for_->extent , vector_size_, analyzer_)) {
187196 vector_size_ /= 2 ;
188197 }
189198 }
@@ -235,7 +244,14 @@ class VectorizeRewriter : public StmtExprMutator {
235244 const int vector_size_;
236245};
237246
238- int GetVectorizeSize (const For &loop) { return VectorizePlanner ().Plan (loop); }
247+ int GetVectorizeSize (const For &loop) {
248+ arith::Analyzer analyzer;
249+ return VectorizePlanner (&analyzer).Plan (loop);
250+ }
251+
252+ int GetVectorizeSize (const For &loop, arith::Analyzer *analyzer) {
253+ return VectorizePlanner (analyzer).Plan (loop);
254+ }
239255
240256bool CanProveIndependent (const PrimExpr &expr, Var var,
241257 arith::Analyzer *analyzer) {
@@ -274,10 +290,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
274290 if (!analyzer->CanProveEqual (FloorMod (iter_var_size, target_size_for_iter),
275291 0 ))
276292 return false ;
277-
293+ auto simplified_expr = analyzer-> Simplify ( Substitute (expr, {{var, zero}}));
278294 // The base offset must be divisible
279- if (!analyzer->CanProveEqual (
280- FloorMod ( Substitute (expr, {{var, zero}}), target_size_for_expr), 0 )) {
295+ if (!analyzer->CanProveEqual (FloorMod (simplified_expr, target_size_for_expr),
296+ zero )) {
281297 return false ;
282298 }
283299
@@ -308,7 +324,20 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
308324
309325For VectorizeLoop (const For &loop, int vectorize_hint) {
310326 if (vectorize_hint <= 0 ) {
311- VectorizePlanner planner;
327+ arith::Analyzer analyzer;
328+ VectorizePlanner planner (&analyzer);
329+ vectorize_hint = planner.Plan (loop);
330+ }
331+ if (vectorize_hint == 1 )
332+ return loop;
333+ auto rewriter = VectorizeRewriter (vectorize_hint);
334+ return Downcast<For>(rewriter (loop));
335+ }
336+
337+ For VectorizeLoop (const For &loop, arith::Analyzer *analyzer,
338+ int vectorize_hint) {
339+ if (vectorize_hint <= 0 ) {
340+ VectorizePlanner planner (analyzer);
312341 vectorize_hint = planner.Plan (loop);
313342 }
314343 if (vectorize_hint == 1 )
0 commit comments