Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,9 @@ TVM_DLL const Op& vscale();
*/
TVM_DLL const Op& get_active_lane_mask();

/*! \brief Annotate a predicate not be considered as target condition of loop partition. */
TVM_DLL const Op& ignore_loop_partition();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,6 +1909,7 @@ def wrapped(*args, **kwargs):
anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed)
anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked)
vscale = _op_wrapper(_tir_op.vscale)
ignore_loop_partition = _op_wrapper(_tir_op.ignore_loop_partition)


def _dtype_forward(func):
Expand Down Expand Up @@ -2261,4 +2262,5 @@ def wrapped(*args, **kwargs):
"vscale",
"get_active_lane_mask",
"call_kernel",
"ignore_loop_partition",
]
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import vscale, get_active_lane_mask, get_vscale_expr
from .op import dp4a
from .op import ignore_loop_partition
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3581,6 +3581,18 @@ def get_vscale_expr(dtype: Union[str, tvm.DataType], min_size: int = 128) -> Pri
return min_size // dtype.bits * vscale()


def ignore_loop_partition(predicate) -> PrimExpr:
"""
Annotate a predicate not be considered as target condition of loop partition.

Parameters
----------
predicate : PrimExpr
The annotated predicate expression.
"""
return call_intrin("bool", "tir.ignore_loop_partition", predicate)


# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
Expand Down
6 changes: 6 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@ TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask)
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));

