Skip to content

Commit

Permalink
[TIR][Arith] Use TryCompare to narrow inequalities if possible (#13024)
Browse files Browse the repository at this point in the history
Prior to this commit, the result of TryCompare would only be used if
it could definitively prove a conditional to be either true or false.
For example, if it is known that `0 <= i`, a conditional of `i <= 0`
would be left as-is.

This commit introduces rewrite rules to preferentially simplify
into more restrictive conditions.  Using the same example, if it is
known that `0 <= i`, a conditional of `i <= 0` would be simplified
into `i == 0`.  Similarly, if it is known that `0 <= i`, a
conditional of `i != 0` would be simplified into `0 < i`.

Because this change does not introduce significant overhead, as the
results of `RewriteSimplifier::Impl::TryCompare` are already
available, this change is enabled for all use cases and does not
require a call to `RewriteSimplifier::SetEnabledExtensions`.
  • Loading branch information
Lunderberg authored Nov 4, 2022
1 parent de8a79d commit ccb7d07
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 32 deletions.
144 changes: 119 additions & 25 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/tir/op.h>

#include <algorithm>
#include <utility>

#include "../target/datatype/registry.h"
#include "conjunctive_normal_form.h"
Expand Down Expand Up @@ -1384,80 +1385,164 @@ Optional<PrimExpr> RewriteSimplifier::Impl::TryMatchLiteralConstraint(const Prim
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQNode>();
EQ ret = Downcast<EQ>(IRMutatorWithAnalyzer::VisitExpr_(op));
op = ret.get();

if (auto const_res = TryConstFold<EQ>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

return ApplyRewriteRules(ret);
}

PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1;
PVar<int> lanes;

// vector rule
if (op->dtype.lanes() != 1) {
if (ret->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes));
}

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a, op->b);
if (IsIndexType(ret->a.dtype())) {
CompareResult result = TryCompare(ret->a, ret->b);
if (result == CompareResult::kEQ) {
return make_const(op->dtype, true);
return make_const(ret->dtype, true);
} else if (result == CompareResult::kNE || result == CompareResult::kGT ||
result == CompareResult::kLT) {
return make_const(op->dtype, false);
return make_const(ret->dtype, false);
}
TVM_TRY_REWRITE(c1 == x, x == c1);

TVM_TRY_REWRITE(x - c1 == 0, x == c1);
TVM_TRY_REWRITE(c1 - x == 0, x == c1);
TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1);
TVM_TRY_REWRITE(x * y == 0, x == 0 || y == 0);
}
return ret;
return std::move(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) {
return this->VisitExpr(Not(op->a == op->b));
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<NENode>();

if (auto const_res = TryConstFold<NE>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a, op->b);
if (result == CompareResult::kNE || result == CompareResult::kGT ||
result == CompareResult::kLT) {
return make_const(op->dtype, true);
} else if (result == CompareResult::kEQ) {
return make_const(op->dtype, false);
} else if (result == CompareResult::kGE) {
// Known: a >= b
//
// a != b
// (a < b) or (b < a)
// False or (b < a)
// b < a
return ApplyRewriteRules(LT(op->b, op->a));
} else if (result == CompareResult::kLE) {
// Known: a <= b
//
// a != b
// (a < b) or (b < a)
// (a < b) or False
// a < b
return ApplyRewriteRules(LT(op->a, op->b));
}
}

