Skip to content

Commit

Permalink
[TensorIR] fix region cover check (apache#9810)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored Dec 29, 2021
1 parent 75cd670 commit 654a687
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape,
arith::IntSet produced = arith::Intersect({produced_region[i], buffer_size});
arith::IntSet consumed = arith::Intersect({consumed_region[i], buffer_size});
PrimExpr produced_min = analyzer->Simplify(produced.min());
PrimExpr produced_max = analyzer->Simplify(produced.max() - produced_min + 1);
PrimExpr produced_max = analyzer->Simplify(produced.max());
PrimExpr consumed_min = analyzer->Simplify(consumed.min());
PrimExpr consumed_max = analyzer->Simplify(consumed.max() - consumed_min + 1);
PrimExpr consumed_max = analyzer->Simplify(consumed.max());
if (!analyzer->CanProve((produced_min <= consumed_min) && (consumed_max <= produced_max))) {
return false;
}
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_tir_schedule_state_cached_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,18 @@ def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None:
)


@T.prim_func
def uncovered_producer_region(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]):
for i in range(120):
with T.block("producer"):
vi = T.axis.S((0, 120), i)
A[vi] = 1.0
for i in range(120):
with T.block("consumer"):
vi = T.axis.S((8, 128), i + 8)
B[vi] = A[vi]


# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg


Expand Down Expand Up @@ -757,5 +769,16 @@ def test_non_perfect_tiling_cache():
# pylint: enable=protected-access


def test_uncovered_producer_region():
s = tir.ScheduleState(uncovered_producer_region, debug_mask="all")
# pylint: disable=protected-access
assert s._get_cached_flags(_get_block(s, "consumer")) == CachedFlags(
affine_binding=True,
region_cover=False,
stage_pipeline=True,
)
# pylint: enable=protected-access


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 654a687

Please sign in to comment.