Skip to content

Commit

Permalink
Add test case for CacheLocDetector issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
Min Chen committed Nov 11, 2022
1 parent 7be4e10 commit 992d5d3
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions tests/python/unittest/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,33 @@ def elementwise_shape_int64(a: T.handle, c: T.handle) -> None:
C[vi, vj] = B[vi, vj] + 1.0


@T.prim_func
def func_nested_seq(b: T.handle, c: T.handle) -> None:
A = T.alloc_buffer((128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))

for i, j in T.grid(128, 128):
with T.block("A"):
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = 2.0
for i, j in T.grid(8, 8):
for x, y in T.grid(16, 16):
with T.block("B0"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B[vi, vj] = 1.0
for x, y in T.grid(16, 16):
with T.block("B1"):
vi = T.axis.S(128, i * 16 + x)
vj = T.axis.S(128, j * 16 + y)
B[vi, vj] = A[vi, vj] + B[vi, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 2.0


@T.prim_func
def access_under_scope(b: T.handle, c: T.handle) -> None:
A = T.alloc_buffer((128, 128))
Expand Down Expand Up @@ -250,6 +277,47 @@ def inplace_call(data_io: T.Buffer[(64), "int32"]):
T.evaluate(T.call_extern("call_impl", data_io.data, dtype=""))


@T.prim_func
def cache_read_nested_seq_target(
B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None :
A = T.alloc_buffer([128, 128], dtype="float32")
A_global = T.alloc_buffer([128, 128], dtype="float32")
for i, j in T.grid(128, 128):
with T.block("A"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads()
T.writes(A[vi, vj])
A[vi, vj] = T.float32(2)
for i, j in T.grid(8, 8):
for x, y in T.grid(16, 16):
with T.block("B0"):
vi = T.axis.spatial(128, i * 16 + x)
vj = T.axis.spatial(128, j * 16 + y)
T.reads()
T.writes(B[vi, vj])
B[vi, vj] = T.float32(1)
for x, y in T.grid(16, 16):
with T.block("B1"):
vi = T.axis.spatial(128, i * 16 + x)
vj = T.axis.spatial(128, j * 16 + y)
T.reads(A[vi, vj], B[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] + B[vi, vj]
for ax0, ax1 in T.grid(128, 128):
with T.block("A_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_global[v0, v1])
A_global[v0, v1] = A[v0, v1]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_global[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = A_global[vi, vj] * T.float32(2)


########## Expected function after cache_read ##########


Expand Down Expand Up @@ -989,6 +1057,14 @@ def test_cache_inplace():
verify_trace_roundtrip(sch=sch, mod=inplace_call, debug_mask=debug_mask)


def test_cache_read_nested_seq(use_block_name):
sch = tir.Schedule(func_nested_seq, debug_mask="all")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_c, 0, "global", consumer_blocks=[block_c])
tvm.ir.assert_structural_equal(cache_read_nested_seq_target, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_nested_seq)


########## Testcases for cache_write ##########


Expand Down

0 comments on commit 992d5d3

Please sign in to comment.