return ApplyRewriteRules(Not(ApplyRewriteRules(EQ(op->a, op->b))));
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) {
return this->VisitExpr(Not(op->b < op->a));
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LENode>();
ICHECK(op);

if (auto const_res = TryConstFold<LE>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Check for applicable rewrites before attempting to prove/disprove
// the inequality. This preserves earlier behavior, where (A<=B*x)
// simplifies to (ceildiv(A,B)<=x) when (A%B!=0). Performing the
// TryCompare first would simplify to the equivalent
// (floordiv(A,B)<x) in these cases instead.
ret = ApplyRewriteRules(Not(ApplyRewriteRules(LT(op->b, op->a))));

if (auto op = ret.as<LENode>(); op && IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a, op->b);
if (result == CompareResult::kLE || result == CompareResult::kLT ||
result == CompareResult::kEQ) {
return make_const(op->dtype, true);
} else if (result == CompareResult::kGT) {
return make_const(op->dtype, false);
} else if (result == CompareResult::kNE) {
// Known: a != b
//
// a <= b
// (a < b) or (a == b)
// (a < b) or False
// a < b
return ApplyRewriteRules(LT(op->a, op->b));
} else if (result == CompareResult::kGE) {
// Known: a >= b
//
// a <= b
// (a < b) or (a == b)
// False or (a == b)
// a == b
return ApplyRewriteRules(EQ(op->a, op->b));
}
}

return ret;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) {
return this->VisitExpr(op->b < op->a);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) {
return this->VisitExpr(Not(op->a < op->b));
return this->VisitExpr(op->b <= op->a);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LTNode>();
LT node = Downcast<LT>(IRMutatorWithAnalyzer::VisitExpr_(op));
op = node.get();

if (auto const_res = TryConstFold<LT>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
if (auto match = TryMatchLiteralConstraint(node)) return match.value();

return ApplyRewriteRules(node);
}

PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;

// vector rule
if (op->dtype.lanes() != 1) {
if (ret->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes));
TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes));
}

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a, op->b);
if (IsIndexType(ret->a.dtype())) {
CompareResult result = TryCompare(ret->a, ret->b);
if (result == CompareResult::kLT) {
return make_const(op->dtype, true);
return make_const(ret->dtype, true);
}
if (result == CompareResult::kEQ || result == CompareResult::kGT ||
result == CompareResult::kGE) {
return make_const(op->dtype, false);
return make_const(ret->dtype, false);
}

// clang-format off
Expand Down Expand Up @@ -1561,19 +1646,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
TVM_TRY_REWRITE(x - c1 < 0, x < c1);
// clang-format on
}
return ret;
return std::move(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<NotNode>();
if (auto const_res = TryConstFold<Not>(op->a)) return const_res.value();
Not ret = Downcast<Not>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (auto const_res = TryConstFold<Not>(ret->a)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

return ApplyRewriteRules(ret);
}

PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
if (ret->dtype.lanes() != 1) {
TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
}

Expand All @@ -1586,7 +1674,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
TVM_TRY_REWRITE(!(x != y), x == y);
TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y));
TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y));
return ret;
return std::move(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
Expand Down Expand Up @@ -1762,6 +1850,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {

TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);

TVM_TRY_RECURSIVE_REWRITE(x < y || x == y, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x < y || y == x, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x == y || x < y, x <= y);
TVM_TRY_RECURSIVE_REWRITE(y == x || x < y, x <= y);

return ret;
}

Expand Down
21 changes: 21 additions & 0 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
*/
Optional<PrimExpr> TryMatchLiteralConstraint(const PrimExpr& expr) const;

/*! \brief Rewrite rules for Less Than comparisons
*
* These are separate from the VisitExpr_(const LTNode*) method, as
* they may required from rewrites of LT or LE.
*/
PrimExpr ApplyRewriteRules(LT node);

/*! \brief Rewrite rules for Equal comparisons
*
* These are separate from the VisitExpr_(const EQNode*) method, as
* they may required from rewrites of LE or NE.
*/
PrimExpr ApplyRewriteRules(EQ node);

/*! \brief Rewrite rules for Equal comparisons
*
* These are separate from the VisitExpr_(const EQNode*) method, as
* they may required from rewrites of LT, LE, or NE.
*/
PrimExpr ApplyRewriteRules(Not node);

private:
CompareResult TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y);
CompareResult TryCompareUsingConstIntBounds(const PrimExpr& x, const PrimExpr y);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ def test_cmp_simplify():
ck.verify(fld(x, 2) <= -1, tvm.tir.LE(x, -1))

