Skip to content

Commit

Permalink
[REFACTOR] Remove Block
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 6, 2020
1 parent 99936b2 commit 4798284
Show file tree
Hide file tree
Showing 23 changed files with 1 addition and 283 deletions.
22 changes: 0 additions & 22 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1120,28 +1120,6 @@ class SeqStmt : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
};

/*!
* \brief A sequence of statements.
*/
class Block : public StmtNode {
public:
/*! \brief The first statement. */
Stmt first;
/*! \brief The restof statments. */
Stmt rest;

void VisitAttrs(AttrVisitor* v) {
v->Visit("first", &first);
v->Visit("rest", &rest);
}

TVM_DLL static Stmt make(Stmt first, Stmt rest);
TVM_DLL static Stmt make(const std::vector<Stmt> &stmts);

static constexpr const char* _type_key = "Block";
TVM_DECLARE_FINAL_OBJECT_INFO(Block, StmtNode);
};

/*!
* \brief IfThenElse statment.
*/
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args ...) {
Expand All @@ -277,7 +276,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Prefetch);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
Expand Down Expand Up @@ -410,7 +408,6 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const Prefetch* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
};
Expand Down Expand Up @@ -505,7 +502,6 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const Provide* op) override;
Stmt VisitStmt_(const Realize* op) override;
Stmt VisitStmt_(const Prefetch* op) override;
Stmt VisitStmt_(const Block* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const Evaluate* op) override;
/*!
Expand Down
20 changes: 0 additions & 20 deletions python/tvm/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,23 +288,6 @@ def __init__(self,
bounds, condition, body)


@register_node
class Block(Stmt):
"""Block node.
Parameters
----------
first : Stmt
The first statement.
rest : Stmt
The following statement.
"""
def __init__(self, first, rest):
self.__init_handle_by_constructor__(
_make.Block, first, rest)


@register_node
class SeqStmt(Stmt):
"""Sequence of statements.
Expand Down Expand Up @@ -422,12 +405,9 @@ def stmt_list(stmt):
for x in stmt:
res += stmt_list(x)
return res
elif isinstance(stmt, Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]


_make.stmt_list = stmt_list
_make.stmt_seq = stmt_seq
3 changes: 0 additions & 3 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ REGISTER_MAKE(IfThenElse);
REGISTER_MAKE(Evaluate);

// overloaded, needs special handling
TVM_REGISTER_GLOBAL("make.Block")
.set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));

// has default args
TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed<Stmt(VarExpr, DataType, Array<Expr>, Expr, Stmt)>([](
Expand Down
5 changes: 0 additions & 5 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -875,11 +875,6 @@ void CodeGenC::VisitStmt_(const IfThenElse* op) {
stream << "}\n";
}

void CodeGenC::VisitStmt_(const Block* op) {
PrintStmt(op->first);
if (op->rest.defined()) PrintStmt(op->rest);
}

void CodeGenC::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
PrintStmt(stmt);
Expand Down
1 change: 0 additions & 1 deletion src/codegen/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class CodeGenC :
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*!
Expand Down
7 changes: 0 additions & 7 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1214,13 +1214,6 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
this->VisitStmt(op->body);
}

void CodeGenLLVM::VisitStmt_(const Block* op) {
this->VisitStmt(op->first);
if (op->rest.defined()) {
this->VisitStmt(op->rest);
}
}

void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
this->VisitStmt(stmt);
Expand Down
1 change: 0 additions & 1 deletion src/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class CodeGenLLVM :
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
Expand Down
7 changes: 0 additions & 7 deletions src/codegen/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,13 +638,6 @@ void CodeGenSPIRV::VisitStmt_(const LetStmt* op) {
this->VisitStmt(op->body);
}

void CodeGenSPIRV::VisitStmt_(const Block* op) {
VisitStmt(op->first);
if (op->rest.defined()) {
this->VisitStmt(op->rest);
}
}

void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
PrintStmt(stmt);
Expand Down
1 change: 0 additions & 1 deletion src/codegen/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ class CodeGenSPIRV:
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const LetStmt* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
Expand Down
5 changes: 0 additions & 5 deletions src/codegen/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,6 @@ void CodeGenStackVM::VisitStmt_(const For* op) {
this->SetOperand(backward_jump, loop_head - label_bjump);
}

void CodeGenStackVM::VisitStmt_(const Block* op) {
this->Push(op->first);
if (op->rest.defined()) this->Push(op->rest);
}

void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
this->Push(stmt);
Expand Down
1 change: 0 additions & 1 deletion src/codegen/stackvm/codegen_stackvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ class CodeGenStackVM
void VisitStmt_(const AttrStmt* op) final;
void VisitStmt_(const AssertStmt* op) final;
void VisitStmt_(const Evaluate* op) final;
void VisitStmt_(const Block* op) final;
void VisitStmt_(const SeqStmtNode* op) final;
void VisitStmt_(const ProducerConsumer* op) final;

Expand Down
5 changes: 0 additions & 5 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,6 @@ void CodeGenHybrid::VisitStmt_(const IfThenElse* op) {
}
}

void CodeGenHybrid::VisitStmt_(const Block* op) {
PrintStmt(op->first);
if (op->rest.defined()) PrintStmt(op->rest);
}

void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
for (Stmt stmt : op->seq) {
PrintStmt(stmt);
Expand Down
1 change: 0 additions & 1 deletion src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ class CodeGenHybrid :
void VisitStmt_(const AttrStmt* op) override;
void VisitStmt_(const AssertStmt* op) override;
void VisitStmt_(const Evaluate* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const ProducerConsumer* op) override;
/*!
Expand Down
35 changes: 0 additions & 35 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,33 +510,6 @@ SeqStmt::SeqStmt(Array<Stmt> seq) {
data_ = std::move(node);
}

Stmt Block::make(Stmt first, Stmt rest) {
CHECK(first.defined());
CHECK(rest.defined());
ObjectPtr<Block> node = make_object<Block>();

// canonicalize.
if (const Block* b = first.as<Block>()) {
node->first = b->first;
node->rest = Block::make(b->rest, rest);
} else {
node->first = std::move(first);
node->rest = std::move(rest);
}
return Stmt(node);
}

Stmt Block::make(const std::vector<Stmt>& stmts) {
if (stmts.empty()) {
return Stmt();
}
Stmt result = stmts.back();
for (size_t i = stmts.size() - 1; i != 0; --i) {
result = Block::make(stmts[i - 1], result);
}
return result;
}

Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
CHECK(condition.defined());
CHECK(then_case.defined());
Expand Down Expand Up @@ -1037,13 +1010,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
}
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<Block>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const Block*>(node.get());
p->Print(op->first);
if (op->rest.defined()) p->Print(op->rest);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SeqStmtNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const SeqStmtNode*>(node.get());
Expand Down Expand Up @@ -1226,7 +1192,6 @@ TVM_REGISTER_NODE_TYPE(Provide);
TVM_REGISTER_NODE_TYPE(Allocate);
TVM_REGISTER_NODE_TYPE(Free);
TVM_REGISTER_NODE_TYPE(Realize);
TVM_REGISTER_NODE_TYPE(Block);
TVM_REGISTER_NODE_TYPE(SeqStmtNode);
TVM_REGISTER_NODE_TYPE(IfThenElse);
TVM_REGISTER_NODE_TYPE(Evaluate);
Expand Down
16 changes: 1 addition & 15 deletions src/pass/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,21 +356,7 @@ class VTInjector : public StmtExprMutator {
return IfThenElse::make(condition, then_case, else_case);
}
}
// Block
Stmt VisitStmt_(const Block* op) final {
CHECK_EQ(max_loop_depth_, 0);
Stmt first = this->VisitStmt(op->first);
int temp = max_loop_depth_;
max_loop_depth_ = 0;
Stmt rest = this->VisitStmt(op->rest);
max_loop_depth_ = std::max(max_loop_depth_, temp);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return GetRef<Stmt>(op);
} else {
return Block::make(first, rest);
}
}

// Seq
Stmt VisitStmt_(const SeqStmtNode* op) final {
CHECK_EQ(max_loop_depth_, 0);
Expand Down
7 changes: 0 additions & 7 deletions src/pass/ir_deep_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,6 @@ class IRDeepCompare :
if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
}

void VisitStmt_(const Block* op, const Stmt& other) final {
const Block* rhs = other.as<Block>();
if (CompareStmt(op->first, rhs->first) != 0) return;
if (CompareStmt(op->rest, rhs->rest) != 0) return;
}


void VisitStmt_(const SeqStmtNode* op, const Stmt& other) final {
const SeqStmtNode* rhs = other.as<SeqStmtNode>();
if (CompareValue(op->size(), rhs->size()) != 0) return;
Expand Down
19 changes: 0 additions & 19 deletions src/pass/ir_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,6 @@ void StmtVisitor::VisitStmt_(const Prefetch* op) {
});
}

void StmtVisitor::VisitStmt_(const Block* op) {
this->VisitStmt(op->first);
this->VisitStmt(op->rest);
}

void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
VisitArray(op->seq, [this](const Stmt& s) {
this->VisitStmt(s);
Expand Down Expand Up @@ -486,20 +481,6 @@ Stmt StmtMutator::VisitStmt_(const Prefetch* op) {
}
}

Stmt StmtMutator::VisitStmt_(const Block* op) {
Stmt first = this->VisitStmt(op->first);
Stmt rest = this->VisitStmt(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->first = std::move(first);
n->rest = std::move(rest);
return Stmt(n);
}
}

Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) {
Array<Stmt> seq = Internal::Mutate(this, op->seq);
if (seq.same_as(op->seq)) {
Expand Down
5 changes: 0 additions & 5 deletions src/pass/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (const auto* block = s.as<Block>()) {
auto n = make_object<Block>(*block);
CHECK(is_no_op(n->rest));
n->rest = body;
body = Stmt(n);
} else if (const auto* seq = s.as<SeqStmtNode>()) {
auto n = make_object<SeqStmtNode>(*seq);
CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
Expand Down
Loading

0 comments on commit 4798284

Please sign in to comment.