Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 10, 2022
1 parent 709e1d7 commit 97e1b6e
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,15 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(32, 32):
for ax0_ax1_fused in T.serial(1024):
with T.block("C_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1)
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
T.reads(C_reindex_shared[v0, v1])
T.writes(compute[v0, v1])
T.block_attr({"meta_schedule.cooperative_fetch":4})
compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0))

# fmt: on
decision_0 = [
("SamplePerfectTile", [4, 1, 1, 1, 2]),
Expand Down Expand Up @@ -303,10 +304,10 @@ def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128,
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(32, 128):
for ax0_ax1_fused in T.serial(4096):
with T.block("C_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0)
v1 = T.axis.spatial(128, ax1)
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 128)
v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
T.reads(C_reindex_shared[v0, v1])
T.writes(compute[v0, v1])
T.block_attr({"meta_schedule.cooperative_fetch":4})
Expand Down Expand Up @@ -451,10 +452,10 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3,
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(16, 16):
for ax0_ax1_fused in T.serial(256):
with T.block("conv2d_nhwc_reindex_shared"):
v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0)
v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax1)
v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 16)
v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16)
T.reads(conv2d_nhwc_reindex_shared[v0, v1])
T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1])
T.block_attr({"meta_schedule.cooperative_fetch":3})
Expand Down Expand Up @@ -617,10 +618,10 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128,
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(32, 32):
for ax0_ax1_fused in T.grid(1024):
with T.block("C_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0)
v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax1)
v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32)
T.reads(C_reindex_shared[v0, v1])
T.writes(C[v0, v1])
T.block_attr({"meta_schedule.cooperative_fetch":3})
Expand Down Expand Up @@ -919,11 +920,11 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(32, 32):
for ax0_ax1_fused in T.serial(1024):
with T.block("C_reindex_shared"):
T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1 < 127)
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1)
T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32 < 127)
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
T.reads(C_reindex_shared[v0, v1])
T.writes(compute[v0, v1])
T.block_attr({"meta_schedule.cooperative_fetch":4})
Expand Down Expand Up @@ -1063,10 +1064,10 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(16, 32):
for ax0_ax1_fused in T.serial(512):
with T.block("conv2d_nhwc_reindex_shared"):
v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0)
v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax1)
v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_fused % 32)
T.reads(conv2d_nhwc_reindex_shared[v0, v1])
T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1])
T.block_attr({"meta_schedule.cooperative_fetch":2})
Expand Down

0 comments on commit 97e1b6e

Please sign in to comment.