From 15be2bfff5cdbf14d969abc1c7406df9ab5ee18e Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 16:31:22 -0800 Subject: [PATCH] migrate lower_warp memory --- src/pass/lower_warp_memory.cc | 66 +++++++++++++++++------------------ tests/cpp/ir_visitor_test.cc | 1 + 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 0749127b905b..2d24ec425f81 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -26,8 +26,7 @@ // Thanks to Andrew Adams and Vinod Grover for // explaining the concept of warp shuffle. #include -#include -#include +#include #include #include #include "ir_util.h" @@ -75,7 +74,7 @@ namespace ir { // Visitor to find m in pattern // store warp_mem[m * warp_index + (warp_size * m) * y + x] -class WarpStoreCoeffFinder : private IRVisitor { +class WarpStoreCoeffFinder : private StmtVisitor { public: WarpStoreCoeffFinder(const Variable* buffer, Var warp_index, @@ -86,13 +85,13 @@ class WarpStoreCoeffFinder : private IRVisitor { } // find the warp co-efficient in the statement given the warp size int Find(const Stmt& stmt) { - this->Visit(stmt); + this->VisitStmt(stmt); return warp_coeff_; } private: /// Visitor implementation - void Visit_(const Store *op) final { + void VisitStmt_(const Store *op) final { if (op->buffer_var.get() == buffer_) { if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); @@ -104,7 +103,7 @@ class WarpStoreCoeffFinder : private IRVisitor { UpdatePattern(base); } } else { - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } } @@ -141,14 +140,14 @@ class WarpStoreCoeffFinder : private IRVisitor { // Visitor to find the warp index -class WarpIndexFinder : private IRVisitor { +class WarpIndexFinder : private StmtVisitor { public: explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) { } // find the warp co-efficient in the statement given the warp size IterVar Find(const Stmt& stmt) { - this->Visit(stmt); + this->VisitStmt(stmt); CHECK(warp_index_.defined()) << "Cannot find warp index(threadIdx.x) within the scope of warp memory"; return warp_index_; @@ -156,7 +155,7 @@ class WarpIndexFinder : private IRVisitor { private: /// Visitor implementation - void Visit_(const AttrStmt *op) final { + void VisitStmt_(const AttrStmt *op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { @@ -177,7 +176,7 @@ class WarpIndexFinder : private IRVisitor { } } } - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } // warp size int warp_size_{0}; @@ -185,13 +184,13 @@ class WarpIndexFinder : private IRVisitor { IterVar warp_index_{nullptr}; }; // Mutator to change the read pattern -class WarpAccessRewriter : protected IRMutator { +class WarpAccessRewriter : protected StmtExprMutator { public: explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer) : warp_size_(warp_size), analyzer_(analyzer) {} // Rewrite the allocate statement which transforms // warp memory to local memory. - Stmt Rewrite(const Allocate* op, const Stmt& stmt) { + Stmt Rewrite(const Allocate* op) { buffer_ = op->buffer_var.get(); int alloc_size = op->constant_allocation_size(); CHECK_GT(alloc_size, 0) @@ -208,27 +207,27 @@ class WarpAccessRewriter : protected IRMutator { op->dtype, {make_const(DataType::Int(32), alloc_size / warp_size_)}, op->condition, - this->Mutate(op->body)); + this->VisitStmt(op->body)); } protected: - Expr Mutate_(const Variable* op, const Expr& expr) { + Expr Mutate_(const Variable* op) { CHECK(op != buffer_) << "Cannot access address of warp memory directly"; - return IRMutator::Mutate_(op, expr); + return StmtExprMutator::VisitExpr_(op); } - Stmt Mutate_(const Store* op, const Stmt& stmt) { + Stmt VisitStmt_(const Store* op) { if (op->buffer_var.get() == buffer_) { Expr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); return Store::make(op->buffer_var, op->value, local_index, op->predicate); } else { - return IRMutator::Mutate_(op, stmt); + return StmtExprMutator::VisitStmt_(op); } } - Expr Mutate_(const Load* op, const Expr& expr) { + Expr Mutate_(const Load* op) { if (op->buffer_var.get() == buffer_) { Expr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); @@ -243,7 +242,7 @@ class WarpAccessRewriter : protected IRMutator { {load_value, group}, Call::Intrinsic); } else { - return IRMutator::Mutate_(op, expr); + return StmtExprMutator::VisitExpr_(op); } } // Split the index to the two component @@ -297,18 +296,18 @@ class WarpAccessRewriter : protected IRMutator { // Bind bound information of variables to make analyzer more effective // TODO(tqchen): consider a pass to inline the bound info into the expr // so analysis can be context independent. -class BindVarBoundInfo : public IRVisitor { +class BindVarBoundInfo : public StmtVisitor { public: explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {} - void Visit_(const For* op) final { + void VisitStmt_(const For* op) final { const Var& loop_var = op->loop_var; analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent)); - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } - void Visit_(const AttrStmt* op) { + void VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); @@ -319,7 +318,7 @@ class BindVarBoundInfo : public IRVisitor { analyzer_->Bind(iv->var, dom); } } - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } protected: @@ -330,7 +329,7 @@ class BindVarBoundInfo : public IRVisitor { }; // Mutator to change the read pattern -class WarpMemoryRewriter : private IRMutator { +class WarpMemoryRewriter : private StmtMutator { public: explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) { @@ -338,36 +337,37 @@ class WarpMemoryRewriter : private IRMutator { Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; - BindVarBoundInfo(&analyzer_).Visit(stmt); - stmt = this->Mutate(stmt); + BindVarBoundInfo binder(&analyzer_); + binder(stmt); + stmt = operator()(std::move(stmt)); stmt = CanonicalSimplify(stmt); return stmt; } private: - Stmt Mutate_(const Allocate* op, const Stmt& stmt) { + Stmt VisitStmt_(const Allocate* op) { if (warp_buffer_.count(op->buffer_var.get())) { WarpAccessRewriter rewriter(warp_size_, &analyzer_); - return rewriter.Rewrite(op, stmt); + return rewriter.Rewrite(op); } else { - return IRMutator::Mutate_(op, stmt); + return StmtMutator::VisitStmt_(op); } } - Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) { + Stmt VisitStmt_(const AttrStmt* op) { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as(); StorageScope scope = StorageScope::make(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); - Stmt ret = IRMutator::Mutate_(op, stmt); + Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); return AttrStmt::make( op->node, op->attr_key, StringImm::make("local"), op->body); } } - return IRMutator::Mutate_(op, stmt); + return StmtMutator::VisitStmt_(op); } int warp_size_{0}; diff --git a/tests/cpp/ir_visitor_test.cc b/tests/cpp/ir_visitor_test.cc index 4282a0026ee6..1f34b2549d0d 100644 --- a/tests/cpp/ir_visitor_test.cc +++ b/tests/cpp/ir_visitor_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include TEST(IRVisitor, CountVar) {