Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Simplifications for floormod(x, 2) #13936

Merged
merged 10 commits into from
Apr 4, 2023
5 changes: 5 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,11 @@ class IterMapRewriter : public ExprMutator {
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);

static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
if (sign < 0 && is_const_int(rhs->extent, 2)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a cover case for the codepath?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do, and added! This was a case that was caught by other unit tests, since the additional RewriteSimplifier rules prevented DetectIterMap from recognizing some patterns after they had been simplfied, but a specific unit test for this case is better than needing to track it down later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: this actually is a bug, see #14571

This makes me think that if other rules leads to the regression. Please do check.

lhs->base -= rhs->scale;
sign = 1;
}

tir::ExprDeepEqual equal;
for (size_t i = 0; i < lhs->args.size(); ++i) {
IterSplitExpr lvalue = lhs->args[i];
Expand Down
38 changes: 33 additions & 5 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
c2.Eval()->value > 0);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2));

// canonicalization rule
// will try rewrite again after canonicalization.

TVM_TRY_RECURSIVE_REWRITE(matches_one_of(x + (c1 - y), (c1 - y) + x), (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(x + c1 + y, x + (c1 + y)), (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of((x + c1) + y, x + (c1 + y), x + (y + c1)),
(x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);

Expand Down Expand Up @@ -454,6 +458,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y,
c1.Eval()->value != 0);

TVM_TRY_RECURSIVE_REWRITE(
floordiv(x + c1, 2) - floordiv(x + c2, 2),
floormod(x, 2) * (floormod(c1, 2) - floormod(c2, 2)) + (floordiv(c1, 2) - floordiv(c2, 2)));
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) - floordiv(x + c2, 2),
floormod(x, 2) * (0 - floormod(c2, 2)) - floordiv(c2, 2));
TVM_TRY_RECURSIVE_REWRITE(floordiv(x + c1, 2) - floordiv(x, 2),
floormod(x, 2) * floormod(c1, 2) + floordiv(c1, 2));

TVM_TRY_REWRITE_IF(
x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
Expand All @@ -473,6 +485,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);

TVM_TRY_RECURSIVE_REWRITE(floordiv(x + 1, 2) - floormod(x, 2), floordiv(x, 2));

TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3),
floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
c3.Eval()->value > 0);
Expand All @@ -483,6 +497,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
// will try rewrite again after canonicalization.
TVM_TRY_REWRITE(x - c1, x + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
} else if (op->dtype.is_float()) {
Expand Down Expand Up @@ -862,6 +877,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x));
TVM_TRY_REWRITE(matches_one_of(floordiv(x * c1, x), floordiv(c1 * x, x)), c1);

TVM_TRY_REWRITE(floordiv(floormod(x, 2) + 1, 2), floormod(x, 2));

// Rules involving 2-operands.
TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
Expand Down Expand Up @@ -973,6 +990,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
c2.Eval()->value > 0);

TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 1,
floormod(c1.Eval()->value, 2) == 1);
TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand All @@ -983,12 +1002,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {

TVM_TRY_REWRITE(matches_one_of(floormod(x * y, y), floormod(y * x, y)), ZeroWithTypeLike(y));

// try modular analysis
if (floormod(x, c1).Match(ret)) {
ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
if (mod->coeff % c1val == 0 && c1val > 0) {
return floormod(mod->base, c1).Eval();
if (c1val > 0) {
// try modular analysis
ModularSet mod = analyzer_->modular_set(x.Eval());
if (mod->coeff % c1val == 0) {
return floormod(mod->base, c1).Eval();
}

// floormod(x,c1) is a no-op when x is already in the
// appropriate range.
ConstIntBound bound = analyzer_->const_int_bound(x.Eval());
if (bound->min_value >= 0 && bound->max_value < c1val) {
return x.Eval();
}
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ def test_compound():
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))


def test_compound_floormod_two():
x = tvm.tir.Var("x", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod

# extent of 2 are normalized to positive scale
assert_iter_sum_pattern(
expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
dom_map=var_dom([(x, 8)]),
)


def test_predicate():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,39 @@ class TestFloormodIndex(BaseCompare):
)


class TestFloorModTwo(BaseCompare):
"""Special-case simplifications for FloorMod(expr,2)

Because FloorMod(expr,2) has only two possible values, it can be
simplified more aggressively than most FloorMod expressions. Some
of these have analogues for other denominators (e.g. x%3 + (x+1)%3
+ (x+2)%3 == 0 + 1 + 2), but they don't appear as often and
require identifying more related terms in order to apply.

(x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2)
"""

x, y, z = te.var("x"), te.var("y"), te.var("z")
test_case = tvm.testing.parameter(
# Removing offsets from floormod
TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x, 2) + flm(x + 1, 2), 1),
TestCase(flm(x + 1, 2) + flm(x, 2), 1),
# Difference of floordiv yields floormod
TestCase(fld(x + 1, 2) - fld(x, 2), flm(x, 2)),
TestCase(fld(x, 2) - fld(x - 1, 2), flm(x, 2) * -1 + 1),
TestCase(fld(x + 5, 2) - fld(x - 2, 2), flm(x, 2) + 3),
TestCase(fld(x + 5, 2) - fld(x - 3, 2), 4),
TestCase(fld(flm(x, 2) + 1, 2), flm(x, 2)),
# Sum of floordiv and floormod to yield floordiv
TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)),
TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)),
# Removal of floormod where possible
TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]),
)


class TestMinIndex(BaseCompare):
x, y, z = te.var("x"), te.var("y"), te.var("z")
test_case = tvm.testing.parameter(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5
for ax0_ax1_ax2_ax3_fused in T.serial((i4_0 % 2 + 1) // 2 * 96 + 96):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * ((i4_0 % 2 + 1) // 2 + 1)) // 96)
v1 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused // 64 + i4_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (i4_0 % 2 + 1)) // 96)
v2 = T.axis.spatial(6, i0_0_i1_0_i2_0_i3_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32)
v3 = T.axis.spatial(512, i6_0 * 32 + ax0_ax1_ax2_ax3_fused % 32)
T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
Expand Down
Loading