diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0da8e55be023..2ae2877b2f92 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1339,6 +1339,12 @@ constexpr const char* hand_threaded = "hand_threaded"; * if (mask & 2) the write region should be detected. */ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access"; + +/*! + * \brief Mark that the loop should be partitioned. + */ +constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 97f5b6f90a70..c4b83e05706d 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -98,7 +98,13 @@ class CandidateSelector final : public StmtExprVisitor { void VisitStmt_(const ForNode* op) final { // partition const loop when sets partition_const_loop_ if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) { + // always treat var with hint to be partitioned const VarNode* var = op->loop_var.get(); + if (partition_hint_vars.count(var)) { + candidates.insert(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + return; + } record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { @@ -117,6 +123,12 @@ class CandidateSelector final : public StmtExprVisitor { Var var = iv->var; runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) { + // always treat var with hint to be partitioned + if (partition_hint_vars.count(var.get())) { + candidates.insert(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + return; + } record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { @@ -125,6 +137,15 @@ class CandidateSelector final : public StmtExprVisitor { record_.erase(var.get()); return; } + } else if (op->attr_key == attr::pragma_loop_partition_hint) { + const VarNode* var = nullptr; + if (op->node->IsInstance()) { + var = op->node.as(); + } else if (op->node->IsInstance()) { + var = op->node.as()->var.get(); + } + ICHECK(var); + partition_hint_vars.insert(var); } StmtExprVisitor::VisitStmt_(op); } @@ -162,6 +183,7 @@ class CandidateSelector final : public StmtExprVisitor { } std::unordered_set candidates; + std::unordered_set partition_hint_vars; private: bool in_likely_{false}; @@ -170,15 +192,28 @@ class CandidateSelector final : public StmtExprVisitor { std::unordered_map record_; }; +// Finder try best to find partitions for hinted vars +#define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \ + void VisitExpr_(const OpNodeT* op) final { \ + if (has_partition_hint_) { \ + DeduceCondition(GetRef(op)); \ + return; \ + } \ + StmtExprVisitor::VisitExpr_(op); \ + } + // Populate partitions data structure, i.e., for a specific variable, -// find an interval in which each condition -// (currently, "likely" conditions) has fixed true or false value +// find an interval in which each condition has fixed true or false value class PartitionFinder : public StmtExprVisitor { public: explicit PartitionFinder(Var current_var, const std::unordered_map& hint_map, - const std::unordered_map& relax_map) - : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { + const std::unordered_map& relax_map, + bool has_partition_hint) + : current_var_(current_var), + has_partition_hint_(has_partition_hint), + hint_map_(hint_map), + relax_map_(relax_map) { for (const auto& kv : hint_map) { out_vars_.insert(kv.first); } @@ -218,33 +253,43 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { - PrimExpr cond = op->args[0]; - if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { - // For cond, find out the interval, if exists, in which we can prove that cond is - // true. Also find the interval, if exists, in which we can prove that cond is - // false. - IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); - if (!interval.IsNothing()) { - // cond is true within interval - partitions[{cond, true}] = interval; - } - PrimExpr inverse_cond = InverseCond(cond); - if (inverse_cond.defined()) { - IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); - if (!interval.IsNothing()) { - // cond is false within interval - partitions[{cond, false}] = interval; - } - } - } + DeduceCondition(op->args[0]); } else { StmtExprVisitor::VisitExpr_(op); } } + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GENode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GTNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LENode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LTNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(EQNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(NENode); + Partition partitions; private: + void DeduceCondition(const PrimExpr& cond) { + // For cond, find out the interval, if exists, in which we can prove that cond is + // true. Also find the interval, if exists, in which we can prove that cond is + // false. + if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { + IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); + if (!interval.IsNothing()) { + // cond is true within interval + partitions[{cond, true}] = interval; + } + PrimExpr inverse_cond = InverseCond(cond); + if (inverse_cond.defined()) { + IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); + if (!interval.IsNothing()) { + // cond is false within interval + partitions[{cond, false}] = interval; + } + } + } + } + PrimExpr InverseCond(const PrimExpr& cond) { PrimExpr inverse_cond; if (const LTNode* op = cond.as()) { @@ -270,6 +315,7 @@ class PartitionFinder : public StmtExprVisitor { } Var current_var_; + bool has_partition_hint_; std::unordered_set out_vars_; std::unordered_map hint_map_; std::unordered_map relax_map_; @@ -472,7 +518,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim // include hint of var. hint_map_.insert({var.get(), IntSet::Interval(min, max)}); - PartitionFinder finder(var, hint_map_, relax_map_); + bool has_partition_hint_ = selector.partition_hint_vars.count(var.get()); + PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_); finder(body); hint_map_.erase(var.get()); @@ -601,7 +648,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b } } -class RemoveLikelyTags : public StmtExprMutator { +class RemoveLikelyTagsAndHints : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { @@ -611,12 +658,19 @@ class RemoveLikelyTags : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::pragma_loop_partition_hint) { + return VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } }; Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) { stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one) .VisitAndMutate(std::move(stmt)); - stmt = RemoveLikelyTags()(std::move(stmt)); + stmt = RemoveLikelyTagsAndHints()(std::move(stmt)); return stmt; } diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index c632f744bb81..a219b8d96457 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -17,6 +17,8 @@ import tvm import tvm.testing from tvm import te +from tvm import tir +from tvm.script import ty import numpy @@ -434,7 +436,6 @@ def test_conv_tiling(): oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) @@ -538,6 +539,33 @@ def test_simple_rfactor(): assert not tvm.ir.structural_equal(stmt1.body, stmt2.body) +@tvm.script.tir +def partitioned_concat(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [16], dtype="float32") + B = tir.match_buffer(b, [16], dtype="float32") + C = tir.match_buffer(c, [32], dtype="float32") + for i in tir.serial(0, 16): + tir.store(C.data, i, tir.load("float32", A.data, i), True) + for i in tir.serial(0, 16): + tir.store(C.data, i + 16, tir.load("float32", B.data, i + 16), True) + + +def test_explicit_partition_hint(): + A = te.placeholder((16,), name="A") + B = te.placeholder((16,), name="B") + C = te.compute((32,), lambda i: te.if_then_else(i < 16, A[i], B[i]), name="C") + s = te.create_schedule(C.op) + s.normalize() + s[C].pragma(s[C].op.axis[0], "loop_partition_hint") + mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None) + with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.tir.transform.Simplify()(mod) + assert tvm.ir.structural_equal(mod["main"], partitioned_concat) + + if __name__ == "__main__": test_basic() test_const_loop() @@ -559,3 +587,4 @@ def test_simple_rfactor(): test_double_splitting_with_indivisible_factors() test_multilevel_splitting_with_indivisble_factors() test_simple_rfactor() + test_explicit_partition_hint()