diff --git a/topi/python/topi/cuda/group_conv2d_nchw.py b/topi/python/topi/cuda/group_conv2d_nchw.py index cbdb3cb8031d..cd8c823ae8ae 100644 --- a/topi/python/topi/cuda/group_conv2d_nchw.py +++ b/topi/python/topi/cuda/group_conv2d_nchw.py @@ -321,8 +321,124 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output): return s +def schedule_group_conv2d_nchw_direct(cfg, s, conv): + """Schedule group conv2d NCHW direct template""" + workload = conv.op.attrs["workload"] + groups = get_const_int(workload[6]) + num_filters = get_const_int(conv.shape[1]) + + ##### space definition begin ##### + n, f, y, x = s[conv].op.axis + rc, ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_n", n, num_outputs=4) + cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2) + cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + target = tvm.target.current_target() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + pad_data, kernel = s[conv].op.input_tensors + + s[pad_data].compute_inline() + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, 'local') + else: + output = s.outputs[0].output(0) + s[conv].set_scope('local') + OL = conv + + # create cache stage + AA = s.cache_read(pad_data, 'shared', [OL]) + WW = s.cache_read(kernel, 'shared', [OL]) + + # tile and bind spatial axes + n, f, y, x = s[output].op.axis + kernel_scope, n = s[output].split(n, nparts=1) + + g, f = s[output].split(f, nparts=groups) + bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) + bg, vg = cfg["tile_g"].apply(s, output, g) + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi) + s[output].bind(bn, tvm.thread_axis("blockIdx.z")) + s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y")) + s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x")) + s[output].bind(vn, tvm.thread_axis("vthread")) + s[output].bind(vg, tvm.thread_axis("vthread")) + s[output].bind(vf, tvm.thread_axis("vthread")) + s[output].bind(vy, tvm.thread_axis("vthread")) + s[output].bind(vx, tvm.thread_axis("vthread")) + + cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf + if cfg["fuse_yx"].val: + s[output].bind(tn, tvm.thread_axis("threadIdx.z")) + s[output].bind(tf, tvm.thread_axis("threadIdx.y")) + tyx = s[output].fuse(ty, tx) + s[output].bind(tyx, tvm.thread_axis("threadIdx.x")) + s[OL].compute_at(s[output], tyx) + + # number of threads + n_tz = cfg["tile_n"].size[2] + n_ty = cfg["tile_f"].size[2] + n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] + else: + s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z")) + s[output].bind(ty, tvm.thread_axis("threadIdx.y")) + s[output].bind(tx, tvm.thread_axis("threadIdx.x")) + s[OL].compute_at(s[output], tx) + + # number of threads + n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] + n_ty = cfg["tile_y"].size[2] + n_tx = cfg["tile_x"].size[2] + + # tile reduction axes + n, f, y, x = s[OL].op.axis + rc, ry, rx = s[OL].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, OL, rc) + ryo, ryi = cfg['tile_rx'].apply(s, OL, ry) + rxo, rxi = cfg['tile_ry'].apply(s, OL, rx) + s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + + # cooperative fetching + for load in [AA, WW]: + n, f, y, x = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + fused, tx = s[load].split(fused, factor=n_tx) + fused, ty = s[load].split(fused, factor=n_ty) + fused, tz = s[load].split(fused, factor=n_tz) + s[load].bind(tz, tvm.thread_axis("threadIdx.z")) + s[load].bind(ty, tvm.thread_axis("threadIdx.y")) + s[load].bind(tx, tvm.thread_axis("threadIdx.x")) + + # unroll + s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + N, CO, OH, OW = get_const_tuple(output.shape) + _, CI_div_groups, KH, KW = get_const_tuple(kernel.shape) + cfg.add_flop(2 * N * OH * OW * CO * CI_div_groups * KH * KW) + + @autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw, - ["cuda", "gpu"], ["int8"]) + ["cuda", "gpu"], ["int8", "direct"]) def schedule_conv2d_nchw_cuda(cfg, outs): """TOPI schedule callback of group conv2d for cuda gpu @@ -347,7 +463,7 @@ def _callback(op): if op.tag == "group_conv2d_NCHWc_int8": schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0)) if op.tag == "group_conv2d_nchw": - raise tvm.error.OpNotImplemented("group_conv2d_nchw not supported") + schedule_group_conv2d_nchw_direct(cfg, s, op.output(0)) traverse_inline(s, outs[0].op, _callback) return s diff --git a/topi/tests/python/test_topi_group_conv2d.py b/topi/tests/python/test_topi_group_conv2d.py index e80999977e5b..0e176780023d 100644 --- a/topi/tests/python/test_topi_group_conv2d.py +++ b/topi/tests/python/test_topi_group_conv2d.py @@ -67,9 +67,6 @@ def check_device(device): if not ctx.exist: print("Skip because %s is not enabled" % device) return - if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version): - print("Skip because int8 intrinsics are not available") - return print("Running on target: %s" % device) with tvm.target.create(device): @@ -94,7 +91,7 @@ def check_device(device): func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - for device in ["llvm"]: + for device in ["llvm", "cuda"]: check_device(device)