Skip to content

Commit

Permalink
migrate lower_warp memory
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 2, 2020
1 parent 2776c44 commit 15be2bf
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 33 deletions.
66 changes: 33 additions & 33 deletions src/pass/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
// Thanks to Andrew Adams and Vinod Grover for
// explaining the concept of warp shuffle.
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "ir_util.h"
Expand Down Expand Up @@ -75,7 +74,7 @@ namespace ir {

// Visitor to find m in pattern
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
class WarpStoreCoeffFinder : private IRVisitor {
class WarpStoreCoeffFinder : private StmtVisitor {
public:
WarpStoreCoeffFinder(const Variable* buffer,
Var warp_index,
Expand All @@ -86,13 +85,13 @@ class WarpStoreCoeffFinder : private IRVisitor {
}
// find the warp co-efficient in the statement given the warp size
int Find(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
return warp_coeff_;
}

private:
/// Visitor implementation
void Visit_(const Store *op) final {
void VisitStmt_(const Store *op) final {
if (op->buffer_var.get() == buffer_) {
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
Expand All @@ -104,7 +103,7 @@ class WarpStoreCoeffFinder : private IRVisitor {
UpdatePattern(base);
}
} else {
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
}

Expand Down Expand Up @@ -141,22 +140,22 @@ class WarpStoreCoeffFinder : private IRVisitor {


// Visitor to find the warp index
class WarpIndexFinder : private IRVisitor {
class WarpIndexFinder : private StmtVisitor {
public:
explicit WarpIndexFinder(int warp_size)
: warp_size_(warp_size) {
}
// find the warp co-efficient in the statement given the warp size
IterVar Find(const Stmt& stmt) {
this->Visit(stmt);
this->VisitStmt(stmt);
CHECK(warp_index_.defined())
<< "Cannot find warp index(threadIdx.x) within the scope of warp memory";
return warp_index_;
}

private:
/// Visitor implementation
void Visit_(const AttrStmt *op) final {
void VisitStmt_(const AttrStmt *op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
Expand All @@ -177,21 +176,21 @@ class WarpIndexFinder : private IRVisitor {
}
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}
// warp size
int warp_size_{0};
// the warp index
IterVar warp_index_{nullptr};
};
// Mutator to change the read pattern
class WarpAccessRewriter : protected IRMutator {
class WarpAccessRewriter : protected StmtExprMutator {
public:
explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer)
: warp_size_(warp_size), analyzer_(analyzer) {}
// Rewrite the allocate statement which transforms
// warp memory to local memory.
Stmt Rewrite(const Allocate* op, const Stmt& stmt) {
Stmt Rewrite(const Allocate* op) {
buffer_ = op->buffer_var.get();
int alloc_size = op->constant_allocation_size();
CHECK_GT(alloc_size, 0)
Expand All @@ -208,27 +207,27 @@ class WarpAccessRewriter : protected IRMutator {
op->dtype,
{make_const(DataType::Int(32), alloc_size / warp_size_)},
op->condition,
this->Mutate(op->body));
this->VisitStmt(op->body));
}

protected:
Expr Mutate_(const Variable* op, const Expr& expr) {
Expr Mutate_(const Variable* op) {
CHECK(op != buffer_)
<< "Cannot access address of warp memory directly";
return IRMutator::Mutate_(op, expr);
return StmtExprMutator::VisitExpr_(op);
}

Stmt Mutate_(const Store* op, const Stmt& stmt) {
Stmt VisitStmt_(const Store* op) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
return Store::make(op->buffer_var, op->value, local_index, op->predicate);
} else {
return IRMutator::Mutate_(op, stmt);
return StmtExprMutator::VisitStmt_(op);
}
}

Expr Mutate_(const Load* op, const Expr& expr) {
Expr Mutate_(const Load* op) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
Expand All @@ -243,7 +242,7 @@ class WarpAccessRewriter : protected IRMutator {
{load_value, group},
Call::Intrinsic);
} else {
return IRMutator::Mutate_(op, expr);
return StmtExprMutator::VisitExpr_(op);
}
}
// Split the index to the two component
Expand Down Expand Up @@ -297,18 +296,18 @@ class WarpAccessRewriter : protected IRMutator {
// Bind bound information of variables to make analyzer more effective
// TODO(tqchen): consider a pass to inline the bound info into the expr
// so analysis can be context independent.
class BindVarBoundInfo : public IRVisitor {
class BindVarBoundInfo : public StmtVisitor {
public:
explicit BindVarBoundInfo(arith::Analyzer* analyzer)
: analyzer_(analyzer) {}

void Visit_(const For* op) final {
void VisitStmt_(const For* op) final {
const Var& loop_var = op->loop_var;
analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}

void Visit_(const AttrStmt* op) {
void VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
Expand All @@ -319,7 +318,7 @@ class BindVarBoundInfo : public IRVisitor {
analyzer_->Bind(iv->var, dom);
}
}
IRVisitor::Visit_(op);
StmtVisitor::VisitStmt_(op);
}

protected:
Expand All @@ -330,44 +329,45 @@ class BindVarBoundInfo : public IRVisitor {
};

// Mutator to change the read pattern
class WarpMemoryRewriter : private IRMutator {
class WarpMemoryRewriter : private StmtMutator {
public:
explicit WarpMemoryRewriter(int warp_size)
: warp_size_(warp_size) {
}

Stmt Rewrite(Stmt stmt) {
if (warp_size_ == 1) return stmt;
BindVarBoundInfo(&analyzer_).Visit(stmt);
stmt = this->Mutate(stmt);
BindVarBoundInfo binder(&analyzer_);
binder(stmt);
stmt = operator()(std::move(stmt));
stmt = CanonicalSimplify(stmt);
return stmt;
}

private:
Stmt Mutate_(const Allocate* op, const Stmt& stmt) {
Stmt VisitStmt_(const Allocate* op) {
if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_, &analyzer_);
return rewriter.Rewrite(op, stmt);
return rewriter.Rewrite(op);
} else {
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}
}

Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
Stmt VisitStmt_(const AttrStmt* op) {
using runtime::StorageScope;
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
if (scope.rank == runtime::StorageRank::kWarp) {
warp_buffer_.insert(buf);
Stmt ret = IRMutator::Mutate_(op, stmt);
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmt>();
return AttrStmt::make(
op->node, op->attr_key, StringImm::make("local"), op->body);
}
}
return IRMutator::Mutate_(op, stmt);
return StmtMutator::VisitStmt_(op);
}

int warp_size_{0};
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/ir_visitor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_pass.h>

TEST(IRVisitor, CountVar) {
Expand Down

0 comments on commit 15be2bf

Please sign in to comment.