Skip to content

Commit

Permalink
Add winograd conv2d nchw search space test.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Aug 2, 2022
1 parent add3a94 commit 0bb5078
Showing 1 changed file with 152 additions and 0 deletions.
152 changes: 152 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Tests for MetaSchedule search space on CUDA"""
from tvm import te, topi, autotvm
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.space_generation import check_sketches, print_sketches
from tvm.meta_schedule.testing.te_workload import create_te_workload
Expand All @@ -25,6 +26,11 @@
def _target():
return Target("nvidia/geforce-rtx-3070")

def _conv2d_winograd_nchw():
data = te.placeholder((1, 64, 224, 224), name="data", dtype="float32")
kernel = te.placeholder((6, 6, 64, 64), name="kernel", dtype="float32")
return te.create_prim_func([data, kernel, topi.cuda.conv2d_winograd.winograd_cuda(cfg=autotvm.ConfigSpace(), data = data, kernel=kernel, strides=(1, 1), padding=(1, 1), dilation=(1, 1), out_dtype="float32", pre_computed=True)])


def test_cuda_c1d():
# fmt: off
Expand Down Expand Up @@ -1216,6 +1222,151 @@ def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3
expected_decisions=[decision_0],
)

def test_cuda_winograd_nchw_conv2d():
# fmt: off
@T.prim_func
def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T.Buffer[(6, 6, 64, 64), "float32"], output: T.Buffer[(1, 64, 224, 224), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":1024})
data_pack = T.alloc_buffer([6, 6, 64, 3136], dtype="float32")
bgemm = T.alloc_buffer([6, 6, 64, 3136], dtype="float32")
inverse_local = T.alloc_buffer([64, 3136, 4, 4], dtype="float32", scope="local")
data_pack_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local")
bgemm_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local")
kernel_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared")
data_pack_shared = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="shared")
for i2_i3_0_fused_i3_1_fused_0 in T.thread_binding(3136, thread="blockIdx.x"):
for i2_i3_0_fused_i3_1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
for i0 in T.unroll(6):
for i1 in T.unroll(6):
for i4 in T.unroll(6):
for i5 in T.unroll(6):
with T.block("data_pack"):
eps, nu = T.axis.remap("SS", [i0, i1])
ci = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136)
p = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 16 * 16 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 16)
r_a, r_a_1 = T.axis.remap("RR", [i4, i5])
T.reads(data[p // 3136, ci, p % 3136 // 56 * 4 + r_a - 1, p % 56 * 4 + r_a_1 - 1])
T.writes(data_pack_local[eps, nu, ci, p])
T.block_attr({"schedule_rule":"meta_schedule.winograd_data_pack.nchw.cuda"})
with T.init():
data_pack_local[eps, nu, ci, p] = T.float32(0)
data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + T.if_then_else(1 <= p % 3136 // 56 * 4 + r_a and p % 3136 // 56 * 4 + r_a < 225 and 1 <= p % 56 * 4 + r_a_1 and p % 56 * 4 + r_a_1 < 225, data[p // 3136, ci, p % 3136 // 56 * 4 + r_a - 1, p % 56 * 4 + r_a_1 - 1], T.float32(0), dtype="float32") * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_a_1 % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_a_1 % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_a_1 % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_a_1 % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_a_1 % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_a_1 % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))
for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1):
with T.block("data_pack_local"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136 + ax2)
v3 = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 16 * 16 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 16 + ax3)
T.reads(data_pack_local[v0, v1, v2, v3])
T.writes(data_pack[v0, v1, v2, v3])
data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3]
for i0_i1_fused_0 in T.thread_binding(2, thread="blockIdx.z"):
for i2_0 in T.thread_binding(2, thread="blockIdx.y"):
for i3_0 in T.thread_binding(56, thread="blockIdx.x"):
for i0_i1_fused_1 in T.thread_binding(3, thread="vthread.z"):
for i2_1 in T.thread_binding(2, thread="vthread.y"):
for i3_1 in T.thread_binding(1, thread="vthread.x"):
for i0_i1_fused_2 in T.thread_binding(6, thread="threadIdx.z"):
for i2_2 in T.thread_binding(2, thread="threadIdx.y"):
for i3_2 in T.thread_binding(8, thread="threadIdx.x"):
for i4_0 in T.serial(16):
for ax0, ax1, ax2, ax3 in T.grid(3, 6, 4, 32):
with T.block("kernel_shared"):
v0 = T.axis.spatial(6, i0_i1_fused_0 * 3 + ax0)
v1 = T.axis.spatial(6, ax1)
v2 = T.axis.spatial(64, i4_0 * 4 + ax2)
v3 = T.axis.spatial(64, i2_0 * 32 + ax3)
T.reads(kernel[v0, v1, v2, v3])
T.writes(kernel_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":4})
kernel_shared[v0, v1, v2, v3] = kernel[v0, v1, v2, v3]
for ax0, ax1, ax2, ax3 in T.grid(3, 6, 4, 56):
with T.block("data_pack_shared"):
v0 = T.axis.spatial(6, i0_i1_fused_0 * 3 + ax0)
v1 = T.axis.spatial(6, ax1)
v2 = T.axis.spatial(64, i4_0 * 4 + ax2)
v3 = T.axis.spatial(3136, i3_0 * 56 + ax3)
T.reads(data_pack[v0, v1, v2, v3])
T.writes(data_pack_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":4})
data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3]
for i4_1, i0_i1_fused_3, i2_3, i3_3 in T.grid(4, 1, 8, 7):
with T.block("bgemm"):
eps = T.axis.spatial(6, (i0_i1_fused_0 * 18 + i0_i1_fused_1 * 6 + i0_i1_fused_2 + i0_i1_fused_3) // 6)
nu = T.axis.spatial(6, (i0_i1_fused_0 * 18 + i0_i1_fused_1 * 6 + i0_i1_fused_2 + i0_i1_fused_3) % 6)
co = T.axis.spatial(64, i2_0 * 32 + i2_1 * 16 + i2_2 * 8 + i2_3)
p = T.axis.spatial(3136, i3_0 * 56 + i3_1 * 56 + i3_2 * 7 + i3_3)
ci = T.axis.reduce(64, i4_0 * 4 + i4_1)
T.reads(kernel_shared[eps, nu, ci, co], data_pack_shared[eps, nu, ci, p])
T.writes(bgemm_local[eps, nu, co, p])
T.block_attr({"meta_schedule.unroll_explicit":512, "pragma_auto_unroll_max_step":256, "schedule_rule":"meta_schedule.winograd_bgemm.nchw.cuda"})
with T.init():
bgemm_local[eps, nu, co, p] = T.float32(0)
bgemm_local[eps, nu, co, p] = bgemm_local[eps, nu, co, p] + kernel_shared[eps, nu, ci, co] * data_pack_shared[eps, nu, ci, p]
for ax0, ax1 in T.grid(8, 7):
with T.block("bgemm_local"):
v0 = T.axis.spatial(6, i0_i1_fused_0 * 3 + i0_i1_fused_1)
v1 = T.axis.spatial(6, i0_i1_fused_2)
v2 = T.axis.spatial(64, i2_0 * 32 + i2_1 * 16 + i2_2 * 8 + ax0)
v3 = T.axis.spatial(3136, i3_0 * 56 + i3_2 * 7 + ax1)
T.reads(bgemm_local[v0, v1, v2, v3])
T.writes(bgemm[v0, v1, v2, v3])
bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3]
for i0_i1_i2_0_i3_0_fused_fused_0 in T.thread_binding(6272, thread="blockIdx.x"):
for i0_i1_i2_0_i3_0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 1, 4, 4, 6, 6):
with T.block("inverse"):
co = T.axis.spatial(64, ax0 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) // 3136)
p = T.axis.spatial(3136, ax1 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 3136 // 56 * 56 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 56)
vh, vw, r_a_2, r_a_3 = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5])
T.reads(bgemm[r_a_2, r_a_3, co, p])
T.writes(inverse_local[co, p, vh, vw])
T.block_attr({"schedule_rule":"meta_schedule.winograd_inverse.nchw.cuda"})
with T.init():
inverse_local[co, p, vh, vw] = T.float32(0)
inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a_2, r_a_3, co, p] * T.Select(r_a_2 % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a_2 % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a_2 % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a_2 % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a_2 % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a_2 % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a_2 % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a_2 % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a_2 % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a_2 % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a_2 % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_a_3 % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_a_3 % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_a_3 % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_a_3 % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_a_3 % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_a_3 % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_a_3 % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_a_3 % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_a_3 % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_a_3 % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_a_3 % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0)))))))))))))))))))))))))
for i2_1, i3_1 in T.grid(4, 4):
with T.block("output"):
n = T.axis.spatial(1, 0)
co = T.axis.spatial(64, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) // 3136)
h = T.axis.spatial(224, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 3136 // 56 * 4 + i2_1)
w = T.axis.spatial(224, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 56 * 4 + i3_1)
T.reads(inverse_local[co, n * 3136 + h // 4 * 56 + w // 4, h % 4, w % 4])
T.writes(output[n, co, h, w])
T.block_attr({"schedule_rule":"meta_schedule.winograd_output.nchw.cuda", "winograd_tile_size":4})
output[n, co, h, w] = inverse_local[co, n * 3136 + h // 4 * 56 + w // 4, h % 4, w % 4]
# fmt: on
decision_0 = [
("SamplePerfectTile", [196, 16]),
("SampleCategorical", 1),
("SamplePerfectTile", [2, 3, 6, 1]),
("SamplePerfectTile", [2, 2, 2, 8]),
("SamplePerfectTile", [56, 1, 8, 7]),
("SamplePerfectTile", [16, 4]),
("SampleCategorical", 3),
("SampleCategorical", 3),
("SampleCategorical", 1),
("SampleCategorical", 0),
("SampleCategorical", 4),
]
mod = _conv2d_winograd_nchw()
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[winograd_nchw_conv2d],
expected_decisions=[decision_0],
)

def test_cuda_tbg():
# fmt: off
Expand Down Expand Up @@ -1315,3 +1466,4 @@ def tbg_0(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128,
test_cuda_sfm()
test_cuda_cbr()
test_cuda_tbg()
test_cuda_winograd_nchw_conv2d()

0 comments on commit 0bb5078

Please sign in to comment.