Skip to content

Commit

Permalink
[FIX] Bug fix for a floormod rewrite simplify rule (apache#8852)
Browse files Browse the repository at this point in the history
* Update rewrite_simplify.cc

* Update test_arith_rewrite_simplify.py

* Update test_arith_rewrite_simplify.py

* Update test_arith_rewrite_simplify.py
  • Loading branch information
jcf94 authored and Andrew Zhao Luo committed Sep 1, 2021
1 parent 9a68712 commit c105a1b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
16 changes: 10 additions & 6 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,14 +858,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0) {
if (ramp_min == ramp_max) {
if (ramp_min == ramp_max) {
// If b1 can devide c2
if (bmod->coeff % c2val == 0) {
return ramp(floormod(bmod->base, c2), c1, lanes).Eval();
} else {
return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval();
}
} else if (c2val % bmod->coeff == 0 && ramp_min == ramp_max) {
return ramp(floormod(b1, c2), c1, lanes).Eval();
// If all indices can be guaranteed to settle inside a coeff range
if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) {
return ramp(floormod(b1, c2), c1, lanes).Eval();
}
}
if (bmod->coeff % c2val == 0) {
return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval();
}
}
}
Expand Down
22 changes: 14 additions & 8 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ def test_vector_simplify():
ck.verify(
fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
)
) # Example negative case: x = 15; [60, 61, 62, 63, 64] / 64 = [0, 0, 0, 0, 1]
ck.verify(
fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
)
) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [0, 1, 1, 1]
ck.verify(
fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
)
) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1]

# floor mod
ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2"))
ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4))
Expand All @@ -136,16 +137,21 @@ def test_vector_simplify():
flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4)
)
ck.verify(
flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), tvm.tir.Ramp(flm(x * 4, 64), 1, 5)
)
flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
) # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0]
ck.verify(
flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4),
)
flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2]
ck.verify(
flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)),
flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)),
) # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5]
ck.verify(
flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
)
) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 13, 20]

# Min/Max rules
vx = te.var("vx", dtype="int32x2")
Expand Down

0 comments on commit c105a1b

Please sign in to comment.