diff --git a/3rdparty/tvm b/3rdparty/tvm index bc31e7ad9..cd2b2b601 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit bc31e7ad9f9fafd7659dfabafe359fd55a0ffc1e +Subproject commit cd2b2b6013d155b5822300b0a0740fa65320dd9e diff --git a/src/op/copy.cc b/src/op/copy.cc index 5d3529044..2584abced 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -852,7 +852,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, auto par_op = ParallelOp(transformed_loop); if (is_cpu_target) { - vectorized_thread_loop = VectorizeLoop(transformed_loop); + vectorized_thread_loop = VectorizeLoop(transformed_loop, analyzer); } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, InferLevel::kFree}; @@ -865,7 +865,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, auto thread_var = T.thread_var; auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); - vectorized_thread_loop = VectorizeLoop(thread_loop); + vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); } if (par_op->GetPredicate(T.thread_var).defined()) { diff --git a/src/op/fill.cc b/src/op/fill.cc index 83b0842dc..93b3bca07 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -207,7 +207,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = VectorizeLoop(thread_loop); + auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); @@ -215,7 +215,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } else if (dst.scope() == "local") { auto init_loop = MakeSIMTLoop(analyzer); - auto vectorized_thread_loop = VectorizeLoop(init_loop); + auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer); return vectorized_thread_loop; } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || dst.scope() == "global") { @@ -225,7 +225,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = VectorizeLoop(thread_loop); + auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 81777aa53..0d09cc129 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -452,8 +452,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // As the pass will do post processing to the layout auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); - int vector_size = GetVectorizeSize(maybe_remapped_root_); - + int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer); DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; PrimExpr loop_total_size = 1; diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index bd726b3db..be98b284d 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include "../layout/utils.h" @@ -85,6 +86,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { auto &next = infer_list_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id]; + arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get(); auto buffer_oob = buffer_oob_vec_[cur_infer_id]; // Double-check that 'next' is valid ICHECK(next.defined()) << "infer_list_[" << cur_infer_id @@ -108,7 +110,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // Run InferLayout auto updates = next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, - &analyzer_, buffer_oob}, + cur_analyzer, buffer_oob}, level); // Process the returned updates for (const auto &[buffer, layout] : updates) { @@ -266,6 +268,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size()) << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in " "length."; + ICHECK_EQ(analyzer_vec_.size(), infer_list_.size()) + << "Size mismatch: analyzer_vec_ and infer_list_ must match in " + "length."; ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size()) << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " "length."; @@ -452,6 +457,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } else { thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); } + analyzer_vec_.push_back(analyzer_.Clone()); // Compute buffer oob for each buffer in the op if (const auto *copy = p.as()) { @@ -542,6 +548,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } else { thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); } + analyzer_vec_.push_back(analyzer_.Clone()); buffer_oob_vec_.push_back(false); } else { IRVisitorWithAnalyzer::VisitStmt(op->body); @@ -683,6 +690,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { IterVarType::kDataPar); std::vector thread_var_vec_; std::vector thread_bounds_vec_; + std::vector> analyzer_vec_; std::vector buffer_oob_vec_; Target target_; LayoutMap annotated_layout_map_; @@ -1024,7 +1032,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { }); if ((has_non_local || has_cast_operations) && !has_reducer) { - for_node = VectorizeLoop(for_node); + for_node = VectorizeLoop(for_node, analyzer_); } if (result_.predicate_map.count(root) && parallel_loop) { diff --git a/src/transform/legalize_vectorized_loop.cc b/src/transform/legalize_vectorized_loop.cc index aa461784a..4fd4ab91f 100644 --- a/src/transform/legalize_vectorized_loop.cc +++ b/src/transform/legalize_vectorized_loop.cc @@ -73,7 +73,7 @@ class LoopVectorizedLegalizer : IRMutatorWithAnalyzer { // Change the loop kind from vectorized to serial for_node.CopyOnWrite()->kind = ForKind::kSerial; // Apply vectorization transformation to the loop - return VectorizeLoop(for_node); + return VectorizeLoop(for_node, analyzer_); } }; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 45283d905..e8a18b004 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -45,7 +45,7 @@ struct VectorizePlanResult { PrimExpr condition; }; -class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer { +class VectorizeFindGlobalAccess : public StmtExprVisitor { public: VectorizeFindGlobalAccess() = default; @@ -60,19 +60,20 @@ class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer { void VisitStmt_(const BufferStoreNode *node) final { if (node->buffer.scope() == "global") has_global_access_ = true; - return arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return StmtExprVisitor::VisitStmt_(node); } void VisitExpr_(const BufferLoadNode *node) final { if (node->buffer.scope() == "global") has_global_access_ = true; - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return StmtExprVisitor::VisitExpr_(node); } }; -class VectorizePlanner : public arith::IRVisitorWithAnalyzer { +class VectorizePlanner : public arith::IRMutatorWithAnalyzer { public: - VectorizePlanner() = default; + explicit VectorizePlanner(arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer) {} int Plan(const For &node) { tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); @@ -92,21 +93,31 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { } private: - void VisitStmt_(const ForNode *node) final { + Stmt VisitStmt_(const ForNode *node) final { inner_for_ = node; - auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent)); - // Here I disable dynamic shape completely, - // In order to do it, the Planner should accept an analyzer with - // arithmetic info outside to prove the dividiblity of vector size - if (!extent_ptr) { - vector_size_ = 1; - return; + bool contains_nested_for = false; + // Must analysis vectorization on the innermost loop + PostOrderVisit(Downcast(node->body), [&](const ObjectRef &obj) { + if (obj.as()) { + contains_nested_for = true; + } + }); + + if (!contains_nested_for) { + auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent)); + // Here I disable dynamic shape completely, + // In order to do it, the Planner should accept an analyzer with + // arithmetic info outside to prove the dividiblity of vector size + if (!extent_ptr) { + vector_size_ = 1; + return ffi::GetRef(node); + } + vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); } - vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); - arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitExpr_(const BufferLoadNode *node) final { + PrimExpr VisitExpr_(const BufferLoadNode *node) final { if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || node->buffer.scope() == "shared.dyn") has_nonlocal_memory_access_ = true; @@ -115,43 +126,44 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { // constant buffer that tl hack to use as local register. auto boundary_check = node->buffer->shape[0].as(); if (boundary_check && boundary_check->value == 1) { - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } } UpdateVectorSize(node->indices, node->buffer); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } - void VisitStmt_(const BufferStoreNode *node) final { + Stmt VisitStmt_(const BufferStoreNode *node) final { if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || node->buffer.scope() == "shared.dyn") has_nonlocal_memory_access_ = true; UpdateVectorSize(node->indices, node->buffer); - return arith::IRVisitorWithAnalyzer::VisitExpr(node->value); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitStmt_(const IfThenElseNode *node) final { + Stmt VisitStmt_(const IfThenElseNode *node) final { CheckConditionVectorized(node->condition); - return arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitExpr_(const CallNode *node) final { + PrimExpr VisitExpr_(const CallNode *node) final { if (node->op == builtin::if_then_else()) { CheckConditionVectorized(node->args[0]); } else if (node->op == builtin::call_extern()) { // do not vectorize extern calls vector_size_ = 1; } - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } void CheckConditionVectorized(const PrimExpr &cond) { // TODO: perform some checks here } - void VisitExpr_(const CastNode *node) final { + PrimExpr VisitExpr_(const CastNode *node) final { vector_size_ = arith::ZeroAwareGCD( vector_load_bits_max_ / node->dtype.bits(), vector_size_); - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } void UpdateVectorSize(const Array indices, const Buffer &buffer) { @@ -171,19 +183,16 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { for (int i = 0; i < indices.size(); ++i) { elem_offset += indices[i] * strides[i]; } - // 2. If element offset is independent with loop_var, ignore it - if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) { + if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) { return; } - // 3. Tight vectorize bound vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ / buffer->dtype.bits()); - // 4. Try to vectorize buffer load while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, - inner_for_->extent, vector_size_, &analyzer_)) { + inner_for_->extent, vector_size_, analyzer_)) { vector_size_ /= 2; } } @@ -235,7 +244,14 @@ class VectorizeRewriter : public StmtExprMutator { const int vector_size_; }; -int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } +int GetVectorizeSize(const For &loop) { + arith::Analyzer analyzer; + return VectorizePlanner(&analyzer).Plan(loop); +} + +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) { + return VectorizePlanner(analyzer).Plan(loop); +} bool CanProveIndependent(const PrimExpr &expr, Var var, arith::Analyzer *analyzer) { @@ -274,10 +290,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter), 0)) return false; - + auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}})); // The base offset must be divisible - if (!analyzer->CanProveEqual( - FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) { + if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr), + zero)) { return false; } @@ -308,7 +324,20 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, For VectorizeLoop(const For &loop, int vectorize_hint) { if (vectorize_hint <= 0) { - VectorizePlanner planner; + arith::Analyzer analyzer; + VectorizePlanner planner(&analyzer); + vectorize_hint = planner.Plan(loop); + } + if (vectorize_hint == 1) + return loop; + auto rewriter = VectorizeRewriter(vectorize_hint); + return Downcast(rewriter(loop)); +} + +For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, + int vectorize_hint) { + if (vectorize_hint <= 0) { + VectorizePlanner planner(analyzer); vectorize_hint = planner.Plan(loop); } if (vectorize_hint == 1) diff --git a/src/transform/loop_vectorize.h b/src/transform/loop_vectorize.h index 4ab20c668..a63c4b450 100644 --- a/src/transform/loop_vectorize.h +++ b/src/transform/loop_vectorize.h @@ -35,8 +35,13 @@ using namespace tir; int GetVectorizeSize(const For &loop); +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer); + For VectorizeLoop(const For &loop, int vectorize_hint = -1); +For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, + int vectorize_hint = -1); + // Can prove expr is independent with var, i.e. the value of expr doesn't change // when var changes bool CanProveIndependent(const PrimExpr &expr, Var var,