From 92c7ff67a9ce26044eaf7d7822f7af1601d83abd Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 30 Dec 2019 14:41:35 -0800 Subject: [PATCH 1/4] [REFACTOR][IR] Introduce SeqStmt to replace Block ir::Block was used to represent a sequence of Stmts in the original low-level IR. The nested ir::Block structure is not really friendly for recursive visits, especially when the statements are unrolled. This PR introduce a SeqStmt that directly stores a sequence of statements in an Array container. The new SeqStmt will be used as a replacement of the original Block structure. --- include/tvm/expr_operator.h | 3 + include/tvm/ir.h | 98 +++++++++++++++++++ include/tvm/ir_functor_ext.h | 19 ++++ include/tvm/node/container.h | 8 ++ python/tvm/ir_builder.py | 10 +- python/tvm/stmt.py | 32 +++++- src/api/api_ir.cc | 6 ++ src/codegen/codegen_c.cc | 62 ++++++------ src/codegen/codegen_c.h | 1 + src/codegen/llvm/codegen_llvm.cc | 6 ++ src/codegen/llvm/codegen_llvm.h | 1 + src/codegen/spirv/codegen_spirv.cc | 6 ++ src/codegen/spirv/codegen_spirv.h | 1 + src/codegen/stackvm/codegen_stackvm.cc | 70 +++++++------ src/codegen/stackvm/codegen_stackvm.h | 5 +- src/contrib/hybrid/codegen_hybrid.cc | 68 +++++++------ src/contrib/hybrid/codegen_hybrid.h | 1 + src/lang/ir.cc | 15 +++ src/pass/inject_virtual_thread.cc | 12 +++ src/pass/ir_deep_compare.cc | 9 ++ src/pass/ir_functor.cc | 63 ++++++++++++ src/pass/lift_attr_scope.cc | 4 + src/pass/loop_partition.cc | 12 +++ src/pass/remove_no_op.cc | 32 ++++++ src/pass/unroll_loop.cc | 17 ++++ tests/cpp/ir_functor_test.cc | 32 ++++++ tests/python/unittest/test_ir_builder.py | 4 +- tests/python/unittest/test_pass_equal.py | 1 + .../unittest/test_pass_inject_vthread.py | 4 +- .../python/unittest/test_pass_remove_no_op.py | 6 +- tests/python/unittest/test_pass_unroll.py | 12 +-- 31 files changed, 510 insertions(+), 110 deletions(-) diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 41e7aa5b7796..a73edb428cba 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -667,6 +667,9 @@ inline bool is_no_op(const Stmt& stmt) { if (const auto* op = stmt.as()) { return is_const(op->value); } + if (const auto* op = stmt.as()) { + return op->seq.size() == 0; + } return false; } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index c55a4695de4d..02c5aa439d7d 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1021,6 +1021,104 @@ class Realize : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(Realize, StmtNode); }; +/*! + * \brief The container of seq statement. + * Represent a sequence of statements. + */ +class SeqStmtNode : public StmtNode { + public: + /*! \brief internal sequence content. */ + Array seq; + + /*! \return get the size of the sequence */ + size_t size() const { + return seq.size(); + } + /*! + * \brief Get the index-th element in the sequence. + */ + Stmt operator[](size_t index) const { + return seq[index]; + } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("seq", &seq); + } + + static constexpr const char* _type_key = "SeqStmt"; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); +}; + +/*! \brief Sequence statement. */ +class SeqStmt : public Stmt { + public: + /*! + * \brief Construct SeqStmt. + * \param seq The sequence. + */ + TVM_DLL explicit SeqStmt(Array seq); + + /*! \return get the size of the sequence */ + size_t size() const { + return operator->()->size(); + } + /*! + * \brief Get the index-th element in the sequence. + */ + Stmt operator[](size_t index) const { + return (*(operator->()))[index]; + } + /*! + * \brief Construct a flattened sequence statement. + * + * \note This function can return the element if there + * is only one element in the sequence. + * \param seq_args The list of arguments to be flattened. + * \tparam Args arguments + * \return The constructed statement + */ + template + static Stmt Flatten(Args&&... seq_args) { + Array seq; + runtime::detail::for_each( + Flattener(&seq), std::forward(seq_args)...); + if (seq.size() == 1) return seq[0]; + return SeqStmt(seq); + } + /*! \brief Helper class to flatten sequence of arguments into Array. */ + class Flattener { + public: + explicit Flattener(Array* seq) + : seq_(seq) {} + + void operator()(size_t i, const Stmt& stmt) const { + if (auto* op = stmt.as()) { + operator()(0, op->seq); + } else if (auto* op = stmt.as()) { + if (!op->is_producer) { + operator()(0, op->body); + } else { + seq_->push_back(stmt); + } + } else { + seq_->push_back(stmt); + } + } + + template + void operator()(size_t i, const T& seq) const { + for (auto v : seq) { + this->operator()(0, v); + } + } + + private: + Array* seq_; + }; + + TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); +}; + /*! * \brief A sequence of statements. */ diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 4d821c2c4236..5462e8799f61 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -254,6 +254,7 @@ class StmtFunctor { virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); @@ -277,6 +278,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(Realize); IR_STMT_FUNCTOR_DISPATCH(Prefetch); IR_STMT_FUNCTOR_DISPATCH(Block); + IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(Evaluate); return vtable; } @@ -409,6 +411,7 @@ class TVM_DLL StmtVisitor : void VisitStmt_(const Realize* op) override; void VisitStmt_(const Prefetch* op) override; void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; }; @@ -503,7 +506,23 @@ class TVM_DLL StmtMutator : Stmt VisitStmt_(const Realize* op) override; Stmt VisitStmt_(const Prefetch* op) override; Stmt VisitStmt_(const Block* op) override; + Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const Evaluate* op) override; + /*! + * \brief Alternative advance method for SeqStmtNode. + * + * This function can be called when a child class override + * VisitStmt_(const SeqStmtNode*) to introduce + * the special behavior to visit + * + * \param op The sequence. + * \param flatten_before_visit Whether to flatten the sequence before visit. + * \param fmutate The mutate function, can be nullptr, which defaults to Visit. + * \return The mutated result. + */ + Stmt VisitSeqStmt_(const SeqStmtNode* op, + bool flatten_before_visit, + std::function fmutate = nullptr); // internal helper. class Internal; }; diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 0d5ab376d50a..7686a96c19d3 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -271,6 +271,14 @@ class Array : public ObjectRef { ArrayNode* n = this->CopyOnWrite(); n->data.push_back(item); } + /*! + * \brief Resize the array. + * \param size The new size. + */ + inline void resize(size_t size) { + ArrayNode* n = this->CopyOnWrite(); + n->data.resize(size); + } /*! * \brief set i-th element of the array. * \param i The index diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index d7f41a2669bf..2a87871377b1 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -120,14 +120,16 @@ def _pop_seq(self): seq = self._seq_stack.pop() if not seq or callable(seq[-1]): seq.append(_make.Evaluate(0)) - stmt = seq[-1] + seqwrap = lambda x : x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x))) + ret_seq = [seq[-1]] + for s in reversed(seq[:-1]): if callable(s): - stmt = s(stmt) + ret_seq = [s(seqwrap(ret_seq))] else: assert isinstance(s, _stmt.Stmt) - stmt = _make.Block(s, stmt) - return stmt + ret_seq.append(s) + return seqwrap(ret_seq) def emit(self, stmt): """Emit a statement to the end of current scope. diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index fc7f6e2cb173..d16b71c54ba4 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -305,6 +305,26 @@ def __init__(self, first, rest): _make.Block, first, rest) +@register_node +class SeqStmt(Stmt): + """Sequence of statements. + + Parameters + ---------- + seq : List[Stmt] + The statements + """ + def __init__(self, seq): + self.__init_handle_by_constructor__( + _make.SeqStmt, seq) + + def __getitem__(self, i): + return self.seq[i] + + def __len__(self): + return len(self.seq) + + @register_node class IfThenElse(Stmt): """IfThenElse node. @@ -374,6 +394,9 @@ def stmt_seq(*args): ------- stmt : Stmt The combined statement. + """ + return SeqStmt(args) + """ ret = None for value in args: @@ -381,7 +404,7 @@ def stmt_seq(*args): value = Evaluate(value) ret = value if ret is None else Block(ret, value) return ret if ret else Evaluate(0) - + """ def stmt_list(stmt): """Make list of stmt from blocks. @@ -395,7 +418,12 @@ def stmt_list(stmt): stmt_list : list of Stmt The unpacked list of statements """ - if isinstance(stmt, Block): + if isinstance(stmt, SeqStmt): + res = [] + for x in stmt: + res += stmt_list(x) + return res + elif isinstance(stmt, Block): return stmt_list(stmt.first) + stmt_list(stmt.rest) if isinstance(stmt, ProducerConsumer): return stmt_list(stmt.body) diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 2987b9ef39c1..4fec5234f955 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("make._cast") TVM_REGISTER_GLOBAL("make._range_by_min_extent") .set_body_typed(Range::make_by_min_extent); + +TVM_REGISTER_GLOBAL("make.SeqStmt") +.set_body_typed([](Array seq) { + return SeqStmt(std::move(seq)); +}); + TVM_REGISTER_GLOBAL("make.For") .set_body_typed([]( VarExpr loop_var, Expr min, Expr extent, diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 4b95e2caf1aa..2c24e07eb5d6 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -405,22 +405,22 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N } } -void CodeGenC::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const IntImm* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const UIntImm* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*) os << "\"" << op->value << "\""; } template inline void PrintBinaryExpr(const T* op, - const char *opstr, + const char* opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { if (op->dtype.lanes() == 1) { @@ -443,7 +443,7 @@ inline void PrintBinaryExpr(const T* op, } inline void PrintBinaryIntrinsic(const Call* op, - const char *opstr, + const char* opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { if (op->dtype.lanes() == 1) { @@ -457,65 +457,65 @@ inline void PrintBinaryIntrinsic(const Call* op, p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os); } } -void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) std::stringstream value; this->PrintExpr(op->value, value); os << CastFromTo(value.str(), op->value.dtype(), op->dtype); } -void CodeGenC::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } -void CodeGenC::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "+", os, this); } -void CodeGenC::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "-", os, this); } -void CodeGenC::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "*", os, this); } -void CodeGenC::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "/", os, this); } -void CodeGenC::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } -void CodeGenC::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "min", os, this); } -void CodeGenC::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "max", os, this); } -void CodeGenC::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "==", os, this); } -void CodeGenC::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const NE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "!=", os, this); } -void CodeGenC::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const LT* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<", os, this); } -void CodeGenC::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const LE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<=", os, this); } -void CodeGenC::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const GT* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">", os, this); } -void CodeGenC::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const GE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">=", os, this); } -void CodeGenC::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const And* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "&&", os, this); } -void CodeGenC::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Or* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "||", os, this); } -void CodeGenC::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Not* op, std::ostream& os) { // NOLINT(*) os << '!'; PrintExpr(op->a, os); } -void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*) if (op->call_type == Call::Extern || op->call_type == Call::PureExtern) { os << op->name << "("; @@ -875,12 +875,18 @@ void CodeGenC::VisitStmt_(const IfThenElse* op) { stream << "}\n"; } -void CodeGenC::VisitStmt_(const Block *op) { +void CodeGenC::VisitStmt_(const Block* op) { PrintStmt(op->first); if (op->rest.defined()) PrintStmt(op->rest); } -void CodeGenC::VisitStmt_(const Evaluate *op) { +void CodeGenC::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + PrintStmt(stmt); + } +} + +void CodeGenC::VisitStmt_(const Evaluate* op) { if (is_const(op->value)) return; const Call* call = op->value.as(); if (call) { @@ -906,7 +912,7 @@ void CodeGenC::VisitStmt_(const Evaluate *op) { } } -void CodeGenC::VisitStmt_(const ProducerConsumer *op) { +void CodeGenC::VisitStmt_(const ProducerConsumer* op) { PrintStmt(op->body); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index b8d357051998..1773bcbfd88f 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -141,6 +141,7 @@ class CodeGenC : void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const ProducerConsumer* op) override; /*! * Print Type represetnation of type t. diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 94ad8b76c9c9..9461eee0ee8c 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1221,6 +1221,12 @@ void CodeGenLLVM::VisitStmt_(const Block* op) { } } +void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + this->VisitStmt(stmt); + } +} + void CodeGenLLVM::VisitStmt_(const Evaluate* op) { MakeValue(op->value); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 08c836adf9d0..56a710392497 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -141,6 +141,7 @@ class CodeGenLLVM : void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const LetStmt* op) override; void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const ProducerConsumer* op) override; diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 7800e47319e0..79363fcbbb40 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -645,6 +645,12 @@ void CodeGenSPIRV::VisitStmt_(const Block* op) { } } +void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + this->VisitStmt(stmt); + } +} + void CodeGenSPIRV::VisitStmt_(const Evaluate* op) { MakeValue(op->value); } diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 3d16377271c4..b2f3fc3ad99e 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -99,6 +99,7 @@ class CodeGenSPIRV: void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const LetStmt* op) override; void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const ProducerConsumer* op) override; diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 9482b2cd649e..e0cda5d03f94 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -268,59 +268,59 @@ void CodeGenStackVM::PushCast(DataType dst, DataType src) { } } -void CodeGenStackVM::VisitExpr_(const StringImm *op) { +void CodeGenStackVM::VisitExpr_(const StringImm* op) { int sid = this->GetStrID(op->value); this->PushOp(StackVM::PUSH_I64, sid); } -void CodeGenStackVM::VisitExpr_(const IntImm *op) { +void CodeGenStackVM::VisitExpr_(const IntImm* op) { CHECK(op->value >= std::numeric_limits::min() && op->value <= std::numeric_limits::max()) << "Int constant exceed bound"; this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const UIntImm *op) { +void CodeGenStackVM::VisitExpr_(const UIntImm* op) { CHECK(op->value <= std::numeric_limits::max()) << "Int constant exceed bound"; this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const FloatImm *op) { +void CodeGenStackVM::VisitExpr_(const FloatImm* op) { LOG(FATAL) << "Float Imm is not supported"; } -void CodeGenStackVM::VisitExpr_(const Variable *op) { +void CodeGenStackVM::VisitExpr_(const Variable* op) { int vid = this->GetVarID(op); this->PushOp(StackVM::LOAD_HEAP, vid); } -void CodeGenStackVM::VisitExpr_(const Cast *op) { +void CodeGenStackVM::VisitExpr_(const Cast* op) { this->Push(op->value); PushCast(op->dtype, op->value.dtype()); } -void CodeGenStackVM::VisitExpr_(const Add *op) { +void CodeGenStackVM::VisitExpr_(const Add* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Sub *op) { +void CodeGenStackVM::VisitExpr_(const Sub* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Mul *op) { +void CodeGenStackVM::VisitExpr_(const Mul* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Div *op) { +void CodeGenStackVM::VisitExpr_(const Div* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Mod *op) { +void CodeGenStackVM::VisitExpr_(const Mod* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Min *op) { +void CodeGenStackVM::VisitExpr_(const Min* op) { this->Push(op->a); this->Push(op->b); this->PushOp(StackVM::PUSH_VALUE, -1); @@ -329,7 +329,7 @@ void CodeGenStackVM::VisitExpr_(const Min *op) { this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitExpr_(const Max *op) { +void CodeGenStackVM::VisitExpr_(const Max* op) { this->Push(op->a); this->Push(op->b); this->PushOp(StackVM::PUSH_VALUE, 0); @@ -338,34 +338,34 @@ void CodeGenStackVM::VisitExpr_(const Max *op) { this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitExpr_(const EQ *op) { +void CodeGenStackVM::VisitExpr_(const EQ* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const LE *op) { +void CodeGenStackVM::VisitExpr_(const LE* op) { PushBinary(StackVM::LE_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const NE *op) { +void CodeGenStackVM::VisitExpr_(const NE* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const LT *op) { +void CodeGenStackVM::VisitExpr_(const LT* op) { PushBinary(StackVM::LT_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const GE *op) { +void CodeGenStackVM::VisitExpr_(const GE* op) { PushBinary(StackVM::LT_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const GT *op) { +void CodeGenStackVM::VisitExpr_(const GT* op) { PushBinary(StackVM::LE_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const And *op) { +void CodeGenStackVM::VisitExpr_(const And* op) { this->Push(op->a); int64_t pc_jump = this->GetPC(); int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); @@ -375,7 +375,7 @@ void CodeGenStackVM::VisitExpr_(const And *op) { this->SetOperand(opr_index, diff); } -void CodeGenStackVM::VisitExpr_(const Or *op) { +void CodeGenStackVM::VisitExpr_(const Or* op) { this->Push(op->a); int64_t pc_jump = this->GetPC(); int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0); @@ -389,11 +389,11 @@ void CodeGenStackVM::VisitExpr_(const Not* op) { this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitStmt_(const ProducerConsumer *op) { +void CodeGenStackVM::VisitStmt_(const ProducerConsumer* op) { this->Push(op->body); } -void CodeGenStackVM::VisitStmt_(const For *op) { +void CodeGenStackVM::VisitStmt_(const For* op) { CHECK(is_zero(op->min)); int vid = this->AllocVarID(op->loop_var.get()); this->PushOp(StackVM::PUSH_I64, 0); @@ -417,11 +417,17 @@ void CodeGenStackVM::VisitStmt_(const For *op) { this->SetOperand(backward_jump, loop_head - label_bjump); } -void CodeGenStackVM::VisitStmt_(const Block *op) { +void CodeGenStackVM::VisitStmt_(const Block* op) { this->Push(op->first); if (op->rest.defined()) this->Push(op->rest); } +void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + this->Push(stmt); + } +} + void CodeGenStackVM::VisitStmt_(const Evaluate *ev) { if (is_const(ev->value)) return; const Call* op = ev->value.as(); @@ -444,7 +450,7 @@ void CodeGenStackVM::VisitStmt_(const Evaluate *ev) { } } -void CodeGenStackVM::VisitStmt_(const IfThenElse *op) { +void CodeGenStackVM::VisitStmt_(const IfThenElse* op) { this->Push(op->condition); int64_t label_ejump = this->GetPC(); int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); @@ -466,29 +472,29 @@ void CodeGenStackVM::VisitStmt_(const IfThenElse *op) { } } -void CodeGenStackVM::VisitStmt_(const LetStmt *op) { +void CodeGenStackVM::VisitStmt_(const LetStmt* op) { this->Push(op->value); int64_t vid = this->AllocVarID(op->var.get()); this->PushOp(StackVM::STORE_HEAP, static_cast(vid)); this->Push(op->body); } -void CodeGenStackVM::VisitExpr_(const Ramp *op) { +void CodeGenStackVM::VisitExpr_(const Ramp* op) { LOG(FATAL) << "Ramp is not supported"; } -void CodeGenStackVM::VisitExpr_(const Broadcast *op) { +void CodeGenStackVM::VisitExpr_(const Broadcast* op) { LOG(FATAL) << "Broadcast is not supported"; } -void CodeGenStackVM::VisitExpr_(const Select *op) { +void CodeGenStackVM::VisitExpr_(const Select* op) { this->Push(op->true_value); this->Push(op->false_value); this->Push(op->condition); this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { +void CodeGenStackVM::VisitStmt_(const AssertStmt* op) { if (const auto* str = op->message.as()) { int sid = this->GetStrID(str->value); this->Push(op->condition); @@ -497,11 +503,11 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { this->Push(op->body); } -void CodeGenStackVM::VisitStmt_(const AttrStmt *op) { +void CodeGenStackVM::VisitStmt_(const AttrStmt* op) { this->Push(op->body); } -void CodeGenStackVM::VisitExpr_(const Let *op) { +void CodeGenStackVM::VisitExpr_(const Let* op) { this->Push(op->value); int64_t vid = this->AllocVarID(op->var.get()); this->PushOp(StackVM::STORE_HEAP, static_cast(vid)); diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index dcae072c102d..63f354d4bffc 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -149,6 +149,7 @@ class CodeGenStackVM void VisitStmt_(const AssertStmt* op) final; void VisitStmt_(const Evaluate* op) final; void VisitStmt_(const Block* op) final; + void VisitStmt_(const SeqStmtNode* op) final; void VisitStmt_(const ProducerConsumer* op) final; private: diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 301602fb8238..d88716bc343d 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -76,24 +76,24 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { os << t.bits(); } -void CodeGenHybrid::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const IntImm* op, std::ostream& os) { // NOLINT(*) os << op->value; } -void CodeGenHybrid::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const UIntImm* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << op->value << ")"; } -void CodeGenHybrid::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; } -void CodeGenHybrid::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*) os << "'" << op->value << "'"; } template inline void PrintBinaryExpr(const T* op, - const char *opstr, + const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; @@ -115,7 +115,7 @@ inline void PrintBinaryExpr(const T* op, } inline void PrintBinaryIntrinsitc(const Call* op, - const char *opstr, + const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; @@ -127,7 +127,7 @@ inline void PrintBinaryIntrinsitc(const Call* op, os << ')'; } -void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) if (op->dtype == op->value.dtype()) { PrintExpr(op->value, stream); } else { @@ -138,76 +138,76 @@ void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) } } -void CodeGenHybrid::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } -void CodeGenHybrid::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "+", os, this); } -void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "-", os, this); } -void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "*", os, this); } -void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); } -void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloorDiv* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); } -void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } -void CodeGenHybrid::VisitExpr_(const FloorMod *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloorMod* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } -void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "min", os, this); } -void CodeGenHybrid::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "max", os, this); } -void CodeGenHybrid::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "==", os, this); } -void CodeGenHybrid::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const NE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "!=", os, this); } -void CodeGenHybrid::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const LT* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<", os, this); } -void CodeGenHybrid::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const LE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<=", os, this); } -void CodeGenHybrid::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const GT* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">", os, this); } -void CodeGenHybrid::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const GE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">=", os, this); } -void CodeGenHybrid::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const And* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "&&", os, this); } -void CodeGenHybrid::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Or* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "||", os, this); } -void CodeGenHybrid::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Not* op, std::ostream& os) { // NOLINT(*) os << "not "; PrintExpr(op->a, os); } -void CodeGenHybrid::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*) if (op->call_type == Call::Halide) { os << GetTensorID(op->func, op->value_index); os << "["; @@ -313,7 +313,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) { } } -void CodeGenHybrid::VisitStmt_(const Realize *op) { +void CodeGenHybrid::VisitStmt_(const Realize* op) { CHECK(alloc_storage_scope_.count(op->func)); if (!alloc_storage_scope_[op->func].empty()) { PrintIndent(); @@ -389,19 +389,25 @@ void CodeGenHybrid::VisitStmt_(const IfThenElse* op) { } } -void CodeGenHybrid::VisitStmt_(const Block *op) { +void CodeGenHybrid::VisitStmt_(const Block* op) { PrintStmt(op->first); if (op->rest.defined()) PrintStmt(op->rest); } -void CodeGenHybrid::VisitStmt_(const Evaluate *op) { +void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + PrintStmt(stmt); + } +} + +void CodeGenHybrid::VisitStmt_(const Evaluate* op) { if (is_const(op->value)) return; std::string str = PrintExpr(op->value); if (!str.empty()) stream << str << "\n"; } -void CodeGenHybrid::VisitStmt_(const ProducerConsumer *op) { +void CodeGenHybrid::VisitStmt_(const ProducerConsumer* op) { PrintStmt(op->body); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 647ef77fc534..8fb2dcd1cae6 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -132,6 +132,7 @@ class CodeGenHybrid : void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const ProducerConsumer* op) override; /*! * \brief Print Type represetnation of type t. diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 5b410d1e3741..08167affac00 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -504,6 +504,12 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo return Stmt(node); } +SeqStmt::SeqStmt(Array seq) { + auto node = make_object(); + node->seq = std::move(seq); + data_ = std::move(node); +} + Stmt Block::make(Stmt first, Stmt rest) { CHECK(first.defined()); CHECK(rest.defined()); @@ -1038,6 +1044,14 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) if (op->rest.defined()) p->Print(op->rest); }); +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + for (Stmt stmt : op->seq) { + p->Print(stmt); + } + }); + TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); @@ -1213,6 +1227,7 @@ TVM_REGISTER_NODE_TYPE(Allocate); TVM_REGISTER_NODE_TYPE(Free); TVM_REGISTER_NODE_TYPE(Realize); TVM_REGISTER_NODE_TYPE(Block); +TVM_REGISTER_NODE_TYPE(SeqStmtNode); TVM_REGISTER_NODE_TYPE(IfThenElse); TVM_REGISTER_NODE_TYPE(Evaluate); diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 202a5c27bd8b..3a50ac35914a 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -371,6 +371,18 @@ class VTInjector : public StmtExprMutator { return Block::make(first, rest); } } + // Seq + Stmt VisitStmt_(const SeqStmtNode* op) final { + CHECK_EQ(max_loop_depth_, 0); + auto fmutate = [this](const Stmt& s) { + int temp = max_loop_depth_; + max_loop_depth_ = 0; + Stmt ret = this->VisitStmt(s); + max_loop_depth_ = std::max(max_loop_depth_, temp); + return ret; + }; + return StmtMutator::VisitSeqStmt_(op, false, fmutate); + } // Allocate Stmt VisitStmt_(const Allocate* op) final { if (op->new_expr.defined() && !vt_loop_injected_) { diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 6a61d5e402f9..0c6cb17422bc 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -185,6 +185,15 @@ class IRDeepCompare : if (CompareStmt(op->rest, rhs->rest) != 0) return; } + + void VisitStmt_(const SeqStmtNode* op, const Stmt& other) final { + const SeqStmtNode* rhs = other.as(); + if (CompareValue(op->size(), rhs->size()) != 0) return; + for (size_t i = 0; i < op->size(); ++i) { + if (CompareStmt(op->seq[i], rhs->seq[i]) != 0) return; + } + } + void VisitStmt_(const Evaluate* op, const Stmt& other) final { const Evaluate* rhs = other.as(); CompareExpr(op->value, rhs->value); diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index efc43a2ffa3d..a8bba375fb6a 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -214,6 +214,12 @@ void StmtVisitor::VisitStmt_(const Block* op) { this->VisitStmt(op->rest); } +void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { + VisitArray(op->seq, [this](const Stmt& s) { + this->VisitStmt(s); + }); +} + void StmtVisitor::VisitStmt_(const Evaluate* op) { this->VisitExpr(op->value); } @@ -504,6 +510,63 @@ Stmt StmtMutator::VisitStmt_(const Block* op) { } } +Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { + Array seq = Internal::Mutate(this, op->seq); + if (seq.same_as(op->seq)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->seq = std::move(seq); + return Stmt(n); + } +} + +// advanced visit function for seqstmt. +Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, + bool flatten_before_visit, + std::function fmutate) { + if (flatten_before_visit) { + // Pass 1, check if we need to flatten. + bool need_flatten = false; + for (size_t i = 0; i < op->seq.size(); ++i) { + Stmt tmp = (*op)[i]; + if (tmp.as()) need_flatten = true; + } + flatten_before_visit = need_flatten; + } + // function to run the visit. + auto frunvisit = [&](const SeqStmtNode* op) { + Array seq = + fmutate != nullptr ? + MutateArray(op->seq, fmutate, allow_copy_on_write_) : + Internal::Mutate(this, op->seq); + if (seq.same_as(op->seq)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->seq = std::move(seq); + return Stmt(n); + } + }; + if (flatten_before_visit) { + Array seq; + SeqStmt::Flattener flattener(&seq); + flattener(0, op->seq); + // NOTE: If copy on write is allowed + // the assignment to seq below will + // destruct the original seq. + // + // Such destruction removes duplicated reference + // count to children and still enables COW for + // child Stmt. + ObjectPtr n = CopyOnWrite(op); + n->seq = std::move(seq); + return frunvisit(n.operator->()); + } else { + return frunvisit(op); + } +} + Stmt StmtMutator::VisitStmt_(const AssertStmt* op) { Expr condition = this->VisitExpr(op->condition); Expr message = this->VisitExpr(op->message); diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index eeed10ebe3ae..0e2c5fea390b 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -88,6 +88,10 @@ class AttrScopeLifter : public StmtMutator { return MergeSeq(seq); } + Stmt VisitStmt_(const SeqStmtNode* op) final { + return StmtMutator::VisitSeqStmt_(op, true); + } + Stmt VisitStmt_(const IfThenElse* op) final { if (!op->else_case.defined()) { return StmtMutator::VisitStmt_(op); diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 11cf57490450..2a94915f65e2 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -116,6 +116,18 @@ class CandidateSelector final : public StmtExprVisitor { no_split_ = no_split_ || temp; } + void VisitStmt_(const SeqStmtNode* op) final { + bool init_no_split = no_split_; + for (Stmt stmt : op->seq) { + // erase the no split state of before visiting the next one. + bool temp = init_no_split; + std::swap(temp, no_split_); + this->VisitStmt(stmt); + // restore the no split flag. + no_split_ = no_split_ || temp; + } + } + void VisitExpr_(const Call* op) final { if (op->is_intrinsic(Call::likely)) { in_likely_ = true; diff --git a/src/pass/remove_no_op.cc b/src/pass/remove_no_op.cc index 8b0b57c10984..afe14cf68aa0 100644 --- a/src/pass/remove_no_op.cc +++ b/src/pass/remove_no_op.cc @@ -93,6 +93,7 @@ class NoOpRemover : public StmtMutator { if (HasSideEffect(op->value)) return GetRef(op); return Evaluate::make(0); } + Stmt VisitStmt_(const Block* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); @@ -105,6 +106,37 @@ class NoOpRemover : public StmtMutator { } } + Stmt VisitStmt_(const SeqStmtNode* op) final { + Stmt ret = StmtMutator::VisitSeqStmt_(op, true); + op = ret.as(); + CHECK(op != nullptr); + bool need_compact = false; + for (size_t i = 0; i < op->size(); ++i) { + if (is_no_op(op->seq[i])) need_compact = true; + } + if (need_compact) { + auto n = CopyOnWrite(op); + size_t top = 0; + for (size_t i = 0; i < n->seq.size(); ++i) { + if (!is_no_op(n->seq[i])) { + n->seq.Set(top++, n->seq[i]); + } + } + if (top == 1) { + return n->seq[0]; + } else { + n->seq.resize(top); + return Stmt(n); + } + } else { + if (op->size() == 1) { + return op->seq[0]; + } else { + return ret; + } + } + } + private: Stmt MakeEvaluate(Expr value) { if (HasSideEffect(value)) { diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index 9fc87f3a0d6b..577984550f60 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -142,6 +142,23 @@ class LoopUnroller : public StmtExprMutator { } } + Stmt VisitStmt_(const SeqStmtNode* op) final { + auto fmutate = [this](const Stmt& s) { + int step_count = step_count_; + int unroll_depth = unroll_depth_; + int normal_loop_depth = normal_loop_depth_; + step_count_ = 0; + unroll_depth_ = 0; + normal_loop_depth_ = 0; + Stmt ret = this->VisitStmt(s); + step_count_ += step_count; + normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); + unroll_depth_ = std::max(unroll_depth_, unroll_depth); + return ret; + }; + return StmtMutator::VisitSeqStmt_(op, false, fmutate); + } + Stmt Unroll(const For* op) { int value = GetExtent(op); // For loop must have a constant integer extent diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index a10ddd413c20..a37f6f97d920 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -155,6 +155,9 @@ TEST(IRF, StmtMutator) { Expr VisitExpr_(const Add* op) final { return op->a; } + Stmt VisitStmt_(const SeqStmtNode* op) final { + return StmtMutator::VisitSeqStmt_(op, true); + } Expr VisitExpr(const Expr& expr) final { return ExprMutator::VisitExpr(expr); } @@ -219,6 +222,35 @@ TEST(IRF, StmtMutator) { auto res = v(std::move(body)); CHECK(res.as()->value.as()->args[0].same_as(x)); } + { + auto body = fmakealloc(); + Stmt body2 = Evaluate::make(1); + auto* ref2 = body2.get(); + auto* extentptr = body.as()->extents.get(); + // construct a recursive SeqStmt. + body = SeqStmt({body}); + body = SeqStmt({body, body2}); + body = SeqStmt({body, body2}); + body = v(std::move(body)); + // the seq get flattened + CHECK(body.as()->size() == 3); + CHECK(body.as()->seq[0].as()->extents.get() == extentptr); + CHECK(body.as()->seq[1].get() == ref2); + } + + { + // Cannot cow because of bref + auto body = fmakealloc(); + Stmt body2 = Evaluate::make(1); + auto* extentptr = body.as()->extents.get(); + // construct a recursive SeqStmt. + body = SeqStmt({body}); + auto bref = body; + body = SeqStmt({body, body2}); + body = v(std::move(body)); + // the seq get flattened + CHECK(body.as()->seq[0].as()->extents.get() != extentptr); + } } int main(int argc, char ** argv) { diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index c910c62424f0..8b9da90c914c 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -34,8 +34,8 @@ def test_for(): body = body.body assert isinstance(body, tvm.stmt.For) body = body.body - assert isinstance(body, tvm.stmt.Block) - assert isinstance(body.rest, tvm.stmt.For) + assert isinstance(body, tvm.stmt.SeqStmt) + assert isinstance(body[1], tvm.stmt.For) def test_if(): ib = tvm.ir_builder.create() diff --git a/tests/python/unittest/test_pass_equal.py b/tests/python/unittest/test_pass_equal.py index 8bd491bb5c8e..1f5bb9cba9a9 100644 --- a/tests/python/unittest/test_pass_equal.py +++ b/tests/python/unittest/test_pass_equal.py @@ -53,6 +53,7 @@ def func2(): A[i] = A[i] + 1 with ib.for_range(0, 10, name="j") as j: A[j] = A[j] + 2 + A[j] = A[j] + 2 return ib.get() assert tvm.ir_pass.Equal(func1(), func1()) diff --git a/tests/python/unittest/test_pass_inject_vthread.py b/tests/python/unittest/test_pass_inject_vthread.py index cb2fa201aff4..a3d059787ab8 100644 --- a/tests/python/unittest/test_pass_inject_vthread.py +++ b/tests/python/unittest/test_pass_inject_vthread.py @@ -92,8 +92,8 @@ def test_vthread_if_then_else(): B[i] = A[i * nthread + tx] + 2 stmt = ib.get() stmt = tvm.ir_pass.InjectVirtualThread(stmt) - assert stmt.body.body.body.first.else_case != None - assert stmt.body.body.body.rest.else_case == None + assert stmt.body.body.body[0].else_case != None + assert stmt.body.body.body[1].else_case == None if __name__ == "__main__": test_vthread_extern() diff --git a/tests/python/unittest/test_pass_remove_no_op.py b/tests/python/unittest/test_pass_remove_no_op.py index d1ea450b53de..d287b8591fb3 100644 --- a/tests/python/unittest/test_pass_remove_no_op.py +++ b/tests/python/unittest/test_pass_remove_no_op.py @@ -16,6 +16,9 @@ # under the License. import tvm +def nop(): + return tvm.stmt.Evaluate(0) + def test_remove_no_op(): i = tvm.var('i') j = tvm.var('j') @@ -37,12 +40,13 @@ def test_remove_no_op(): store = tvm.make.Store(Ab.data, tvm.make.Load(dtype, Ab.data, i) + 1, i + 1) - stmt2 = tvm.make.Block(stmt, store) + stmt2 = tvm.stmt.SeqStmt([nop(), tvm.stmt.SeqStmt([store, nop()])]) assert(tvm.ir_pass.RemoveNoOp(stmt2) == store) # remove zero extent loop stmt3 = tvm.make.For(i, 0, 0, 0, 0, store) ret = tvm.ir_pass.RemoveNoOp(stmt3) assert(isinstance(ret, tvm.stmt.Evaluate)) + if __name__ == "__main__": test_remove_no_op() diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py index fc8d0dc5f5c8..c94ffe0bde14 100644 --- a/tests/python/unittest/test_pass_unroll.py +++ b/tests/python/unittest/test_pass_unroll.py @@ -43,13 +43,13 @@ def test_unroll_loop(): ib.scope_attr(tvm.const(0, "int32"), "pragma_auto_unroll_max_step", 16) ib.emit(stmt) wrapped = ib.get() - wrapped = tvm.make.Block(wrapped, stmt) + wrapped = tvm.stmt.SeqStmt([wrapped, stmt]) assert isinstance(ret, tvm.stmt.For) ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False) - assert isinstance(ret.first, tvm.stmt.For) - assert ret.first.for_type == tvm.stmt.For.Unrolled - assert isinstance(ret.rest, tvm.stmt.For) - assert ret.rest.for_type != tvm.stmt.For.Unrolled + assert isinstance(ret[0], tvm.stmt.For) + assert ret[0].for_type == tvm.stmt.For.Unrolled + assert isinstance(ret[1], tvm.stmt.For) + assert ret[1].for_type != tvm.stmt.For.Unrolled def test_unroll_fake_loop(): ib = tvm.ir_builder.create() @@ -65,7 +65,7 @@ def test_unroll_fake_loop(): stmt = ib.get() ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) - assert isinstance(ret.first, tvm.stmt.Store) + assert isinstance(ret[0], tvm.stmt.Store) def test_unroll_single_count_loops(): n = tvm.var('n') From 0b913461437891ff767ed36593502eeb77f4e155 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 5 Jan 2020 16:04:46 -0800 Subject: [PATCH 2/4] [REFACTOR] Migrate use of Block to SeqStmt. --- include/tvm/ir.h | 5 +- python/tvm/hybrid/parser.py | 8 +-- python/tvm/ir_builder.py | 2 +- python/tvm/stmt.py | 13 ++--- src/op/compute_op.cc | 8 +-- src/op/cross_thread_reduction.cc | 4 +- src/op/tensor_compute_op.cc | 2 +- src/op/tensorize.cc | 2 +- src/pass/arg_binder.cc | 2 +- src/pass/coproc_sync.cc | 24 +++----- src/pass/inject_double_buffer.cc | 6 +- src/pass/inject_prefetch.cc | 2 +- src/pass/inject_virtual_thread.cc | 9 ++- src/pass/ir_util.cc | 14 ++--- src/pass/ir_util.h | 7 --- src/pass/lift_attr_scope.cc | 55 ++++++++++++++++++- src/pass/loop_partition.cc | 13 +---- src/pass/lower_thread_allreduce.cc | 10 ++-- src/pass/lower_tvm_builtin.cc | 17 +++--- src/pass/make_api.cc | 2 +- src/pass/remove_no_op.cc | 2 +- src/pass/storage_sync.cc | 6 +- src/pass/unroll_loop.cc | 10 +--- src/schedule/schedule_ops.cc | 2 +- tests/python/unittest/test_hybrid_script.py | 28 +++++----- .../python/unittest/test_lang_constructor.py | 6 -- tests/python/unittest/test_lang_tensor.py | 4 +- .../unittest/test_pass_lift_attr_scope.py | 19 +++++++ .../unittest/test_pass_loop_partition.py | 16 +++--- .../python/unittest/test_pass_storage_sync.py | 4 +- .../unittest/test_schedule_schedule_ops.py | 2 +- vta/python/vta/ir_pass.py | 4 +- 32 files changed, 167 insertions(+), 141 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 02c5aa439d7d..03c66a80bfd7 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1071,8 +1071,8 @@ class SeqStmt : public Stmt { /*! * \brief Construct a flattened sequence statement. * - * \note This function can return the element if there - * is only one element in the sequence. + * \note This function can directly return an element + * if it is the only element in the sequence. * \param seq_args The list of arguments to be flattened. * \tparam Args arguments * \return The constructed statement @@ -1092,6 +1092,7 @@ class SeqStmt : public Stmt { : seq_(seq) {} void operator()(size_t i, const Stmt& stmt) const { + if (!stmt.defined()) return; if (auto* op = stmt.as()) { operator()(0, op->seq); } else if (auto* op = stmt.as()) { diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 816a0e1d7ad3..7e5659a8e9bb 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -37,6 +37,8 @@ from .. import _api_internal as _tvm_internal from .. import expr as _expr from .. import make as _make +from .. import stmt as _stmt + from .. import api as _api from .. import ir_pass as _ir_pass @@ -48,11 +50,7 @@ def concat_list_to_block(lst): n = len(lst) if n == 1: return lst[0] - body = lst[n - 1] - for i in range(1, n): - stmt = lst[n - 1 - i] - body = _make.Block(stmt, body) - return body + return _stmt.SeqStmt(lst) def visit_list_to_block(visit, lst): diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index 2a87871377b1..bf41c98a7bdd 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -120,7 +120,7 @@ def _pop_seq(self): seq = self._seq_stack.pop() if not seq or callable(seq[-1]): seq.append(_make.Evaluate(0)) - seqwrap = lambda x : x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x))) + seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x))) ret_seq = [seq[-1]] for s in reversed(seq[:-1]): diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index d16b71c54ba4..3073e86cd890 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -395,16 +395,15 @@ def stmt_seq(*args): stmt : Stmt The combined statement. """ - return SeqStmt(args) - - """ - ret = None + ret = [] for value in args: if not isinstance(value, Stmt): value = Evaluate(value) - ret = value if ret is None else Block(ret, value) - return ret if ret else Evaluate(0) - """ + ret.append(value) + if len(ret) == 1: + return ret[0] + return SeqStmt(ret) + def stmt_list(stmt): """Make list of stmt from blocks. diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 939327890fec..6146284554b4 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -337,8 +337,8 @@ void MakeReduction(const ComputeOpNode* op, provides.emplace_back(Provide::make( t->op, t->value_index, update_value[i], args)); } - *init = Block::make(inits); - *provide = Block::make(provides); + *init = SeqStmt::Flatten(inits); + *provide = SeqStmt::Flatten(provides); if (!is_one(reduce->condition)) { *provide = IfThenElse::make(reduce->condition, *provide); } @@ -382,7 +382,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, if (debug_keep_trivial_loop) { provide = MergeNest(common, provide); } else { - provide = MergeNest(common, Block::make(init, provide)); + provide = MergeNest(common, SeqStmt::Flatten(init, provide)); } // run substitution in the on the full nest, because loop condition // could depend on outer loops. @@ -392,7 +392,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, for (size_t i = 0; i < self->body.size(); ++i) { provides.emplace_back(MakeProvide(self, stage->op.output(i))); } - Stmt provide = Block::make(provides); + Stmt provide = SeqStmt::Flatten(provides); provide = MergeNest(n.main_nest, provide); // run substitution in the on the full nest, because loop condition // could depend on outer loops. diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc index 4a3aa54ccc6d..ab56fc9657d2 100644 --- a/src/op/cross_thread_reduction.cc +++ b/src/op/cross_thread_reduction.cc @@ -100,10 +100,10 @@ Stmt MakeCrossThreadReduction( stage->op, idx, Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args); } - Stmt assign_body = Block::make(assigns); + Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(op::MakeIfNest(conds), assign_body); - Stmt body = Block::make(reduce_body, assign_body); + Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate::make( res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index 76ecf3417d36..a6252df05246 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -242,7 +242,7 @@ Stmt TensorComputeOpNode::BuildProvide( update = MergeNest(binder.asserts(), update); update = op::Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); - return MergeNest(common, Block::make(init, update)); + return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. CHECK(this->intrin->body.defined()) diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index f6fa00dad859..0df8e889efeb 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -478,7 +478,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, update = MergeNest(binder.asserts(), update); update = Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); - return MergeNest(common, Block::make(init, update)); + return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. CHECK(intrin->body.defined()) diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index e4ff9cb457a5..a0ddcd98b260 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -240,7 +240,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, AssertStmt::make(arith::ComputeReduce(conds, Expr()), stride_err_msg.str(), Evaluate::make(0)); check = IfThenElse::make(Not::make(is_null), check, Stmt()); - asserts_.emplace_back(Block::make(check, Evaluate::make(0))); + asserts_.emplace_back(SeqStmt({check, Evaluate::make(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { DataType stype = buffer->DefaultIndexType(); diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index 924305199628..33af959f102e 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -655,24 +655,14 @@ class CoProcSyncInserter : public StmtMutator { } Stmt VisitStmt(const Stmt& stmt) final { - Stmt before, after; - auto it = insert_before_.find(stmt.get()); - if (it != insert_before_.end()) { - before = MergeSeq(std::vector( - it->second.rbegin(), it->second.rend())); - } - it = insert_after_.find(stmt.get()); - if (it != insert_after_.end()) { - after = MergeSeq(it->second); - } + auto it_before = insert_before_.find(stmt.get()); + auto it_after = insert_after_.find(stmt.get()); Stmt new_stmt = StmtMutator::VisitStmt(stmt); - if (before.defined()) { - new_stmt = Block::make(before, new_stmt); - } - if (after.defined()) { - new_stmt = Block::make(new_stmt, after); - } - return new_stmt; + + return SeqStmt::Flatten( + it_before != insert_before_.end() ? it_before->second : std::vector(), + new_stmt, + it_after != insert_after_.end() ? it_after->second : std::vector()); } private: diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index 84b2f705e995..0158a949da53 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -147,7 +147,7 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt loop = For::make( outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, - MergeSeq(loop_seq)); + SeqStmt::Flatten(loop_seq)); // tail std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); @@ -158,9 +158,9 @@ class DoubleBufferInjector : public StmtExprMutator { IfThenElse::make(idx < old_loop->extent, Substitute(tail_body, vmap))); } - stmt = Block::make(loop, MergeSeq(tail_seq)); + stmt = SeqStmt::Flatten(loop, tail_seq); } - stmt = Block::make(MergeSeq(it->second), stmt); + stmt = SeqStmt::Flatten(it->second, stmt); } it = loop_allocs_.find(op); if (it != loop_allocs_.end()) { diff --git a/src/pass/inject_prefetch.cc b/src/pass/inject_prefetch.cc index 73e1dc9a3e30..73725c20583c 100644 --- a/src/pass/inject_prefetch.cc +++ b/src/pass/inject_prefetch.cc @@ -59,7 +59,7 @@ class PrefetchInjector : public StmtMutator { vectorized_.erase(iter_var); Stmt prefetch = Prefetch::make(ts->op, ts->value_index, ts->dtype, region); - return Block::make(prefetch, op->body); + return SeqStmt({prefetch, op->body}); } return ret; } diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 3a50ac35914a..848aa8f8e34e 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -454,12 +454,11 @@ class VTInjector : public StmtExprMutator { // only unroll if number of vthreads are small if (max_loop_depth_ == 0 && num_threads_ < 16) { // do unrolling if it is inside innermost content. - Stmt blk = Substitute(stmt, {{var_, make_zero(var_.dtype())}}); - for (int i = 1; i < num_threads_; ++i) { - blk = Block::make( - blk, Substitute(stmt, {{var_, make_const(var_.dtype(), i)}})); + Array seq; + for (int i = 0; i < num_threads_; ++i) { + seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}})); } - return blk; + return SeqStmt::Flatten(seq); } else { // insert a for loop Var idx(var_->name_hint + ".s", var_->dtype); diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index cdc708ce5faf..7d7b15d18a4c 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -56,6 +56,11 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { CHECK(is_no_op(n->rest)); n->rest = body; body = Stmt(n); + } else if (const auto* seq = s.as()) { + auto n = make_object(*seq); + CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); + n->seq.Set(n->size() - 1, body); + body = Stmt(n); } else if (const auto* assert_ = s.as()) { auto n = make_object(*assert_); CHECK(is_no_op(n->body)); @@ -80,14 +85,5 @@ Stmt MergeNest(const std::vector >& nest, Stmt body) { return body; } -Stmt MergeSeq(const std::vector& seq) { - if (seq.size() == 0) return Evaluate::make(0); - Stmt body = seq[0]; - for (size_t i = 1; i < seq.size(); ++i) { - body = Block::make(body, seq[i]); - } - return body; -} - } // namespace ir } // namespace tvm diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 0f8bb990c2d3..900d6d59853a 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -47,13 +47,6 @@ Stmt MergeNest(const std::vector& nest, Stmt body); */ Stmt MergeNest(const std::vector >& nest, Stmt body); -/*! - * \brief combine sequence of operations. - * \param seq The sequence. - * \return The combined Stmt - */ -Stmt MergeSeq(const std::vector& seq); - /*! * \brief update array with an unary function * \param arr array diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 0e2c5fea390b..6d998a6acc3a 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -85,11 +85,59 @@ class AttrScopeLifter : public StmtMutator { seq[1].same_as(op->rest)) { return GetRef(op); } - return MergeSeq(seq); + return SeqStmt::Flatten(seq); } Stmt VisitStmt_(const SeqStmtNode* op) final { - return StmtMutator::VisitSeqStmt_(op, true); + // remember the decorations. + std::vector attr_node; + std::vector attr_value; + + auto fmutate = [&](const Stmt& s) { + attr_node_ = ObjectRef(); + attr_value_ = Expr(); + Stmt ret = this->VisitStmt(s); + attr_node.push_back(attr_node_); + attr_value.push_back(attr_value_); + return ret; + }; + Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate); + if (attr_node.size() == 0) return ret; + + op = ret.as(); + CHECK(op != nullptr); + Array reorg; + // check if all decorations are common. + for (size_t begin = 0; begin < attr_node.size();) { + size_t end = begin + 1; + while (end < attr_node.size() && + attr_node[end].same_as(attr_node[begin]) && + ValueSame(attr_value[end], attr_value[begin])) { + ++end; + } + // covers everything + // lift attr to parent. + if (begin == 0 && end == attr_node.size()) { + attr_node_ = attr_node[0]; + attr_value_ = attr_value[0]; + return ret; + } + // construct subsegments. + Array seq; + for (size_t i = begin; i < end; ++i) { + seq.push_back(op->seq[i]); + } + Stmt stmt = SeqStmt::Flatten(seq); + if (attr_node[begin].defined()) { + stmt = AttrStmt::make( + attr_node[begin], attr_key_, attr_value[begin], stmt); + } + reorg.push_back(stmt); + begin = end; + } + attr_node_ = ObjectRef(); + attr_value_ = Expr(); + return SeqStmt::Flatten(reorg); } Stmt VisitStmt_(const IfThenElse* op) final { @@ -151,7 +199,7 @@ class AttrScopeLifter : public StmtMutator { } } - std::vector MutateSeq(const std::vector& seq) { + std::vector MutateSeq(const Array& seq) { std::vector res_seq; ObjectRef curr_node; Expr curr_value; @@ -201,6 +249,7 @@ class AttrScopeLifter : public StmtMutator { // value comparison that also compares content of int constant static bool ValueSame(const Expr& a, const Expr& b) { if (a.same_as(b)) return true; + if (!a.defined() || !b.defined()) return false; if (a->type_index() != b->type_index()) return false; if (a.dtype() != b.dtype()) return false; if (const IntImm* op = a.as()) { diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 2a94915f65e2..09559c1bb8b2 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -414,16 +414,6 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, return std::make_pair(interval, cond_set); } -Stmt AppendStmts(const Stmt& a, const Stmt& b) { - if (!a.defined()) { - return b; - } else if (!b.defined()) { - return a; - } else { - return Block::make(a, b); - } -} - /* * Tries to recursively partition the range of the variable (given by var) of * the for loop (given by node and stmt) into a @@ -589,8 +579,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, } } } - s = AppendStmts(pre_stmt, mid_stmt); - s = AppendStmts(s, post_stmt); + s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt); } else { Expr cond = const_true(); if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index f77443a1b4c3..4712bccb415a 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -185,7 +185,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var buffer_var = Downcast(call->args[2+size+i]); stores[i] = Store::make(buffer_var, values[i], 0, pred); } - return Block::make(stores); + return SeqStmt::Flatten(stores); } // Whether the threadIdx.x is involved in reduction. if (vred[0].scope.dim_index == 0) { @@ -218,7 +218,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { {Expr(group_extent), Expr(reduce_extent)}, pred, Evaluate::make(0)); } - return MergeSeq(seq); + return SeqStmt::Flatten(seq); } // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode *combiner, @@ -252,7 +252,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true()); } - return Block::make(stores); + return SeqStmt::Flatten(stores); }; // Step one, check for if (reduce_align > reduce_extent) { @@ -280,11 +280,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { seq.emplace_back(SyncThread("warp")); } if (in_warp_seq.size() != 0) { - Stmt warp_body = MergeSeq(in_warp_seq); + Stmt warp_body = SeqStmt::Flatten(in_warp_seq); seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body)); seq.emplace_back(SyncThread("shared")); } - return MergeSeq(seq); + return SeqStmt::Flatten(seq); } // Flatten the thread index. // Also return a warp number, diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 0cc9ea21a61a..c0b98793c7f9 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -72,11 +72,14 @@ class BuiltinLower : public StmtExprMutator { auto stmt = StmtExprMutator::VisitStmt(s); CHECK_EQ(run_shape_stack_, 0); CHECK_EQ(run_array_stack_, 0); - while (prep_seq_.size() != 0) { - stmt = Block::make(prep_seq_.back(), stmt); - prep_seq_.pop_back(); + + if (prep_seq_.size() != 0) { + Stmt ret = SeqStmt::Flatten(prep_seq_, stmt); + prep_seq_.clear(); + return ret; + } else { + return stmt; } - return stmt; } Stmt VisitStmt_(const Allocate* op) { @@ -107,12 +110,12 @@ class BuiltinLower : public StmtExprMutator { intrinsic::tvm_throw_last_error, {}, Call::Intrinsic)); - Stmt body = Block::make( + Stmt body = SeqStmt({ IfThenElse::make(Call::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, {op->buffer_var}, Call::PureIntrinsic), throw_last_error), - op->body); + op->body}); Stmt alloca = LetStmt::make( op->buffer_var, @@ -133,7 +136,7 @@ class BuiltinLower : public StmtExprMutator { op->buffer_var}, Call::Extern); Stmt free_stmt = IfThenElse::make(free_op != make_zero(DataType::Int(32)), throw_last_error); - body = Block::make(alloca, free_stmt); + body = SeqStmt({alloca, free_stmt}); body = AttrStmt::make( op->buffer_var, attr::storage_alignment, make_const(DataType::Int(32), runtime::kTempAllocaAlignment), diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 274383f86691..f065502db6b4 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -189,7 +189,7 @@ LoweredFunc MakeAPI(Stmt body, DataType::Int(32), intrinsic::tvm_call_packed, {StringImm::make(runtime::symbol::tvm_set_device), device_type, device_id}, Call::Intrinsic))); - body = Block::make(set_device, body); + body = SeqStmt({set_device, body}); } n->body = MergeNest( {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); diff --git a/src/pass/remove_no_op.cc b/src/pass/remove_no_op.cc index afe14cf68aa0..59ec113fcaff 100644 --- a/src/pass/remove_no_op.cc +++ b/src/pass/remove_no_op.cc @@ -150,7 +150,7 @@ class NoOpRemover : public StmtMutator { for (Expr e : values) { if (HasSideEffect(e)) { if (stmt.defined()) { - stmt = Block::make(stmt, Evaluate::make(e)); + stmt = SeqStmt({stmt, Evaluate::make(e)}); } else { stmt = Evaluate::make(e); } diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 6ace4f7f85b4..85cf2b92f9e4 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -216,7 +216,7 @@ class ThreadSyncInserter : public StmtExprMutator { } // Mutate after query, to avoid stmt change. auto ret = StmtExprMutator::VisitStmt(stmt); - ret = Block::make(barrier, ret); + ret = SeqStmt({barrier, ret}); return ret; } else { return StmtExprMutator::VisitStmt(stmt); @@ -313,10 +313,10 @@ class ThreadSyncInserter : public StmtExprMutator { rw_stats_.clear(); Stmt kinit = Evaluate::make( Call::make(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic)); - body = Block::make(kinit, body); + body = SeqStmt({kinit, body}); body = AttrStmt::make( op->node, op->attr_key, op->value, body); - return Block::make(prep, body); + return SeqStmt({prep, body}); } Stmt MakeGlobalBarrier() { CHECK(sync_scope_.rank == StorageRank::kGlobal); diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index 577984550f60..269bf0933c0b 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -166,17 +166,13 @@ class LoopUnroller : public StmtExprMutator { if (value == 0) return Evaluate::make(0); Stmt body = op->body; Map vmap; - Stmt unrolled; + Array unrolled; for (int i = 0; i < value; ++i) { vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); Stmt step = Substitute(body, vmap); - if (unrolled.defined()) { - unrolled = Block::make(unrolled, step); - } else { - unrolled = step; - } + unrolled.push_back(step); } - return unrolled; + return SeqStmt::Flatten(unrolled); } private: diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index b177d6f8d22f..2d494522b211 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -53,7 +53,7 @@ Stmt MakePipeline(const Stage& s, if (consumer.defined() && !is_no_op(consumer)) { consumer = ProducerConsumer::make(s->op, false, consumer); - pipeline = Block::make(producer, consumer); + pipeline = SeqStmt({producer, consumer}); } pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 1f101a1e92e8..c3c40cf740ad 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -123,12 +123,12 @@ def test_outer_product(): assert ibody.extent.name == 'm' #Check loop body jblock = ibody.body - assert isinstance(jblock, tvm.stmt.Block) - jbody = jblock.first + assert isinstance(jblock, tvm.stmt.SeqStmt) + jbody = jblock[0] assert isinstance(jbody, tvm.stmt.AssertStmt) assert isinstance(jbody.message, tvm.expr.StringImm) assert jbody.message.value == "index out of range!" - jbody = jblock.rest + jbody = jblock[1] assert isinstance(jbody, tvm.stmt.Provide) assert jbody.func.name == 'c' assert len(jbody.args) == 2 @@ -191,12 +191,12 @@ def fanout(n, a): assert abody.func.name == 'sigma' #Check i loop body rbody = abody.body - assert isinstance(rbody.first, tvm.stmt.Provide) - assert rbody.first.func.name == 'sigma' - assert len(rbody.first.args) == 1 - assert rbody.first.args[0].value == 0 + assert isinstance(rbody[0], tvm.stmt.Provide) + assert rbody[0].func.name == 'sigma' + assert len(rbody[0].args) == 1 + assert rbody[0].args[0].value == 0 #Check fanout loop - jloop = rbody.rest.first + jloop = rbody[1] assert jloop.loop_var.name == 'j' assert jloop.min.value == 0 assert jloop.extent.value == 3 @@ -214,7 +214,7 @@ def fanout(n, a): assert value.b.name == 'a' assert len(value.b.args) == 1 assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var) - divide= rbody.rest.rest.first + divide= rbody[2] assert isinstance(divide, tvm.stmt.Provide) assert len(divide.args) == 1 assert divide.args[0].value == 0 @@ -224,7 +224,7 @@ def fanout(n, a): assert len(value.a.args) == 1 assert value.a.args[0].value == 0 assert abs(value.b.value - (1 / 3.0)) < 1e-5 - write = rbody.rest.rest.rest + write = rbody[3] assert isinstance(write, tvm.stmt.Provide) assert write.func.name == 'b' assert write.value.name == 'sigma' @@ -257,9 +257,9 @@ def looptype(a, b, c): ir = d.op.body except: return - iloop = ir.first - jloop = ir.rest.first - kloop = ir.rest.rest + iloop = ir[0] + jloop = ir[1] + kloop = ir[2] assert iloop.for_type == tvm.stmt.For.Parallel assert jloop.for_type == tvm.stmt.For.Vectorized assert kloop.for_type == tvm.stmt.For.Unrolled @@ -802,7 +802,7 @@ def sum_array(inputs): inputs = [] for i in range(n): inputs.append(tvm.placeholder((10,), name='t%s' % i, dtype='float32')) - + out = sum_array(tvm.convert(inputs)) assert len(out.op.inputs) == n diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py index a0d39f2daffe..fe329494e24e 100644 --- a/tests/python/unittest/test_lang_constructor.py +++ b/tests/python/unittest/test_lang_constructor.py @@ -146,12 +146,6 @@ def test_stmt_constructor(): assert isinstance(x, tvm.stmt.AttrStmt) assert x.value.value == 1 - x = tvm.stmt.Block(tvm.stmt.Evaluate(11), - nop) - assert isinstance(x, tvm.stmt.Block) - assert x.first.value.value == 11 - assert x.rest == nop - x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"), tvm.convert("hellow"), nop) diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 44aca3b324bb..7e9f59bf348d 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -171,8 +171,8 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate) - assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body[0], tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate) def test_tensor_scan(): m = tvm.var("m") diff --git a/tests/python/unittest/test_pass_lift_attr_scope.py b/tests/python/unittest/test_pass_lift_attr_scope.py index d786ca8c8108..b281e17bc633 100644 --- a/tests/python/unittest/test_pass_lift_attr_scope.py +++ b/tests/python/unittest/test_pass_lift_attr_scope.py @@ -31,10 +31,29 @@ def test_coproc_lift(): with ib.for_range(0, 10, name="j") as j: ib.scope_attr(cp, "coproc_uop_scope", value) A[j] = A[j] + 2 + A[j] = A[j] + 3 + A[j] = A[j] + 3 body = ib.get() body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope") assert body.body.body.node == cp + # only able to lift to the common pattern of the last two fors. + ib = tvm.ir_builder.create() + A = ib.allocate("float32", n, name="A", scope="global") + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, 10, name="j") as j: + A[j] = A[j] + 1 + with ib.for_range(0, 10, name="j") as j: + ib.scope_attr(cp, "coproc_uop_scope", value) + A[i] = A[i] + 1 + with ib.for_range(0, 10, name="j") as j: + ib.scope_attr(cp, "coproc_uop_scope", value) + A[i] = A[i] + 2 + + body = ib.get() + body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope") + assert body.body.body.body[1].node == cp + assert len(body.body.body.body) == 2 if __name__ == "__main__": test_coproc_lift() diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 021709506754..c58b2f6dd298 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -64,7 +64,7 @@ def test_basic(): stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body.first)) + assert('if' not in str(stmt.body.body.body[0])) def test_const_loop(): n = 21 @@ -79,7 +79,7 @@ def test_const_loop(): stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body.first)) + assert('if' not in str(stmt.body.body.body[0])) def test_multi_loop(): ib = tvm.ir_builder.create() @@ -95,7 +95,7 @@ def test_multi_loop(): stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.stmt.IfThenElse)))) def test_multi_if(): ib = tvm.ir_builder.create() @@ -115,7 +115,7 @@ def test_multi_if(): stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.first)) + assert('if' not in str(stmt.body[0])) def test_thread_axis(): m = tvm.var('m') @@ -134,7 +134,7 @@ def test_thread_axis(): stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body.first)) + assert('if' not in str(stmt.body.body.body[0])) def test_vectorize(): n = tvm.var('n') @@ -169,7 +169,7 @@ def test_condition(): stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select)))) + assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.expr.Select)))) def test_condition_EQ(): ib = tvm.ir_builder.create() @@ -181,7 +181,7 @@ def test_condition_EQ(): stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.Simplify(stmt) - assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select)))) + assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.expr.Select)))) def test_thread_axis2(): n = tvm.convert(4096) @@ -197,7 +197,7 @@ def test_thread_axis2(): s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(tx, tvm.thread_axis("threadIdx.x")) stmt = lower(s, [A, B]) - for_body = stmt.body.body.body.body.body.first + for_body = stmt.body.body.body.body.body[0] assert('threadIdx' not in str(for_body.extent)) def test_everything_during_deduction(): diff --git a/tests/python/unittest/test_pass_storage_sync.py b/tests/python/unittest/test_pass_storage_sync.py index 5c7024abe518..3202d7b7d3a8 100644 --- a/tests/python/unittest/test_pass_storage_sync.py +++ b/tests/python/unittest/test_pass_storage_sync.py @@ -119,10 +119,10 @@ def __check_list(tvm_array, py_list): stmt = ib.get() stmt = tvm.ir_pass.CoProcSync(stmt) - slist = tvm.make.stmt_list(stmt.first.body.body) + slist = tvm.make.stmt_list(stmt[0].body.body) push_st = slist[2] slist = tvm.make.stmt_list(slist[-1]) - pop_st = slist[0].body.first + pop_st = slist[0].body[0] assert(push_st.value.name == "cop.coproc_dep_push") assert(__check_list(push_st.value.args, [2,3])) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 5275aec4db90..b10224376dfa 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -71,7 +71,7 @@ def test_schedule_scan(): s = tvm.create_schedule(res.op) s = s.normalize() ir = tvm.lower(s, [s_state], simple_mode=True) - assert not hasattr(ir.body.body.body.body.rest.body.body.rest.body, "condition") + assert not hasattr(ir.body.body.body.body[1].body.body[1].body, "condition") bounds = tvm.schedule.InferBound(s) assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds) diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 12ef7daac731..dbce9a7b9102 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -135,7 +135,7 @@ def _do_fold(stmt): if body == stmt.body: return stmt ends = list(reversed(ends)) - body = tvm.make.stmt_seq(*(begins + [body] + ends)) + body = tvm.stmt.stmt_seq(*(begins + [body] + ends)) return tvm.make.AttrStmt( stmt.node, stmt.attr_key, stmt.value, body) return None @@ -307,7 +307,7 @@ def _do_fold(stmt): success[0] = True sync = tvm.make.Call( "int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0) - return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync)) + return tvm.stmt.SeqStmt([stmt.body, tvm.make.Evaluate(sync)]) if _match_pragma(stmt, "trim_loop"): op = stmt.body assert isinstance(op, tvm.stmt.For) From 0f308b6f9dd0faed786041d39cb0b5d30272d559 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 5 Jan 2020 16:21:04 -0800 Subject: [PATCH 3/4] [REFACTOR] Remove Block --- include/tvm/ir.h | 22 -------- include/tvm/ir_functor_ext.h | 4 -- python/tvm/stmt.py | 20 ------- src/api/api_ir.cc | 3 -- src/codegen/codegen_c.cc | 5 -- src/codegen/codegen_c.h | 1 - src/codegen/llvm/codegen_llvm.cc | 7 --- src/codegen/llvm/codegen_llvm.h | 1 - src/codegen/spirv/codegen_spirv.cc | 7 --- src/codegen/spirv/codegen_spirv.h | 1 - src/codegen/stackvm/codegen_stackvm.cc | 5 -- src/codegen/stackvm/codegen_stackvm.h | 1 - src/contrib/hybrid/codegen_hybrid.cc | 5 -- src/contrib/hybrid/codegen_hybrid.h | 1 - src/lang/ir.cc | 35 ------------ src/pass/inject_virtual_thread.cc | 16 +----- src/pass/ir_deep_compare.cc | 7 --- src/pass/ir_functor.cc | 19 ------- src/pass/ir_util.cc | 5 -- src/pass/lift_attr_scope.cc | 75 -------------------------- src/pass/loop_partition.cc | 10 ---- src/pass/remove_no_op.cc | 12 ----- src/pass/unroll_loop.cc | 22 -------- 23 files changed, 1 insertion(+), 283 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 03c66a80bfd7..9c3172160097 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1120,28 +1120,6 @@ class SeqStmt : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); }; -/*! - * \brief A sequence of statements. - */ -class Block : public StmtNode { - public: - /*! \brief The first statement. */ - Stmt first; - /*! \brief The restof statments. */ - Stmt rest; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("first", &first); - v->Visit("rest", &rest); - } - - TVM_DLL static Stmt make(Stmt first, Stmt rest); - TVM_DLL static Stmt make(const std::vector &stmts); - - static constexpr const char* _type_key = "Block"; - TVM_DECLARE_FINAL_OBJECT_INFO(Block, StmtNode); -}; - /*! * \brief IfThenElse statment. */ diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 5462e8799f61..6cc6d702c7cd 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -253,7 +253,6 @@ class StmtFunctor { virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args ...) { @@ -277,7 +276,6 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(Provide); IR_STMT_FUNCTOR_DISPATCH(Realize); IR_STMT_FUNCTOR_DISPATCH(Prefetch); - IR_STMT_FUNCTOR_DISPATCH(Block); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(Evaluate); return vtable; @@ -410,7 +408,6 @@ class TVM_DLL StmtVisitor : void VisitStmt_(const Provide* op) override; void VisitStmt_(const Realize* op) override; void VisitStmt_(const Prefetch* op) override; - void VisitStmt_(const Block* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; }; @@ -505,7 +502,6 @@ class TVM_DLL StmtMutator : Stmt VisitStmt_(const Provide* op) override; Stmt VisitStmt_(const Realize* op) override; Stmt VisitStmt_(const Prefetch* op) override; - Stmt VisitStmt_(const Block* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const Evaluate* op) override; /*! diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index 3073e86cd890..64628d1d4198 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -288,23 +288,6 @@ def __init__(self, bounds, condition, body) -@register_node -class Block(Stmt): - """Block node. - - Parameters - ---------- - first : Stmt - The first statement. - - rest : Stmt - The following statement. - """ - def __init__(self, first, rest): - self.__init_handle_by_constructor__( - _make.Block, first, rest) - - @register_node class SeqStmt(Stmt): """Sequence of statements. @@ -422,12 +405,9 @@ def stmt_list(stmt): for x in stmt: res += stmt_list(x) return res - elif isinstance(stmt, Block): - return stmt_list(stmt.first) + stmt_list(stmt.rest) if isinstance(stmt, ProducerConsumer): return stmt_list(stmt.body) return [stmt] _make.stmt_list = stmt_list -_make.stmt_seq = stmt_seq diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 4fec5234f955..034405f1a7f0 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -169,9 +169,6 @@ REGISTER_MAKE(IfThenElse); REGISTER_MAKE(Evaluate); // overloaded, needs special handling -TVM_REGISTER_GLOBAL("make.Block") - .set_body_typed(static_cast(Block::make)); - // has default args TVM_REGISTER_GLOBAL("make.Allocate") .set_body_typed([]( diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 2c24e07eb5d6..a3f145994f2c 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -875,11 +875,6 @@ void CodeGenC::VisitStmt_(const IfThenElse* op) { stream << "}\n"; } -void CodeGenC::VisitStmt_(const Block* op) { - PrintStmt(op->first); - if (op->rest.defined()) PrintStmt(op->rest); -} - void CodeGenC::VisitStmt_(const SeqStmtNode* op) { for (Stmt stmt : op->seq) { PrintStmt(stmt); diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 1773bcbfd88f..eae1e4961b77 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -140,7 +140,6 @@ class CodeGenC : void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const Evaluate* op) override; - void VisitStmt_(const Block* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const ProducerConsumer* op) override; /*! diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 9461eee0ee8c..b0d86a9f66ce 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1214,13 +1214,6 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) { this->VisitStmt(op->body); } -void CodeGenLLVM::VisitStmt_(const Block* op) { - this->VisitStmt(op->first); - if (op->rest.defined()) { - this->VisitStmt(op->rest); - } -} - void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { for (Stmt stmt : op->seq) { this->VisitStmt(stmt); diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 56a710392497..076ffb2af588 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -140,7 +140,6 @@ class CodeGenLLVM : void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const LetStmt* op) override; - void VisitStmt_(const Block* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const ProducerConsumer* op) override; diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 79363fcbbb40..0709965d0e8b 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -638,13 +638,6 @@ void CodeGenSPIRV::VisitStmt_(const LetStmt* op) { this->VisitStmt(op->body); } -void CodeGenSPIRV::VisitStmt_(const Block* op) { - VisitStmt(op->first); - if (op->rest.defined()) { - this->VisitStmt(op->rest); - } -} - void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { for (Stmt stmt : op->seq) { this->VisitStmt(stmt); diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index b2f3fc3ad99e..5cd88c9f267a 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -98,7 +98,6 @@ class CodeGenSPIRV: void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const LetStmt* op) override; - void VisitStmt_(const Block* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const ProducerConsumer* op) override; diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index e0cda5d03f94..23bb008a0e7e 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -417,11 +417,6 @@ void CodeGenStackVM::VisitStmt_(const For* op) { this->SetOperand(backward_jump, loop_head - label_bjump); } -void CodeGenStackVM::VisitStmt_(const Block* op) { - this->Push(op->first); - if (op->rest.defined()) this->Push(op->rest); -} - void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { for (Stmt stmt : op->seq) { this->Push(stmt); diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 63f354d4bffc..7a4c0ab797fd 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -148,7 +148,6 @@ class CodeGenStackVM void VisitStmt_(const AttrStmt* op) final; void VisitStmt_(const AssertStmt* op) final; void VisitStmt_(const Evaluate* op) final; - void VisitStmt_(const Block* op) final; void VisitStmt_(const SeqStmtNode* op) final; void VisitStmt_(const ProducerConsumer* op) final; diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index d88716bc343d..00b2c230c5bb 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -389,11 +389,6 @@ void CodeGenHybrid::VisitStmt_(const IfThenElse* op) { } } -void CodeGenHybrid::VisitStmt_(const Block* op) { - PrintStmt(op->first); - if (op->rest.defined()) PrintStmt(op->rest); -} - void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { for (Stmt stmt : op->seq) { PrintStmt(stmt); diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 8fb2dcd1cae6..27c97c73e333 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -131,7 +131,6 @@ class CodeGenHybrid : void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const Evaluate* op) override; - void VisitStmt_(const Block* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const ProducerConsumer* op) override; /*! diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 08167affac00..de047f330630 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -510,33 +510,6 @@ SeqStmt::SeqStmt(Array seq) { data_ = std::move(node); } -Stmt Block::make(Stmt first, Stmt rest) { - CHECK(first.defined()); - CHECK(rest.defined()); - ObjectPtr node = make_object(); - - // canonicalize. - if (const Block* b = first.as()) { - node->first = b->first; - node->rest = Block::make(b->rest, rest); - } else { - node->first = std::move(first); - node->rest = std::move(rest); - } - return Stmt(node); -} - -Stmt Block::make(const std::vector& stmts) { - if (stmts.empty()) { - return Stmt(); - } - Stmt result = stmts.back(); - for (size_t i = stmts.size() - 1; i != 0; --i) { - result = Block::make(stmts[i - 1], result); - } - return result; -} - Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { CHECK(condition.defined()); CHECK(then_case.defined()); @@ -1037,13 +1010,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) } }); -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->Print(op->first); - if (op->rest.defined()) p->Print(op->rest); - }); - TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { auto* op = static_cast(node.get()); @@ -1226,7 +1192,6 @@ TVM_REGISTER_NODE_TYPE(Provide); TVM_REGISTER_NODE_TYPE(Allocate); TVM_REGISTER_NODE_TYPE(Free); TVM_REGISTER_NODE_TYPE(Realize); -TVM_REGISTER_NODE_TYPE(Block); TVM_REGISTER_NODE_TYPE(SeqStmtNode); TVM_REGISTER_NODE_TYPE(IfThenElse); TVM_REGISTER_NODE_TYPE(Evaluate); diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 848aa8f8e34e..0887a83c1a48 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -356,21 +356,7 @@ class VTInjector : public StmtExprMutator { return IfThenElse::make(condition, then_case, else_case); } } - // Block - Stmt VisitStmt_(const Block* op) final { - CHECK_EQ(max_loop_depth_, 0); - Stmt first = this->VisitStmt(op->first); - int temp = max_loop_depth_; - max_loop_depth_ = 0; - Stmt rest = this->VisitStmt(op->rest); - max_loop_depth_ = std::max(max_loop_depth_, temp); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { - return GetRef(op); - } else { - return Block::make(first, rest); - } - } + // Seq Stmt VisitStmt_(const SeqStmtNode* op) final { CHECK_EQ(max_loop_depth_, 0); diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 0c6cb17422bc..bbee9eeb7c8a 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -179,13 +179,6 @@ class IRDeepCompare : if (CompareRegion(op->bounds, rhs->bounds) != 0) return; } - void VisitStmt_(const Block* op, const Stmt& other) final { - const Block* rhs = other.as(); - if (CompareStmt(op->first, rhs->first) != 0) return; - if (CompareStmt(op->rest, rhs->rest) != 0) return; - } - - void VisitStmt_(const SeqStmtNode* op, const Stmt& other) final { const SeqStmtNode* rhs = other.as(); if (CompareValue(op->size(), rhs->size()) != 0) return; diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index a8bba375fb6a..dddf90eb47aa 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -209,11 +209,6 @@ void StmtVisitor::VisitStmt_(const Prefetch* op) { }); } -void StmtVisitor::VisitStmt_(const Block* op) { - this->VisitStmt(op->first); - this->VisitStmt(op->rest); -} - void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); @@ -496,20 +491,6 @@ Stmt StmtMutator::VisitStmt_(const Prefetch* op) { } } -Stmt StmtMutator::VisitStmt_(const Block* op) { - Stmt first = this->VisitStmt(op->first); - Stmt rest = this->VisitStmt(op->rest); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->first = std::move(first); - n->rest = std::move(rest); - return Stmt(n); - } -} - Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { Array seq = Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index 7d7b15d18a4c..8956a4d11e7c 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -51,11 +51,6 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { CHECK(!n->else_case.defined()); n->then_case = body; body = Stmt(n); - } else if (const auto* block = s.as()) { - auto n = make_object(*block); - CHECK(is_no_op(n->rest)); - n->rest = body; - body = Stmt(n); } else if (const auto* seq = s.as()) { auto n = make_object(*seq); CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 6d998a6acc3a..4f2df7b22b09 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -75,19 +75,6 @@ class AttrScopeLifter : public StmtMutator { } } - Stmt VisitStmt_(const Block* op) final { - std::vector seq; - FlattenSeq(op->first, &seq); - FlattenSeq(op->rest, &seq); - seq = MutateSeq(seq); - if (seq.size() == 2 && - seq[0].same_as(op->first) && - seq[1].same_as(op->rest)) { - return GetRef(op); - } - return SeqStmt::Flatten(seq); - } - Stmt VisitStmt_(const SeqStmtNode* op) final { // remember the decorations. std::vector attr_node; @@ -184,68 +171,6 @@ class AttrScopeLifter : public StmtMutator { } private: - void FlattenSeq(Stmt s, std::vector* res) { - if (const Block* op = s.as()) { - FlattenSeq(op->first, res); - FlattenSeq(op->rest, res); - } else if (const ProducerConsumer* op = s.as()) { - if (!op->is_producer) { - FlattenSeq(op->body, res); - } else { - res->emplace_back(s); - } - } else { - res->emplace_back(s); - } - } - - std::vector MutateSeq(const Array& seq) { - std::vector res_seq; - ObjectRef curr_node; - Expr curr_value; - Stmt curr_stmt; - for (const Stmt & stmt : seq) { - attr_node_ = ObjectRef(); - attr_value_ = Expr(); - Stmt rest = this->VisitStmt(stmt); - if (attr_node_.defined() && - attr_value_.defined() && - curr_node.defined() && - curr_value.defined() && - attr_node_.same_as(curr_node) && - ValueSame(attr_value_, curr_value)) { - curr_stmt = Block::make(curr_stmt, rest); - } else { - if (curr_stmt.defined()) { - if (curr_node.defined()) { - curr_stmt = AttrStmt::make( - curr_node, attr_key_, curr_value, curr_stmt); - } - res_seq.push_back(curr_stmt); - } - curr_stmt = rest; - curr_node = attr_node_; - curr_value = attr_value_; - } - } - - if (curr_stmt.defined()) { - // keep attr_node_, attr_node_ - if (res_seq.size() == 0) { - return {curr_stmt}; - } - if (curr_node.defined()) { - curr_stmt = AttrStmt::make( - curr_node, attr_key_, curr_value, curr_stmt); - } - res_seq.push_back(curr_stmt); - // reset - attr_node_ = ObjectRef(); - attr_value_ = Expr(); - } - return res_seq; - } - // value comparison that also compares content of int constant static bool ValueSame(const Expr& a, const Expr& b) { if (a.same_as(b)) return true; diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 09559c1bb8b2..aa8ebe1eb19b 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -106,16 +106,6 @@ class CandidateSelector final : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const Block* op) final { - bool temp = no_split_; - this->VisitStmt(op->first); - // erase the no split state of first when visit rest. - std::swap(temp, no_split_); - this->VisitStmt(op->rest); - // restore the no split flag. - no_split_ = no_split_ || temp; - } - void VisitStmt_(const SeqStmtNode* op) final { bool init_no_split = no_split_; for (Stmt stmt : op->seq) { diff --git a/src/pass/remove_no_op.cc b/src/pass/remove_no_op.cc index 59ec113fcaff..68918708b9fb 100644 --- a/src/pass/remove_no_op.cc +++ b/src/pass/remove_no_op.cc @@ -94,18 +94,6 @@ class NoOpRemover : public StmtMutator { return Evaluate::make(0); } - Stmt VisitStmt_(const Block* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); - if (is_no_op(op->first)) { - return op->rest; - } else if (is_no_op(op->rest)) { - return op->first; - } else { - return stmt; - } - } - Stmt VisitStmt_(const SeqStmtNode* op) final { Stmt ret = StmtMutator::VisitSeqStmt_(op, true); op = ret.as(); diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index 269bf0933c0b..7826a9b3a666 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -120,28 +120,6 @@ class LoopUnroller : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const Block* op) final { - Stmt first = this->VisitStmt(op->first); - // cleanup state - int step_count = step_count_; - int unroll_depth = unroll_depth_; - int normal_loop_depth = normal_loop_depth_; - step_count_ = 0; - unroll_depth_ = 0; - normal_loop_depth_ = 0; - // work on rest part - Stmt rest = this->VisitStmt(op->rest); - step_count_ += step_count; - normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); - unroll_depth_ = std::max(unroll_depth_, unroll_depth); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { - return GetRef(op); - } else { - return Block::make(first, rest); - } - } - Stmt VisitStmt_(const SeqStmtNode* op) final { auto fmutate = [this](const Stmt& s) { int step_count = step_count_; From 6a228dc10af32c1168e7a960ada45b0fdcab8bf6 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 5 Jan 2020 19:51:57 -0800 Subject: [PATCH 4/4] Add more comments per yizhi's comment --- include/tvm/ir.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 9c3172160097..b1cefff1e90e 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1069,10 +1069,19 @@ class SeqStmt : public Stmt { return (*(operator->()))[index]; } /*! - * \brief Construct a flattened sequence statement. + * \brief Construct a sequence statement by flattening + * all the arrays and sequences in the arguments + * recursively. + * + * - When an argument is nullptr, it will be ignored. + * - When an argument is an array or a SeqStmt, it will be flattened recursively. + * - When an argument is a consumer block in ProducerConsumer, the consumer + * tag will be dropped as such information is not useful in lowering. + * - A normal Stmt will be appended to the end of the sequence. * * \note This function can directly return an element * if it is the only element in the sequence. + * * \param seq_args The list of arguments to be flattened. * \tparam Args arguments * \return The constructed statement @@ -1096,6 +1105,7 @@ class SeqStmt : public Stmt { if (auto* op = stmt.as()) { operator()(0, op->seq); } else if (auto* op = stmt.as()) { + // NOTE: The consumer block annotation was not as useful and can be safely dropped. if (!op->is_producer) { operator()(0, op->body); } else {