From 0bb5078f5f9c8078c70052c2f0e0fc56ce4f68d7 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 2 Aug 2022 00:04:55 -0700 Subject: [PATCH] Add winograd conv2d nchw search space test. --- .../unittest/test_meta_schedule_space_cuda.py | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index d617742d9457..ad49d05dc3d1 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Tests for MetaSchedule search space on CUDA""" +from tvm import te, topi, autotvm from tvm import meta_schedule as ms from tvm.meta_schedule.testing.space_generation import check_sketches, print_sketches from tvm.meta_schedule.testing.te_workload import create_te_workload @@ -25,6 +26,11 @@ def _target(): return Target("nvidia/geforce-rtx-3070") +def _conv2d_winograd_nchw(): + data = te.placeholder((1, 64, 224, 224), name="data", dtype="float32") + kernel = te.placeholder((6, 6, 64, 64), name="kernel", dtype="float32") + return te.create_prim_func([data, kernel, topi.cuda.conv2d_winograd.winograd_cuda(cfg=autotvm.ConfigSpace(), data = data, kernel=kernel, strides=(1, 1), padding=(1, 1), dilation=(1, 1), out_dtype="float32", pre_computed=True)]) + def test_cuda_c1d(): # fmt: off @@ -1216,6 +1222,151 @@ def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3 expected_decisions=[decision_0], ) +def test_cuda_winograd_nchw_conv2d(): + # fmt: off + @T.prim_func + def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T.Buffer[(6, 6, 64, 64), "float32"], output: T.Buffer[(1, 64, 224, 224), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":1024}) + data_pack = T.alloc_buffer([6, 6, 64, 3136], dtype="float32") + bgemm = T.alloc_buffer([6, 6, 64, 3136], dtype="float32") + inverse_local = T.alloc_buffer([64, 3136, 4, 4], dtype="float32", scope="local") + data_pack_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local") + bgemm_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local") + kernel_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared") + data_pack_shared = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="shared") + for i2_i3_0_fused_i3_1_fused_0 in T.thread_binding(3136, thread="blockIdx.x"): + for i2_i3_0_fused_i3_1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for i0 in T.unroll(6): + for i1 in T.unroll(6): + for i4 in T.unroll(6): + for i5 in T.unroll(6): + with T.block("data_pack"): + eps, nu = T.axis.remap("SS", [i0, i1]) + ci = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136) + p = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 16 * 16 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 16) + r_a, r_a_1 = T.axis.remap("RR", [i4, i5]) + T.reads(data[p // 3136, ci, p % 3136 // 56 * 4 + r_a - 1, p % 56 * 4 + r_a_1 - 1]) + T.writes(data_pack_local[eps, nu, ci, p]) + T.block_attr({"schedule_rule":"meta_schedule.winograd_data_pack.nchw.cuda"}) + with T.init(): + data_pack_local[eps, nu, ci, p] = T.float32(0) + data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + T.if_then_else(1 <= p % 3136 // 56 * 4 + r_a and p % 3136 // 56 * 4 + r_a < 225 and 1 <= p % 56 * 4 + r_a_1 and p % 56 * 4 + r_a_1 < 225, data[p // 3136, ci, p % 3136 // 56 * 4 + r_a - 1, p % 56 * 4 + r_a_1 - 1], T.float32(0), dtype="float32") * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_a_1 % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_a_1 % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_a_1 % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_a_1 % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_a_1 % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_a_1 % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): + with T.block("data_pack_local"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136 + ax2) + v3 = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 16 * 16 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 16 + ax3) + T.reads(data_pack_local[v0, v1, v2, v3]) + T.writes(data_pack[v0, v1, v2, v3]) + data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3] + for i0_i1_fused_0 in T.thread_binding(2, thread="blockIdx.z"): + for i2_0 in T.thread_binding(2, thread="blockIdx.y"): + for i3_0 in T.thread_binding(56, thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(3, thread="vthread.z"): + for i2_1 in T.thread_binding(2, thread="vthread.y"): + for i3_1 in T.thread_binding(1, thread="vthread.x"): + for i0_i1_fused_2 in T.thread_binding(6, thread="threadIdx.z"): + for i2_2 in T.thread_binding(2, thread="threadIdx.y"): + for i3_2 in T.thread_binding(8, thread="threadIdx.x"): + for i4_0 in T.serial(16): + for ax0, ax1, ax2, ax3 in T.grid(3, 6, 4, 32): + with T.block("kernel_shared"): + v0 = T.axis.spatial(6, i0_i1_fused_0 * 3 + ax0) + v1 = T.axis.spatial(6, ax1) + v2 = T.axis.spatial(64, i4_0 * 4 + ax2) + v3 = T.axis.spatial(64, i2_0 * 32 + ax3) + T.reads(kernel[v0, v1, v2, v3]) + T.writes(kernel_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + kernel_shared[v0, v1, v2, v3] = kernel[v0, v1, v2, v3] + for ax0, ax1, ax2, ax3 in T.grid(3, 6, 4, 56): + with T.block("data_pack_shared"): + v0 = T.axis.spatial(6, i0_i1_fused_0 * 3 + ax0) + v1 = T.axis.spatial(6, ax1) + v2 = T.axis.spatial(64, i4_0 * 4 + ax2) + v3 = T.axis.spatial(3136, i3_0 * 56 + ax3) + T.reads(data_pack[v0, v1, v2, v3]) + T.writes(data_pack_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] + for i4_1, i0_i1_fused_3, i2_3, i3_3 in T.grid(4, 1, 8, 7): + with T.block("bgemm"): + eps = T.axis.spatial(6, (i0_i1_fused_0 * 18 + i0_i1_fused_1 * 6 + i0_i1_fused_2 + i0_i1_fused_3) // 6) + nu = T.axis.spatial(6, (i0_i1_fused_0 * 18 + i0_i1_fused_1 * 6 + i0_i1_fused_2 + i0_i1_fused_3) % 6) + co = T.axis.spatial(64, i2_0 * 32 + i2_1 * 16 + i2_2 * 8 + i2_3) + p = T.axis.spatial(3136, i3_0 * 56 + i3_1 * 56 + i3_2 * 7 + i3_3) + ci = T.axis.reduce(64, i4_0 * 4 + i4_1) + T.reads(kernel_shared[eps, nu, ci, co], data_pack_shared[eps, nu, ci, p]) + T.writes(bgemm_local[eps, nu, co, p]) + T.block_attr({"meta_schedule.unroll_explicit":512, "pragma_auto_unroll_max_step":256, "schedule_rule":"meta_schedule.winograd_bgemm.nchw.cuda"}) + with T.init(): + bgemm_local[eps, nu, co, p] = T.float32(0) + bgemm_local[eps, nu, co, p] = bgemm_local[eps, nu, co, p] + kernel_shared[eps, nu, ci, co] * data_pack_shared[eps, nu, ci, p] + for ax0, ax1 in T.grid(8, 7): + with T.block("bgemm_local"): + v0 = T.axis.spatial(6, i0_i1_fused_0 * 3 + i0_i1_fused_1) + v1 = T.axis.spatial(6, i0_i1_fused_2) + v2 = T.axis.spatial(64, i2_0 * 32 + i2_1 * 16 + i2_2 * 8 + ax0) + v3 = T.axis.spatial(3136, i3_0 * 56 + i3_2 * 7 + ax1) + T.reads(bgemm_local[v0, v1, v2, v3]) + T.writes(bgemm[v0, v1, v2, v3]) + bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] + for i0_i1_i2_0_i3_0_fused_fused_0 in T.thread_binding(6272, thread="blockIdx.x"): + for i0_i1_i2_0_i3_0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 1, 4, 4, 6, 6): + with T.block("inverse"): + co = T.axis.spatial(64, ax0 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) // 3136) + p = T.axis.spatial(3136, ax1 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 3136 // 56 * 56 + (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 56) + vh, vw, r_a_2, r_a_3 = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) + T.reads(bgemm[r_a_2, r_a_3, co, p]) + T.writes(inverse_local[co, p, vh, vw]) + T.block_attr({"schedule_rule":"meta_schedule.winograd_inverse.nchw.cuda"}) + with T.init(): + inverse_local[co, p, vh, vw] = T.float32(0) + inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a_2, r_a_3, co, p] * T.Select(r_a_2 % 6 == 5 and vh % 4 == 3, T.float32(1), T.Select(r_a_2 % 6 == 5 and vh % 4 == 2, T.float32(0), T.Select(r_a_2 % 6 == 5 and vh % 4 == 1, T.float32(0), T.Select(r_a_2 % 6 == 5 and vh % 4 == 0, T.float32(0), T.Select(r_a_2 % 6 == 4 and vh % 4 == 3, T.float32(-8), T.Select(r_a_2 % 6 == 4 and vh % 4 == 2, T.float32(4), T.Select(r_a_2 % 6 == 4 and vh % 4 == 1, T.float32(-2), T.Select(r_a_2 % 6 == 4 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 3 and vh % 4 == 3, T.float32(0.125), T.Select(r_a_2 % 6 == 3 and vh % 4 == 2, T.float32(0.25), T.Select(r_a_2 % 6 == 3 and vh % 4 == 1, T.float32(0.5), T.Select(r_a_2 % 6 == 3 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 3, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 2, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 1, T.float32(1), T.Select(r_a_2 % 6 == 2 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 3, T.float32(-1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 2, T.float32(1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 1, T.float32(-1), T.Select(r_a_2 % 6 == 1 and vh % 4 == 0, T.float32(1), T.Select(r_a_2 % 6 == 0 and vh % 4 == 3, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 2, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 1, T.float32(0), T.Select(r_a_2 % 6 == 0 and vh % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) * T.Select(r_a_3 % 6 == 5 and vw % 4 == 3, T.float32(1), T.Select(r_a_3 % 6 == 5 and vw % 4 == 2, T.float32(0), T.Select(r_a_3 % 6 == 5 and vw % 4 == 1, T.float32(0), T.Select(r_a_3 % 6 == 5 and vw % 4 == 0, T.float32(0), T.Select(r_a_3 % 6 == 4 and vw % 4 == 3, T.float32(-8), T.Select(r_a_3 % 6 == 4 and vw % 4 == 2, T.float32(4), T.Select(r_a_3 % 6 == 4 and vw % 4 == 1, T.float32(-2), T.Select(r_a_3 % 6 == 4 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 3 and vw % 4 == 3, T.float32(0.125), T.Select(r_a_3 % 6 == 3 and vw % 4 == 2, T.float32(0.25), T.Select(r_a_3 % 6 == 3 and vw % 4 == 1, T.float32(0.5), T.Select(r_a_3 % 6 == 3 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 3, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 2, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 1, T.float32(1), T.Select(r_a_3 % 6 == 2 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 3, T.float32(-1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 2, T.float32(1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 1, T.float32(-1), T.Select(r_a_3 % 6 == 1 and vw % 4 == 0, T.float32(1), T.Select(r_a_3 % 6 == 0 and vw % 4 == 3, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 2, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 1, T.float32(0), T.Select(r_a_3 % 6 == 0 and vw % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) + for i2_1, i3_1 in T.grid(4, 4): + with T.block("output"): + n = T.axis.spatial(1, 0) + co = T.axis.spatial(64, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) // 3136) + h = T.axis.spatial(224, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 3136 // 56 * 4 + i2_1) + w = T.axis.spatial(224, (i0_i1_i2_0_i3_0_fused_fused_0 * 32 + i0_i1_i2_0_i3_0_fused_fused_1) % 56 * 4 + i3_1) + T.reads(inverse_local[co, n * 3136 + h // 4 * 56 + w // 4, h % 4, w % 4]) + T.writes(output[n, co, h, w]) + T.block_attr({"schedule_rule":"meta_schedule.winograd_output.nchw.cuda", "winograd_tile_size":4}) + output[n, co, h, w] = inverse_local[co, n * 3136 + h // 4 * 56 + w // 4, h % 4, w % 4] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [196, 16]), + ("SampleCategorical", 1), + ("SamplePerfectTile", [2, 3, 6, 1]), + ("SamplePerfectTile", [2, 2, 2, 8]), + ("SamplePerfectTile", [56, 1, 8, 7]), + ("SamplePerfectTile", [16, 4]), + ("SampleCategorical", 3), + ("SampleCategorical", 3), + ("SampleCategorical", 1), + ("SampleCategorical", 0), + ("SampleCategorical", 4), + ] + mod = _conv2d_winograd_nchw() + actual = ms.TuneContext( + mod=mod, + target=_target(), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules="default", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[winograd_nchw_conv2d], + expected_decisions=[decision_0], + ) def test_cuda_tbg(): # fmt: off @@ -1315,3 +1466,4 @@ def tbg_0(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, test_cuda_sfm() test_cuda_cbr() test_cuda_tbg() + test_cuda_winograd_nchw_conv2d()