TIR_DEFINE_BUILTIN_FUNC(ignore_loop_partition)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kNone));

} // namespace builtin
} // namespace tir
} // namespace tvm
51 changes: 36 additions & 15 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ class CandidateSelector final : public StmtExprVisitor {
: partition_const_loop_(partition_const_loop) {}

void VisitStmt_(const ForNode* op) final {
// always treat var with hint to be partitioned
const VarNode* var = op->loop_var.get();
if (partition_hint_vars.count(var)) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
// partition const loop when sets partition_const_loop_
if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) {
// always treat var with hint to be partitioned
const VarNode* var = op->loop_var.get();
if (partition_hint_vars.count(var)) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
Expand All @@ -126,14 +126,14 @@ class CandidateSelector final : public StmtExprVisitor {
const IterVarNode* iv = op->node.as<IterVarNode>();
ICHECK(iv);
Var var = iv->var;
// always treat var with hint to be partitioned
if (partition_hint_vars.count(var.get())) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) {
// always treat var with hint to be partitioned
if (partition_hint_vars.count(var.get())) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
record_.insert({var.get(), false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
Expand Down Expand Up @@ -262,6 +262,8 @@ class PartitionFinder : public StmtExprVisitor {
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
DeduceCondition(op->args[0]);
} else if (op->op.same_as(builtin::ignore_loop_partition())) {
return;
} else {
StmtExprVisitor::VisitExpr_(op);
}
Expand All @@ -287,6 +289,22 @@ class PartitionFinder : public StmtExprVisitor {
// cond is true within interval
partitions[{cond, true}] = interval;
}

if (interval.IsNothing()) {
// `DeduceBound` do not support NE now, thus when
// deduce l==r failed, just only try (l<=r && l>=r)
if (const EQNode* op = cond.as<EQNode>()) {
IntSet part1 = DeduceBound(current_var_, GE(op->a, op->b), hint_map_, relax_map_);
IntSet part2 = DeduceBound(current_var_, LE(op->a, op->b), hint_map_, relax_map_);
interval = arith::Intersect({part1, part2});
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond, true}] = interval;
return;
}
}
}

PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
Expand Down Expand Up @@ -469,6 +487,7 @@ std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
if (kv.first.second == cond_value) {
arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval);

if (!intersection->IsEmpty()) {
sets.push_back(kv.second);
cond_set.insert(kv.first.first);
Expand Down Expand Up @@ -625,8 +644,7 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
}();

if (middle_interval.IsNothing() && opt_cond_value == false) {
// Return loop directly as it can be simplified.
return stmt;
return Stmt();
}

if (!opt_cond_value.has_value()) {
Expand Down Expand Up @@ -750,6 +768,9 @@ class RemoveLikelyTagsAndHints : public StmtExprMutator {
if (op->op.same_as(builtin::likely())) {
ICHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
} else if (op->op.same_as(builtin::ignore_loop_partition())) {
ICHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
} else {
return StmtExprMutator::VisitExpr_(op);
}
Expand Down
87 changes: 85 additions & 2 deletions tests/python/tir-transform/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,11 +570,12 @@ def test_explicit_partition_hint():
tvm.ir.assert_structural_equal(mod["main"], partitioned_concat)


def partition_from_scheduled_tir(prim_func, pass_cfg):
def partition_from_scheduled_tir(prim_func, pass_cfg, do_flatten=True):
with tvm.transform.PassContext(config=pass_cfg):
mod = IRModule.from_expr(prim_func.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
if do_flatten:
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
Expand Down Expand Up @@ -1037,6 +1038,29 @@ def concat_five_buffers_with_equalities_expected(
T_concat_1[i0 * 129 + 129] = buffer_e_1[i0]


@T.prim_func
def nested_partition_with_single_points(A: T.Buffer[(25,), "int32"]):
for i in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
if i == 1:
for j in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
if j > 2:
A[i * 5 + j] = i * 5 + j
else:
for j in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
if j > 2:
A[i * 5 + j] = i * 15 + j


@T.prim_func
def nested_partition_with_single_points_expected(A: T.Buffer[(25,), "int32"]):
for j in range(2):
A[j + 3] = j + 3
for j in range(2):
A[j + 8] = j + 8
for i, j in T.grid(3, 2):
A[i * 5 + j + 13] = i * 15 + j + 33


@pytest.mark.parametrize(
"origin,expected",
[
Expand All @@ -1045,6 +1069,7 @@ def concat_five_buffers_with_equalities_expected(
(concat_func_end_point_equality, concat_func_end_point_equality_expected),
(concat_func_edge_equalities, concat_func_edge_equalities_expected),
(concat_five_buffers_with_equalities, concat_five_buffers_with_equalities_expected),
(nested_partition_with_single_points, nested_partition_with_single_points_expected),
],
)
def test_single_point_partition(origin, expected):
Expand All @@ -1062,5 +1087,63 @@ def test_single_point_partition(origin, expected):
tvm.ir.assert_structural_equal(mod["main"], expected)


def test_equation_on_floordiv():
@T.prim_func
def before(A: T.Buffer[(2, 2, 20), "int32"]):
for i in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
if i == 1:
for vv in T.vectorized(640, annotations={"pragma_loop_partition_hint": 1}):
if i * 2 + vv // 320 == 3:
A[i - 1, i * 2 + vv // 320 - 3, vv % 320 // 16] = 1

@T.prim_func
def expected(A: T.Buffer[(2, 2, 20), "int32"]):
for vv in T.vectorized(320):
A[0, 0, vv // 16] = 1

expected = expected.with_attr({"global_symbol": "main"})
after = partition_from_scheduled_tir(
before.with_attr("global_symbol", "main"), {}, do_flatten=False
)
tvm.ir.assert_structural_equal(after["main"], expected)


def test_ignore_loop_partition_hint():
"""Skip unroll body and prologue for pipeline case"""

@T.prim_func
def before(A: T.Buffer[(10), "float32"], D: T.Buffer[(10), "float32"]):
B = T.decl_buffer([2], "float32")
C = T.decl_buffer([2], "float32")
for i in T.serial(12, annotations={"pragma_loop_partition_hint": 1}):
if T.ignore_loop_partition(i < 10):
B[i % 2] = A[i] + 1.0
if T.ignore_loop_partition(1 <= i and i < 11):
C[(i - 1) % 2] = B[(i - 1) % 2] + 2.0
if 2 <= i:
D[i - 2] = C[i % 2] + 3.0

@T.prim_func
def expected(A: T.Buffer[(10), "float32"], D: T.Buffer[(10), "float32"]):
B = T.decl_buffer([2], "float32")
C = T.decl_buffer([2], "float32")
for i in range(2):
B[i] = A[i] + 1.0
if i == 1:
C[i - 1] = B[i - 1] + 2.0
for i in T.serial(10):
if i < 8:
B[i % 2] = A[i + 2] + 1.0
if i < 9:
C[(i + 1) % 2] = B[(i + 1) % 2] + 2.0
D[i] = C[i % 2] + 3.0

expected = expected.with_attr({"global_symbol": "main"})
after = partition_from_scheduled_tir(
before.with_attr({"global_symbol": "main"}), {}, do_flatten=False
)
tvm.ir.assert_structural_equal(after["main"], expected)


if __name__ == "__main__":
tvm.testing.main()
Loading