Skip to content

Commit

Permalink
[TIR][Arith] Support negative coeff in ModularSet
Browse files Browse the repository at this point in the history
Prior to this commit, any use of negative coefficients in `ModularSet`
would result in an error.  This included cases where a constraint is
being entered, such as `floormod(i, -2)==0` appearing as the condition
of an if/else block.  These negative indices can also arise as
intermediate simplification steps produced by `CanonicalSimplifier`,
such as `floormod(-i,2)` being canonicalized to `floormod(i,-2)`.

This commit adds support for negative coefficients in `ModularSet`,
using the same sign convention as is used by `CanonicalSimplifier` for
negative denominators, and adds unit tests to verify that sign
convention.
  • Loading branch information
Lunderberg committed Oct 14, 2022
1 parent b389d4d commit b76c70a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,18 @@ struct ModularSetAnalyzer::Entry {
Entry() = default;

Entry(int64_t coeff, int64_t base) {
ICHECK_GE(coeff, 0);
if (coeff < 0) {
// `analyzer->canonical_simplify()` can generate expressions with
// negative coefficients (e.g. simplifying `floormod(-i, 2)`
// into `floormod(i, -2) * -1`). When this happens, the
// ModularSet may enter a constraint based on this expression.
//
// Handling a negative coeff uses the same sign convention as
// canonical_simplify, requiring that
// `floormod(var, coeff) == -floormod(var, -coeff)`.
coeff *= -1;
base *= -1;
}
this->coeff = coeff;
if (coeff != 0) {
base = base % coeff;
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def test_split_index_simplify():
# cannot simplify mixed case, unless we canonicalize into one mode.
ck.verify(tdiv(x, 6) * 2 + tmod(fld(x, 3), 2), tdiv(x, 6) * 2 + tmod(fld(x, 3), 2))

ck.verify(tmod(-x, 2), tmod(x, -2) * -1)


def test_div_simplify():
ck = CanonicalChecker()
Expand Down Expand Up @@ -129,6 +131,8 @@ def test_floormod_simplify():
ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) + y + 12, 16))
ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4)

ck.verify(flm(-x, 2), flm(x, -2) * -1)


def test_canonical_mixed():
ck = CanonicalChecker()
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,5 +816,34 @@ def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32):
A[0] = (i != 30) or (j == 0)


class TestConditionalFloorMod(BaseBeforeAfter):
"""A regression test for negative floormod denominator
Previously, simplifying this function could throw an error. First, the
`canonical_simplify` would rewrite `floormod(0-i,2)` to the equivalent
`floormod(i,-2)`. Then, the rewrite_simplifier would enter a
constrained context in which `floormod(i,-2)==1`. Passing this
expression to `ModularSet::EnterConstraint`, which previously did not
support a negative value for the second argument, threw an error.
The analogous failure mode never occurred for `truncmod`, because
`truncmod(0-i,2)` would be canonicalized to `truncmod(i, -2) * -1`, and
the pattern matching in `ModularSet` didn't recognize the constant
factor.
This failure mode was resolved by supporting negative arguments in
`ModularSet`, using the same sign convention as is used by
`canonical_simplify`.
"""

def before(A: T.Buffer[1, "bool"], i: T.int32):
if T.floormod(0 - i, 2) == 0:
A[0] = T.floormod(i, 2) == 0

def expected(A: T.Buffer[1, "bool"], i: T.int32):
if T.floormod(i, -2) == 0:
A[0] = True


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit b76c70a

Please sign in to comment.