Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] add loop partition hint pragma #9121

Merged
merged 4 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,12 @@ constexpr const char* hand_threaded = "hand_threaded";
* if (mask & 2) the write region should be detected.
*/
constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";

/*!
* \brief Mark that the loop should be partitioned.
*/
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
106 changes: 80 additions & 26 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ class CandidateSelector final : public StmtExprVisitor {
void VisitStmt_(const ForNode* op) final {
// partition const loop when sets partition_const_loop_
if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) {
// always treat var with hint to be partitioned
const VarNode* var = op->loop_var.get();
if (partition_hint_vars.count(var)) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
Expand All @@ -117,6 +123,12 @@ class CandidateSelector final : public StmtExprVisitor {
Var var = iv->var;
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) {
// always treat var with hint to be partitioned
if (partition_hint_vars.count(var.get())) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
record_.insert({var.get(), false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
Expand All @@ -125,6 +137,15 @@ class CandidateSelector final : public StmtExprVisitor {
record_.erase(var.get());
return;
}
} else if (op->attr_key == attr::pragma_loop_partition_hint) {
const VarNode* var = nullptr;
if (op->node->IsInstance<VarNode>()) {
var = op->node.as<VarNode>();
} else if (op->node->IsInstance<IterVarNode>()) {
var = op->node.as<IterVarNode>()->var.get();
}
ICHECK(var);
partition_hint_vars.insert(var);
}
StmtExprVisitor::VisitStmt_(op);
}
Expand Down Expand Up @@ -162,6 +183,7 @@ class CandidateSelector final : public StmtExprVisitor {
}

std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;
std::unordered_set<const VarNode*> partition_hint_vars;

private:
bool in_likely_{false};
Expand All @@ -170,15 +192,28 @@ class CandidateSelector final : public StmtExprVisitor {
std::unordered_map<const VarNode*, VarIsUsed> record_;
};

// Finder try best to find partitions for hinted vars
#define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \
void VisitExpr_(const OpNodeT* op) final { \
if (has_partition_hint_) { \
DeduceCondition(GetRef<PrimExpr>(op)); \
return; \
} \
StmtExprVisitor::VisitExpr_(op); \
}

// Populate partitions data structure, i.e., for a specific variable,
// find an interval in which each condition
// (currently, "likely" conditions) has fixed true or false value
// find an interval in which each condition has fixed true or false value
class PartitionFinder : public StmtExprVisitor {
public:
explicit PartitionFinder(Var current_var,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map)
: current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
const std::unordered_map<const VarNode*, IntSet>& relax_map,
bool has_partition_hint)
: current_var_(current_var),
has_partition_hint_(has_partition_hint),
hint_map_(hint_map),
relax_map_(relax_map) {
for (const auto& kv : hint_map) {
out_vars_.insert(kv.first);
}
Expand Down Expand Up @@ -218,33 +253,43 @@ class PartitionFinder : public StmtExprVisitor {

void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
PrimExpr cond = op->args[0];
if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond, true}] = interval;
}
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is false within interval
partitions[{cond, false}] = interval;
}
}
}
DeduceCondition(op->args[0]);
} else {
StmtExprVisitor::VisitExpr_(op);
}
}

DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GENode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GTNode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LENode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LTNode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(EQNode);
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(NENode);

Partition partitions;

private:
void DeduceCondition(const PrimExpr& cond) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) {
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond, true}] = interval;
}
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is false within interval
partitions[{cond, false}] = interval;
}
}
}
}

PrimExpr InverseCond(const PrimExpr& cond) {
PrimExpr inverse_cond;
if (const LTNode* op = cond.as<LTNode>()) {
Expand All @@ -270,6 +315,7 @@ class PartitionFinder : public StmtExprVisitor {
}

Var current_var_;
bool has_partition_hint_;
std::unordered_set<const VarNode*> out_vars_;
std::unordered_map<const VarNode*, IntSet> hint_map_;
std::unordered_map<const VarNode*, IntSet> relax_map_;
Expand Down Expand Up @@ -472,7 +518,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
// include hint of var.
hint_map_.insert({var.get(), IntSet::Interval(min, max)});

PartitionFinder finder(var, hint_map_, relax_map_);
bool has_partition_hint_ = selector.partition_hint_vars.count(var.get());
PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_);
finder(body);

hint_map_.erase(var.get());
Expand Down Expand Up @@ -601,7 +648,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b
}
}

class RemoveLikelyTags : public StmtExprMutator {
class RemoveLikelyTagsAndHints : public StmtExprMutator {
public:
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
Expand All @@ -611,12 +658,19 @@ class RemoveLikelyTags : public StmtExprMutator {
return StmtExprMutator::VisitExpr_(op);
}
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::pragma_loop_partition_hint) {
return VisitStmt(op->body);
}
return StmtExprMutator::VisitStmt_(op);
}
};

Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) {
stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one)
.VisitAndMutate(std::move(stmt));
stmt = RemoveLikelyTags()(std::move(stmt));
stmt = RemoveLikelyTagsAndHints()(std::move(stmt));
return stmt;
}

Expand Down
31 changes: 30 additions & 1 deletion tests/python/unittest/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import tvm
import tvm.testing
from tvm import te
from tvm import tir
from tvm.script import ty
import numpy


Expand Down Expand Up @@ -434,7 +436,6 @@ def test_conv_tiling():
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
mod = tvm.tir.transform.LoopPartition()(mod)
Expand Down Expand Up @@ -538,6 +539,33 @@ def test_simple_rfactor():
assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)


@tvm.script.tir
def partitioned_concat(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
A = tir.match_buffer(a, [16], dtype="float32")
B = tir.match_buffer(b, [16], dtype="float32")
C = tir.match_buffer(c, [32], dtype="float32")
for i in tir.serial(0, 16):
tir.store(C.data, i, tir.load("float32", A.data, i), True)
for i in tir.serial(0, 16):
tir.store(C.data, i + 16, tir.load("float32", B.data, i + 16), True)


def test_explicit_partition_hint():
A = te.placeholder((16,), name="A")
B = te.placeholder((16,), name="B")
C = te.compute((32,), lambda i: te.if_then_else(i < 16, A[i], B[i]), name="C")
s = te.create_schedule(C.op)
s.normalize()
s[C].pragma(s[C].op.axis[0], "loop_partition_hint")
mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)
with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
assert tvm.ir.structural_equal(mod["main"], partitioned_concat)


if __name__ == "__main__":
test_basic()
test_const_loop()
Expand All @@ -559,3 +587,4 @@ def test_simple_rfactor():
test_double_splitting_with_indivisible_factors()
test_multilevel_splitting_with_indivisble_factors()
test_simple_rfactor()
test_explicit_partition_hint()