ck.verify(fld(x, 4) * 4 < x, tvm.tir.LT(0, flm(x, 4)))
ck.verify(fld(x, 4) * 4 >= x, tvm.tir.LE(flm(x, 4), 0))
ck.verify(fld(x, 4) * 4 >= x, tvm.tir.EQ(flm(x, 4), 0))

ck.verify(fld(x, 4) * 4 < x + y, tvm.tir.LT(0, flm(x, 4) + y))
ck.verify(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4)))
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_nonbijective_inverse_gives_error():
inverse=lambda i, j: [4 * i + j],
pre_shape=[15],
post_shape=[4, 4],
padding=lambda i, j: tvm.tir.And(i == 3, j >= 3),
padding=lambda i, j: tvm.tir.And(i == 3, tvm.runtime.convert(3) == j),
),
"left_padding": dict(
forward=lambda i: [(i + 1) // 4, (i + 1) % 4],
Expand All @@ -107,7 +107,7 @@ def test_nonbijective_inverse_gives_error():
post_shape=[4, 4],
padding=lambda i, j: tvm.tir.Or(
tvm.tir.And(i == 0, j < 1),
tvm.tir.And(i == 3, j >= 3),
tvm.tir.And(i == 3, tvm.runtime.convert(3) == j),
),
),
"dynamic_size": dict(
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_nonbijective_inverse_gives_error():
padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or(
tvm.tir.Or(
tvm.tir.And(i_outer == 0, i_inner < 1),
tvm.tir.And(i_outer == 3, i_inner >= 3),
tvm.tir.And(i_outer == 3, tvm.runtime.convert(3) == i_inner),
),
tvm.tir.Or(
tvm.tir.And(j_outer == 0, j_inner < 5),
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_nonbijective_inverse_gives_error():
inverse=lambda i, j: [i * 4 + j],
pre_shape=[3],
post_shape=[1, 4],
padding=lambda i, j: 3 <= j,
padding=lambda i, j: tvm.runtime.convert(3) == j,
),
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def transformed_three_stage_compute(
T.writes(B[0:2, tx, 0])
B[i, tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.where(1 <= i)
T.where(i == 1)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
Expand Down Expand Up @@ -1349,7 +1349,7 @@ def ref(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> N
with T.attr(0, "async_scope", 1):
B[i % 2, tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.where(1 <= i and i - 1 < 16)
T.where(i == 1 and i - 1 < 16)
T.reads(B[(i + 1) % 2, tx, 0])
T.writes(C[(i + 1) % 2, tx, 0])
with T.attr(0, "async_commit_queue_scope", 1):
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,5 +1003,51 @@ def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
A[0] = True


class TestMostRestrictiveConditional(BaseBeforeAfter):
"""Preferentially prove part of a compound conditional.
Even if we cannot prove a conditional as true or false on its own,
proving that a conditional must satisfy a stronger condition may
allow for later rewrites. For example, if it is known that `a <= b`,
then `a >= b` cannot be proven, but can be reduced to `a == b`.
"""

i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"]
tir_int = tvm.tir.IntImm("int32", 0)

test_case = tvm.testing.parameter(
(i <= tir_int, tir_int <= i, i == tir_int),
(i <= tir_int, i != tir_int, i < tir_int),
(i != tir_int, i <= tir_int, i < tir_int),
(i != tir_int, tir_int <= i, tir_int < i),
(i <= j, j <= i, j == i),
(i <= j, i != j, i < j),
(i != j, i <= j, i < j),
(i != j, j <= i, j < i),
)

@tvm.testing.fixture
def before(self, test_case):
priors, expr_before, _ = test_case

@T.prim_func
def func(A: T.Buffer[1, "bool"]):
if priors:
A[0] = expr_before

return func

@tvm.testing.fixture
def expected(self, test_case):
priors, _, expr_after = test_case

@T.prim_func
def func(A: T.Buffer[1, "bool"]):
if priors:
A[0] = expr_after

return func


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ccb7d07

Please sign in to comment.