Skip to content

Commit

Permalink
[TIR] Add TIR While node (#7425)
Browse files Browse the repository at this point in the history
* add while node

* update visitors

* binary search lowering works

* llvm codegen working

* cuda codegen working

* nms updated to use while loop

* add missing upper bound check too

* add mandelbrot test

* add gpu mandel

commit ee2363b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Jan 29 11:44:02 2021 +0900

    enable extern lib offload for nvptx

* rename test

* run black

* add doc

* add collatz test

* add while + vectorize test

* simplify bin search

* Add special case visit method to storage_access.cc

* disallow while loop inside vectorized loop

* disallow trivial condition since we do not have break

* error out in CoprocSync for now

* error out LiftAttrScope for now

* add placeholder to inject_vpthread

* refactor to use MakeAttach

* handle WhileNode in InplaceOpVerifier

* error out in InjectVirtualThread

* try handle WhileNode in StoragePlanRewriter

* remove WhileNode visitor from storage rewrite

* add while loop storage rewrite test

* update tests

* move test_vectorize_while_fail to  test_tir_transform_vectorize.py
  • Loading branch information
masahi authored Mar 3, 2021
1 parent 3a02e0b commit cf36aa6
Show file tree
Hide file tree
Showing 23 changed files with 695 additions and 23 deletions.
47 changes: 47 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,53 @@ class For : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
};

/*!
* \brief A While loop
*
* \code
*
* while (condition)
* body
*
* \endcode
*/
class WhileNode : public StmtNode {
public:
/*! \brief The termination condition. */
PrimExpr condition;
/*! \brief The body of the while loop. */
Stmt body;

void VisitAttrs(AttrVisitor* v) {
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("span", &span);
}

bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(condition);
hash_reduce.DefHash(body);
}

static constexpr const char* _type_key = "tir.While";
TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode);
};

/*!
* \brief Managed reference to WhileNode.
* \sa WhileNode
*/
class While : public Stmt {
public:
TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode);
};

/*!
* \brief A prefetch hint for a buffer
*/
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -111,6 +112,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
IR_STMT_FUNCTOR_DISPATCH(ForNode);
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
Expand Down Expand Up @@ -152,6 +154,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
Expand Down Expand Up @@ -245,6 +248,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const IfThenElseNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const WhileNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,35 @@ def _exit_cb():

return WithScope(loop_var, _exit_cb)

def while_loop(self, condition):
"""Create a while loop scope.
Parameters
----------
condition : Expr
The termination condition.
Returns
-------
loop_scope : With.Scope of Var
The while scope.
Examples
--------
.. code-block:: python
ib = tvm.tir.ir_builder.create()
iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
with ib.while_loop(iterations[0] < 10):
iterations[0] += 1
"""
self._seq_stack.append([])

def _exit_cb():
self.emit(_stmt.While(condition, self._pop_seq()))

return WithScope(None, _exit_cb)

def if_scope(self, cond):
"""Create an if scope.
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,31 @@ def __init__(
)


@tvm._ffi.register_object("tir.While")
class While(Stmt):
"""While node.
Parameters
----------
condition : PrimExpr
The termination condition.
body : Stmt
The body statement.
span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, condition, body, span=None):
self.__init_handle_by_constructor__(
_ffi_api.While,
condition,
body,
span,
)


