Skip to content

Commit

Permalink
[REFACTOR] Migrate use of Block to SeqStmt.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 6, 2020
1 parent 3adcd8a commit 99936b2
Show file tree
Hide file tree
Showing 32 changed files with 167 additions and 141 deletions.
5 changes: 3 additions & 2 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1071,8 +1071,8 @@ class SeqStmt : public Stmt {
/*!
* \brief Construct a flattened sequence statement.
*
* \note This function can return the element if there
* is only one element in the sequence.
* \note This function can directly return an element
* if it is the only element in the sequence.
* \param seq_args The list of arguments to be flattened.
* \tparam Args arguments
* \return The constructed statement
Expand All @@ -1092,6 +1092,7 @@ class SeqStmt : public Stmt {
: seq_(seq) {}

void operator()(size_t i, const Stmt& stmt) const {
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
} else if (auto* op = stmt.as<ProducerConsumer>()) {
Expand Down
8 changes: 3 additions & 5 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
from .. import make as _make
from .. import stmt as _stmt

from .. import api as _api
from .. import ir_pass as _ir_pass

Expand All @@ -48,11 +50,7 @@ def concat_list_to_block(lst):
n = len(lst)
if n == 1:
return lst[0]
body = lst[n - 1]
for i in range(1, n):
stmt = lst[n - 1 - i]
body = _make.Block(stmt, body)
return body
return _stmt.SeqStmt(lst)


def visit_list_to_block(visit, lst):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _pop_seq(self):
seq = self._seq_stack.pop()
if not seq or callable(seq[-1]):
seq.append(_make.Evaluate(0))
seqwrap = lambda x : x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
ret_seq = [seq[-1]]

for s in reversed(seq[:-1]):
Expand Down
13 changes: 6 additions & 7 deletions python/tvm/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,16 +395,15 @@ def stmt_seq(*args):
stmt : Stmt
The combined statement.
"""
return SeqStmt(args)

"""
ret = None
ret = []
for value in args:
if not isinstance(value, Stmt):
value = Evaluate(value)
ret = value if ret is None else Block(ret, value)
return ret if ret else Evaluate(0)
"""
ret.append(value)
if len(ret) == 1:
return ret[0]
return SeqStmt(ret)


def stmt_list(stmt):
"""Make list of stmt from blocks.
Expand Down
8 changes: 4 additions & 4 deletions src/op/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ void MakeReduction(const ComputeOpNode* op,
provides.emplace_back(Provide::make(
t->op, t->value_index, update_value[i], args));
}
*init = Block::make(inits);
*provide = Block::make(provides);
*init = SeqStmt::Flatten(inits);
*provide = SeqStmt::Flatten(provides);
if (!is_one(reduce->condition)) {
*provide = IfThenElse::make(reduce->condition, *provide);
}
Expand Down Expand Up @@ -382,7 +382,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
if (debug_keep_trivial_loop) {
provide = MergeNest(common, provide);
} else {
provide = MergeNest(common, Block::make(init, provide));
provide = MergeNest(common, SeqStmt::Flatten(init, provide));
}
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
Expand All @@ -392,7 +392,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self,
for (size_t i = 0; i < self->body.size(); ++i) {
provides.emplace_back(MakeProvide(self, stage->op.output(i)));
}
Stmt provide = Block::make(provides);
Stmt provide = SeqStmt::Flatten(provides);
provide = MergeNest(n.main_nest, provide);
// run substitution in the on the full nest, because loop condition
// could depend on outer loops.
Expand Down
4 changes: 2 additions & 2 deletions src/op/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ Stmt MakeCrossThreadReduction(
stage->op, idx,
Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = Block::make(assigns);
Stmt assign_body = SeqStmt::Flatten(assigns);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = Block::make(reduce_body, assign_body);
Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body = Allocate::make(
res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
Expand Down
2 changes: 1 addition & 1 deletion src/op/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ Stmt TensorComputeOpNode::BuildProvide(
update = MergeNest(binder.asserts(), update);
update = op::Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, Block::make(init, update));
return MergeNest(common, SeqStmt::Flatten(init, update));
} else {
// When init op is not available, use body op for reset in the first iter.
CHECK(this->intrin->body.defined())
Expand Down
2 changes: 1 addition & 1 deletion src/op/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
update = MergeNest(binder.asserts(), update);
update = Substitute(update, n.main_vmap);
update = MergeNest(update_nest, update);
return MergeNest(common, Block::make(init, update));
return MergeNest(common, SeqStmt::Flatten(init, update));
} else {
// When init op is not available, use body op for reset in the first iter.
CHECK(intrin->body.defined())
Expand Down
2 changes: 1 addition & 1 deletion src/pass/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
stride_err_msg.str(), Evaluate::make(0));
check = IfThenElse::make(Not::make(is_null), check, Stmt());
asserts_.emplace_back(Block::make(check, Evaluate::make(0)));
asserts_.emplace_back(SeqStmt({check, Evaluate::make(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
DataType stype = buffer->DefaultIndexType();
Expand Down
24 changes: 7 additions & 17 deletions src/pass/coproc_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,24 +655,14 @@ class CoProcSyncInserter : public StmtMutator {
}

Stmt VisitStmt(const Stmt& stmt) final {
Stmt before, after;
auto it = insert_before_.find(stmt.get());
if (it != insert_before_.end()) {
before = MergeSeq(std::vector<Stmt>(
it->second.rbegin(), it->second.rend()));
}
it = insert_after_.find(stmt.get());
if (it != insert_after_.end()) {
after = MergeSeq(it->second);
}
auto it_before = insert_before_.find(stmt.get());
auto it_after = insert_after_.find(stmt.get());
Stmt new_stmt = StmtMutator::VisitStmt(stmt);
if (before.defined()) {
new_stmt = Block::make(before, new_stmt);
}
if (after.defined()) {
new_stmt = Block::make(new_stmt, after);
}
return new_stmt;

return SeqStmt::Flatten(
it_before != insert_before_.end() ? it_before->second : std::vector<Stmt>(),
new_stmt,
it_after != insert_after_.end() ? it_after->second : std::vector<Stmt>());
}

private:
Expand Down
6 changes: 3 additions & 3 deletions src/pass/inject_double_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class DoubleBufferInjector : public StmtExprMutator {
}
Stmt loop = For::make(
outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
MergeSeq(loop_seq));
SeqStmt::Flatten(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
Expand All @@ -158,9 +158,9 @@ class DoubleBufferInjector : public StmtExprMutator {
IfThenElse::make(idx < old_loop->extent,
Substitute(tail_body, vmap)));
}
stmt = Block::make(loop, MergeSeq(tail_seq));
stmt = SeqStmt::Flatten(loop, tail_seq);
}
stmt = Block::make(MergeSeq(it->second), stmt);
stmt = SeqStmt::Flatten(it->second, stmt);
}
it = loop_allocs_.find(op);
if (it != loop_allocs_.end()) {
Expand Down
2 changes: 1 addition & 1 deletion src/pass/inject_prefetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class PrefetchInjector : public StmtMutator {
vectorized_.erase(iter_var);

Stmt prefetch = Prefetch::make(ts->op, ts->value_index, ts->dtype, region);
return Block::make(prefetch, op->body);
return SeqStmt({prefetch, op->body});
}
return ret;
}
Expand Down
9 changes: 4 additions & 5 deletions src/pass/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,11 @@ class VTInjector : public StmtExprMutator {
// only unroll if number of vthreads are small
if (max_loop_depth_ == 0 && num_threads_ < 16) {
// do unrolling if it is inside innermost content.
Stmt blk = Substitute(stmt, {{var_, make_zero(var_.dtype())}});
for (int i = 1; i < num_threads_; ++i) {
blk = Block::make(
blk, Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
Array<Stmt> seq;
for (int i = 0; i < num_threads_; ++i) {
seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
}
return blk;
return SeqStmt::Flatten(seq);
} else {
// insert a for loop
Var idx(var_->name_hint + ".s", var_->dtype);
Expand Down
14 changes: 5 additions & 9 deletions src/pass/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
CHECK(is_no_op(n->rest));
n->rest = body;
body = Stmt(n);
} else if (const auto* seq = s.as<SeqStmtNode>()) {
auto n = make_object<SeqStmtNode>(*seq);
CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
n->seq.Set(n->size() - 1, body);
body = Stmt(n);
} else if (const auto* assert_ = s.as<AssertStmt>()) {
auto n = make_object<AssertStmt>(*assert_);
CHECK(is_no_op(n->body));
Expand All @@ -80,14 +85,5 @@ Stmt MergeNest(const std::vector<std::vector<Stmt> >& nest, Stmt body) {
return body;
}

Stmt MergeSeq(const std::vector<Stmt>& seq) {
if (seq.size() == 0) return Evaluate::make(0);
Stmt body = seq[0];
for (size_t i = 1; i < seq.size(); ++i) {
body = Block::make(body, seq[i]);
}
return body;
}

} // namespace ir
} // namespace tvm
7 changes: 0 additions & 7 deletions src/pass/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body);
*/
Stmt MergeNest(const std::vector<std::vector<Stmt> >& nest, Stmt body);

/*!
* \brief combine sequence of operations.
* \param seq The sequence.
* \return The combined Stmt
*/
Stmt MergeSeq(const std::vector<Stmt>& seq);

/*!
* \brief update array with an unary function
* \param arr array
Expand Down
55 changes: 52 additions & 3 deletions src/pass/lift_attr_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,59 @@ class AttrScopeLifter : public StmtMutator {
seq[1].same_as(op->rest)) {
return GetRef<Stmt>(op);
}
return MergeSeq(seq);
return SeqStmt::Flatten(seq);
}

Stmt VisitStmt_(const SeqStmtNode* op) final {
return StmtMutator::VisitSeqStmt_(op, true);
// remember the decorations.
std::vector<ObjectRef> attr_node;
std::vector<Expr> attr_value;

auto fmutate = [&](const Stmt& s) {
attr_node_ = ObjectRef();
attr_value_ = Expr();
Stmt ret = this->VisitStmt(s);
attr_node.push_back(attr_node_);
attr_value.push_back(attr_value_);
return ret;
};
Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate);
if (attr_node.size() == 0) return ret;

op = ret.as<SeqStmtNode>();
CHECK(op != nullptr);
Array<Stmt> reorg;
// check if all decorations are common.
for (size_t begin = 0; begin < attr_node.size();) {
size_t end = begin + 1;
while (end < attr_node.size() &&
attr_node[end].same_as(attr_node[begin]) &&
ValueSame(attr_value[end], attr_value[begin])) {
++end;
}
// covers everything
// lift attr to parent.
if (begin == 0 && end == attr_node.size()) {
attr_node_ = attr_node[0];
attr_value_ = attr_value[0];
return ret;
}
// construct subsegments.
Array<Stmt> seq;
for (size_t i = begin; i < end; ++i) {
seq.push_back(op->seq[i]);
}
Stmt stmt = SeqStmt::Flatten(seq);
if (attr_node[begin].defined()) {
stmt = AttrStmt::make(
attr_node[begin], attr_key_, attr_value[begin], stmt);
}
reorg.push_back(stmt);
begin = end;
}
attr_node_ = ObjectRef();
attr_value_ = Expr();
return SeqStmt::Flatten(reorg);
}

Stmt VisitStmt_(const IfThenElse* op) final {
Expand Down Expand Up @@ -151,7 +199,7 @@ class AttrScopeLifter : public StmtMutator {
}
}

std::vector<Stmt> MutateSeq(const std::vector<Stmt>& seq) {
std::vector<Stmt> MutateSeq(const Array<Stmt>& seq) {
std::vector<Stmt> res_seq;
ObjectRef curr_node;
Expr curr_value;
Expand Down Expand Up @@ -201,6 +249,7 @@ class AttrScopeLifter : public StmtMutator {
// value comparison that also compares content of int constant
static bool ValueSame(const Expr& a, const Expr& b) {
if (a.same_as(b)) return true;
if (!a.defined() || !b.defined()) return false;
if (a->type_index() != b->type_index()) return false;
if (a.dtype() != b.dtype()) return false;
if (const IntImm* op = a.as<IntImm>()) {
Expand Down
13 changes: 1 addition & 12 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,16 +414,6 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
return std::make_pair(interval, cond_set);
}

Stmt AppendStmts(const Stmt& a, const Stmt& b) {
if (!a.defined()) {
return b;
} else if (!b.defined()) {
return a;
} else {
return Block::make(a, b);
}
}

/*
* Tries to recursively partition the range of the variable (given by var) of
* the for loop (given by node and stmt) into a
Expand Down Expand Up @@ -589,8 +579,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
}
}
}
s = AppendStmts(pre_stmt, mid_stmt);
s = AppendStmts(s, post_stmt);
s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt);
} else {
Expr cond = const_true();
if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
Expand Down
Loading

0 comments on commit 99936b2

Please sign in to comment.