From 654a687e5c67d1ccaaf54ffb8da79af3c113ce03 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 30 Dec 2021 03:09:45 +0800 Subject: [PATCH] [TensorIR] fix region cover check (#9810) --- src/tir/schedule/state.cc | 4 ++-- .../test_tir_schedule_state_cached_flags.py | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index faeb0b9907d7..04b7dd5ea2af 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -104,9 +104,9 @@ bool ProducerCoversConsumer(const Array& buffer_shape, arith::IntSet produced = arith::Intersect({produced_region[i], buffer_size}); arith::IntSet consumed = arith::Intersect({consumed_region[i], buffer_size}); PrimExpr produced_min = analyzer->Simplify(produced.min()); - PrimExpr produced_max = analyzer->Simplify(produced.max() - produced_min + 1); + PrimExpr produced_max = analyzer->Simplify(produced.max()); PrimExpr consumed_min = analyzer->Simplify(consumed.min()); - PrimExpr consumed_max = analyzer->Simplify(consumed.max() - consumed_min + 1); + PrimExpr consumed_max = analyzer->Simplify(consumed.max()); if (!analyzer->CanProve((produced_min <= consumed_min) && (consumed_max <= produced_max))) { return false; } diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index d86af72fca93..e88eacdb453b 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -353,6 +353,18 @@ def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: ) +@T.prim_func +def uncovered_producer_region(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + for i in range(120): + with T.block("producer"): + vi = T.axis.S((0, 120), i) + A[vi] = 1.0 + for i in range(120): + with T.block("consumer"): + vi = T.axis.S((8, 128), i + 8) + B[vi] = A[vi] + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -757,5 +769,16 @@ def test_non_perfect_tiling_cache(): # pylint: enable=protected-access +def test_uncovered_producer_region(): + s = tir.ScheduleState(uncovered_producer_region, debug_mask="all") + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "consumer")) == CachedFlags( + affine_binding=True, + region_cover=False, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))