diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index f39ce4b05643..2c01b9143155 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -245,6 +245,23 @@ class ConstIntBoundAnalyzer::Impl } Entry VisitExpr_(const FloorModNode* op) final { + /* let a / b = x + y, where x is integer, y \in [0, 1) + * floormod(a, b) = a - floordiv(a, b) * b + * floordiv(a, b) = x + * floormod(a, b) = a - floordiv(a, b) * b + * = a - x * b + * = a - (a / b - y) * b + * = a - a + y * b + * = y * b + * note that 0 <= y < 1 + * when b > 0, 0 <= b * y < b + * 0 <= b * y <= b - 1 + * when b < 0, b < b * y <= 0 + * b + 1 <= b * y <= 0 + * In all cases, min(0, b + 1) <= b * y <= max(0, b - 1) + * min(0, b_min + 1) <= b * y <= max(0, b_max - 1) + * That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1) + */ Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); if (b.min_value > 0) { @@ -259,9 +276,11 @@ class ConstIntBoundAnalyzer::Impl } } else { ICHECK(!b.is_const(0)) << "floormod by zero"; - // mod by negative value is rare, - // and we just use the simpliest rule. - return Everything(op->dtype); + int64_t b_min_cap = InfAwareAdd(b.min_value, 1); + int64_t b_max_cap = InfAwareAdd(b.max_value, -1); + return Intersect(MakeBound(std::min(static_cast(0), b_min_cap), + std::max(static_cast(0), b_max_cap)), + Everything(op->dtype)); } } diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py index badbcbcf1bb3..84fc7fd64614 100644 --- a/tests/python/unittest/test_arith_const_int_bound.py +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -303,6 +303,17 @@ def test_let_bound(): assert bd.max_value == 2 +def test_floormod_negative_divisor(): + analyzer = tvm.arith.Analyzer() + flm, fld = tvm.te.floormod, tvm.te.floordiv + a, b = te.var("a"), te.var("b") + analyzer.update(a, tvm.arith.ConstIntBound(0, 6)) + analyzer.update(b, tvm.arith.ConstIntBound(-5, 7)) + bd = analyzer.const_int_bound(flm(a, b)) + assert bd.min_value == -4 + assert bd.max_value == 6 + + if __name__ == "__main__": test_let_bound() test_dtype_bound() @@ -318,3 +329,4 @@ def test_let_bound(): test_shift_and_bound() test_mix_index_bound() test_size_var_bound() + test_floormod_negative_divisor()