@tvm._ffi.register_object("tir.Store")
class Store(Stmt):
"""Store node.
Expand Down
28 changes: 17 additions & 11 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def nms_inner_loop(ib, j):
offset_j = j * 4
num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx)

with ib.for_range(0, num_iter_per_thread) as _k:
with ib.for_range(0, num_iter_per_thread, name="_k") as _k:
k = j + 1 + _k * nthread_tx + tx
offset_k = k * 4

Expand Down Expand Up @@ -555,16 +555,22 @@ def nms_inner_loop(ib, j):

with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, nkeep) as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size
# boxes
# TODO(masahi): Add TIR while loop to realize early exit from the outer loop
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
nms_inner_loop(ib, j)
with ib.else_scope():
with ib.if_scope(max_output_size > 0):
# No need to do more iteration if we have already reached max_output_size boxes
box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local")
box_idx[0] = 0
with ib.while_loop(
tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size)
):
# Proceed to the inner loop if the box with id box_idx is still valid
with ib.if_scope(out_scores[i, box_idx[0]] > -1.0):
nms_inner_loop(ib, box_idx[0])
box_idx[0] += 1

with ib.else_scope():
with ib.for_range(0, nkeep, name="j") as j:
# Proceed to the inner loop if the box j is still valid
with ib.if_scope(out_scores[i, j] > -1.0):
nms_inner_loop(ib, j)

with ib.if_scope(tx + 0 == 0):
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
Doc VisitStmt_(const ForNode* op) override;
Doc VisitStmt_(const WhileNode* op) override;
Doc VisitStmt_(const PrefetchNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;

Expand Down
7 changes: 7 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,13 @@ Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) {
Doc doc;
doc << "while (" << Print(op->condition) << ")";
doc << PrintBody(op->body);
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
Doc doc;
doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
Expand Down
14 changes: 14 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,20 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
}

void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
using llvm::BasicBlock;
BasicBlock* while_cond = BasicBlock::Create(*ctx_, "while_cond", function_);
BasicBlock* while_body = BasicBlock::Create(*ctx_, "while_body", function_);
BasicBlock* while_merge = BasicBlock::Create(*ctx_, "while_merge", function_);
builder_->CreateBr(while_cond);
builder_->SetInsertPoint(while_cond);
builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge);
builder_->SetInsertPoint(while_body);
this->VisitStmt(op->body);
builder_->CreateBr(while_cond);
builder_->SetInsertPoint(while_merge);
}

void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
Expand Down
1 change: 1 addition & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
// stmt
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
Expand Down
11 changes: 10 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
ICHECK(is_one(op->predicate)) << "Predicated store is not supported";
arith::PVar<PrimExpr> base;


if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
Expand Down Expand Up @@ -899,6 +898,16 @@ void CodeGenC::VisitStmt_(const ForNode* op) {
stream << "}\n";
}

void CodeGenC::VisitStmt_(const WhileNode* op) {
PrintIndent();
stream << "while (" << PrintExpr(op->condition) << ") {\n";
int while_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(while_scope);
PrintIndent();
stream << "}\n";
}

void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
Expand Down
32 changes: 32 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,38 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "}\n";
});

// While
While::While(PrimExpr condition, Stmt body, Span span) {
ICHECK(condition.defined());
ICHECK(condition.dtype().is_scalar());
ICHECK(condition.as<tir::IntImmNode>() == nullptr) << "The condition should not be trivial.";
ICHECK(body.defined());

ObjectPtr<WhileNode> node = make_object<WhileNode>();
node->condition = std::move(condition);
node->body = std::move(body);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) {
return While(condition, body, span);
});

TVM_REGISTER_NODE_TYPE(WhileNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<WhileNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const WhileNode*>(node.get());
p->PrintIndent();
p->stream << "while(" << op->condition << "){\n";
p->indent += 2;
p->Print(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});

// Store
Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) {
ICHECK(value.defined());
Expand Down
18 changes: 18 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ void StmtVisitor::VisitStmt_(const ForNode* op) {
this->VisitStmt(op->body);
}

void StmtVisitor::VisitStmt_(const WhileNode* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->body);
}

void StmtVisitor::VisitStmt_(const AllocateNode* op) {
VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitStmt(op->body);
Expand Down Expand Up @@ -283,6 +288,19 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) {
}
}

Stmt StmtMutator::VisitStmt_(const WhileNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
Stmt body = this->VisitStmt(op->body);
if (condition.same_as(op->condition) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->condition = std::move(condition);
n->body = std::move(body);
return Stmt(n);
}
}

Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
Stmt body = this->VisitStmt(op->body);
Expand Down
5 changes: 5 additions & 0 deletions src/tir/transforms/coproc_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,11 @@ class CoProcInstDepDetector : public StmtVisitor {
}
}

void VisitStmt_(const WhileNode* op) final {
// TODO(masahi): Do we need a special handling for While nodes?
LOG(FATAL) << "WhileNode not supported in CoProcSync.";
}

// insert before is stored in reverse order
// the first element is closest to the node.
std::unordered_map<const Object*, std::vector<Stmt> > insert_before_;
Expand Down
7 changes: 7 additions & 0 deletions src/tir/transforms/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,13 @@ class VTInjector : public StmtExprMutator {
}
}

// While
Stmt VisitStmt_(const WhileNode* op) final {
// TODO(masahi): What should we do for While nodes?
LOG(FATAL) << "WhileNode in InjectVirtualThread not supported yet";
return Stmt();
}

// Seq
Stmt VisitStmt_(const SeqStmtNode* op) final {
ICHECK_EQ(max_loop_depth_, 0);
Expand Down
6 changes: 6 additions & 0 deletions src/tir/transforms/lift_attr_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ class AttrScopeLifter : public StmtMutator {
}
}

Stmt VisitStmt_(const WhileNode* op) final {
// TODO(masahi): Do we need a special handling for While nodes?
LOG(FATAL) << "WhileNode not supported in LiftAttrScope.";
return Stmt();
}

private:
// value comparison that also compares content of int constant
static bool ValueSame(const PrimExpr& a, const PrimExpr& b) {
Expand Down
Loading

0 comments on commit cf36aa6

Please sign in to comment.