From 8dfa33b17ff9ad74773cc184e691b06415a43ded Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 7 Oct 2022 15:23:23 -0500 Subject: [PATCH] [TIR][Arith] Use TryCompare to narrow inequalities if possible Prior to this commit, the result of TryCompare would only be used if it could definitively prove a conditional to be either true or false. For example, if it is known that `0 <= i`, a conditional of `i <= 0` would be left as-is. This commit introduces rewrite rules to preferentially simplify into more restrictive conditions. Using the same example, if it is known that `0 <= i`, a conditional of `i <= 0` would be simplified into `i == 0`. Similarly, if it is known that `0 <= i`, a conditional of `i != 0` would be simplified into `0 < i`. Because this change does not introduce significant overhead, as the results of `RewriteSimplifier::Impl::TryCompare` are already available, this change is enabled for all use cases and does not require a call to `RewriteSimplifier::SetEnabledExtensions`. --- src/arith/rewrite_simplify.cc | 144 +++++++++++++++--- src/arith/rewrite_simplify.h | 21 +++ .../unittest/test_arith_rewrite_simplify.py | 2 +- tests/python/unittest/test_index_map.py | 8 +- ..._tir_transform_inject_software_pipeline.py | 4 +- .../unittest/test_tir_transform_simplify.py | 46 ++++++ 6 files changed, 193 insertions(+), 32 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 8e0c62e063602..56be0cd553d90 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -29,6 +29,7 @@ #include #include +#include #include "../target/datatype/registry.h" #include "conjunctive_normal_form.h" @@ -1350,11 +1351,16 @@ Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint(const Prim } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { - PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); + EQ ret = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + op = ret.get(); + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + return ApplyRewriteRules(ret); +} + +PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm @@ -1362,32 +1368,106 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { PVar lanes; // vector rule - if (op->dtype.lanes() != 1) { + if (ret->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); } - if (IsIndexType(op->a.dtype())) { - CompareResult result = TryCompare(op->a, op->b); + if (IsIndexType(ret->a.dtype())) { + CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kEQ) { - return make_const(op->dtype, true); + return make_const(ret->dtype, true); } else if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return make_const(op->dtype, false); + return make_const(ret->dtype, false); } + TVM_TRY_REWRITE(c1 == x, x == c1); + TVM_TRY_REWRITE(x - c1 == 0, x == c1); TVM_TRY_REWRITE(c1 - x == 0, x == c1); TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1); TVM_TRY_REWRITE(x * y == 0, x == 0 || y == 0); } - return ret; + return std::move(ret); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { - return this->VisitExpr(Not(op->a == op->b)); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + op = ret.as(); + + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + + if (IsIndexType(op->a.dtype())) { + CompareResult result = TryCompare(op->a, op->b); + if (result == CompareResult::kNE || result == CompareResult::kGT || + result == CompareResult::kLT) { + return make_const(op->dtype, true); + } else if (result == CompareResult::kEQ) { + return make_const(op->dtype, false); + } else if (result == CompareResult::kGE) { + // Known: a >= b + // + // a != b + // (a < b) or (b < a) + // False or (b < a) + // b < a + return ApplyRewriteRules(LT(op->b, op->a)); + } else if (result == CompareResult::kLE) { + // Known: a <= b + // + // a != b + // (a < b) or (b < a) + // (a < b) or False + // a < b + return ApplyRewriteRules(LT(op->a, op->b)); + } + } + + return ApplyRewriteRules(Not(ApplyRewriteRules(EQ(op->a, op->b)))); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { - return this->VisitExpr(Not(op->b < op->a)); + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + op = ret.as(); + ICHECK(op); + + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + + // Check for applicable rewrites before attempting to prove/disprove + // the inequality. This preserves earlier behavior, where (A<=B*x) + // simplifies to (ceildiv(A,B)<=x) when (A%B!=0). Performing the + // TryCompare first would simplify to the equivalent + // (floordiv(A,B)b, op->a)))); + + if (auto op = ret.as(); op && IsIndexType(op->a.dtype())) { + CompareResult result = TryCompare(op->a, op->b); + if (result == CompareResult::kLE || result == CompareResult::kLT || + result == CompareResult::kEQ) { + return make_const(op->dtype, true); + } else if (result == CompareResult::kGT) { + return make_const(op->dtype, false); + } else if (result == CompareResult::kNE) { + // Known: a != b + // + // a <= b + // (a < b) or (a == b) + // (a < b) or False + // a < b + return ApplyRewriteRules(LT(op->a, op->b)); + } else if (result == CompareResult::kGE) { + // Known: a >= b + // + // a <= b + // (a < b) or (a == b) + // False or (a == b) + // a == b + return ApplyRewriteRules(EQ(op->a, op->b)); + } + } + + return ret; } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) { @@ -1395,15 +1475,20 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) { - return this->VisitExpr(Not(op->a < op->b)); + return this->VisitExpr(op->b <= op->a); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { - PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); + LT node = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + op = node.get(); + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if (auto match = TryMatchLiteralConstraint(node)) return match.value(); + + return ApplyRewriteRules(node); +} +PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm @@ -1411,19 +1496,19 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PVar lanes; // vector rule - if (op->dtype.lanes() != 1) { + if (ret->dtype.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes)); TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes)); } - if (IsIndexType(op->a.dtype())) { - CompareResult result = TryCompare(op->a, op->b); + if (IsIndexType(ret->a.dtype())) { + CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kLT) { - return make_const(op->dtype, true); + return make_const(ret->dtype, true); } if (result == CompareResult::kEQ || result == CompareResult::kGT || result == CompareResult::kGE) { - return make_const(op->dtype, false); + return make_const(ret->dtype, false); } // clang-format off @@ -1527,19 +1612,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { TVM_TRY_REWRITE(x - c1 < 0, x < c1); // clang-format on } - return ret; + return std::move(ret); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { - PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - if (auto const_res = TryConstFold(op->a)) return const_res.value(); + Not ret = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + if (auto const_res = TryConstFold(ret->a)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + return ApplyRewriteRules(ret); +} + +PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) { // Pattern var to match any expression PVar x, y; PVar lanes; - if (op->dtype.lanes() != 1) { + if (ret->dtype.lanes() != 1) { TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); } @@ -1552,7 +1640,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { TVM_TRY_REWRITE(!(x != y), x == y); TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y)); TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y)); - return ret; + return std::move(ret); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { @@ -1641,6 +1729,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); + + TVM_TRY_REWRITE(x < y || x == y, x <= y); + TVM_TRY_REWRITE(x < y || y == x, x <= y); + TVM_TRY_REWRITE(x == y || x < y, x <= y); + TVM_TRY_REWRITE(y == x || x < y, x <= y); + return ret; } diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 02c54902153aa..b8e7fcdd94337 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -137,6 +137,27 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { */ Optional TryMatchLiteralConstraint(const PrimExpr& expr) const; + /*! \brief Rewrite rules for Less Than comparisons + * + * These are separate from the VisitExpr_(const LTNode*) method, as + * they may required from rewrites of LT or LE. + */ + PrimExpr ApplyRewriteRules(LT node); + + /*! \brief Rewrite rules for Equal comparisons + * + * These are separate from the VisitExpr_(const EQNode*) method, as + * they may required from rewrites of LE or NE. + */ + PrimExpr ApplyRewriteRules(EQ node); + + /*! \brief Rewrite rules for Equal comparisons + * + * These are separate from the VisitExpr_(const EQNode*) method, as + * they may required from rewrites of LT, LE, or NE. + */ + PrimExpr ApplyRewriteRules(Not node); + private: CompareResult TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y); CompareResult TryCompareUsingConstIntBounds(const PrimExpr& x, const PrimExpr y); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 77751b1601775..4477e1d9c713c 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -863,7 +863,7 @@ def test_cmp_simplify(): ck.verify(fld(x, 2) <= -1, tvm.tir.LE(x, -1)) ck.verify(fld(x, 4) * 4 < x, tvm.tir.LT(0, flm(x, 4))) - ck.verify(fld(x, 4) * 4 >= x, tvm.tir.LE(flm(x, 4), 0)) + ck.verify(fld(x, 4) * 4 >= x, tvm.tir.EQ(flm(x, 4), 0)) ck.verify(fld(x, 4) * 4 < x + y, tvm.tir.LT(0, flm(x, 4) + y)) ck.verify(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4))) diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py index 6882c2b426344..1f54f8d0dcee2 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_index_map.py @@ -81,7 +81,7 @@ def test_nonbijective_inverse_gives_error(): inverse=lambda i, j: [4 * i + j], pre_shape=[15], post_shape=[4, 4], - padding=lambda i, j: tvm.tir.And(i == 3, j >= 3), + padding=lambda i, j: tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), ), "left_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], @@ -97,7 +97,7 @@ def test_nonbijective_inverse_gives_error(): post_shape=[4, 4], padding=lambda i, j: tvm.tir.Or( tvm.tir.And(i == 0, j < 1), - tvm.tir.And(i == 3, j >= 3), + tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), ), ), "dynamic_size": dict( @@ -126,7 +126,7 @@ def test_nonbijective_inverse_gives_error(): padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or( tvm.tir.Or( tvm.tir.And(i_outer == 0, i_inner < 1), - tvm.tir.And(i_outer == 3, i_inner >= 3), + tvm.tir.And(i_outer == 3, tvm.runtime.convert(3) == i_inner), ), tvm.tir.Or( tvm.tir.And(j_outer == 0, j_inner < 5), @@ -167,7 +167,7 @@ def test_nonbijective_inverse_gives_error(): inverse=lambda i, j: [i * 4 + j], pre_shape=[3], post_shape=[1, 4], - padding=lambda i, j: 3 <= j, + padding=lambda i, j: tvm.runtime.convert(3) == j, ), } ) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 9334a4d9e827a..2a4cabc541c65 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -263,7 +263,7 @@ def transformed_three_stage_compute( T.writes(B[0:2, tx, 0]) B[i, tx, 0] = A[tx, i] * T.float32(2) with T.block(): - T.where(1 <= i) + T.where(i == 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) @@ -1349,7 +1349,7 @@ def ref(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> N with T.attr(0, "async_scope", 1): B[i % 2, tx, 0] = A[tx, i] * T.float32(2) with T.block(): - T.where(1 <= i and i - 1 < 16) + T.where(i == 1 and i - 1 < 16) T.reads(B[(i + 1) % 2, tx, 0]) T.writes(C[(i + 1) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 6236791a75eee..e47a071db0f7c 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -828,5 +828,51 @@ def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32): A[0] = True +class TestMostRestrictiveConditional(BaseBeforeAfter): + """Preferentially prove part of a compound conditional. + + Even if we cannot prove a conditional as true or false on its own, + proving that a conditional must satisfy a stronger condition may + allow for later rewrites. For example, if it is known that `a <= b`, + then `a >= b` cannot be proven, but can be reduced to `a == b`. + """ + + i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"] + tir_int = tvm.tir.IntImm("int32", 0) + + test_case = tvm.testing.parameter( + (i <= tir_int, tir_int <= i, i == tir_int), + (i <= tir_int, i != tir_int, i < tir_int), + (i != tir_int, i <= tir_int, i < tir_int), + (i != tir_int, tir_int <= i, tir_int < i), + (i <= j, j <= i, j == i), + (i <= j, i != j, i < j), + (i != j, i <= j, i < j), + (i != j, j <= i, j < i), + ) + + @tvm.testing.fixture + def before(self, test_case): + priors, expr_before, _ = test_case + + @T.prim_func + def func(A: T.Buffer[1, "bool"]): + if priors: + A[0] = expr_before + + return func + + @tvm.testing.fixture + def expected(self, test_case): + priors, _, expr_after = test_case + + @T.prim_func + def func(A: T.Buffer[1, "bool"]): + if priors: + A[0] = expr_after + + return func + + if __name__ == "__main__": tvm.testing.main()