Skip to content

Commit

Permalink
[TOPI][CUDA] schedule for group_conv2d (apache#3663)
Browse files Browse the repository at this point in the history
* [TOPI][CUDA] schedule for group_conv2d

* Fix #flops
  • Loading branch information
vinx13 authored and wweic committed Aug 9, 2019
1 parent 607d05c commit 12f768d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 6 deletions.
120 changes: 118 additions & 2 deletions topi/python/topi/cuda/group_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,124 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
return s


def schedule_group_conv2d_nchw_direct(cfg, s, conv):
"""Schedule group conv2d NCHW direct template"""
workload = conv.op.attrs["workload"]
groups = get_const_int(workload[6])
num_filters = get_const_int(conv.shape[1])

##### space definition begin #####
n, f, y, x = s[conv].op.axis
rc, ry, rx = s[conv].op.reduce_axis
cfg.define_split("tile_n", n, num_outputs=4)
cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4)
cfg.define_split("tile_y", y, num_outputs=4)
cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=2)
cfg.define_split("tile_ry", ry, num_outputs=2)
cfg.define_split("tile_rx", rx, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])

target = tvm.target.current_target()
if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])

pad_data, kernel = s[conv].op.input_tensors

s[pad_data].compute_inline()

if conv.op in s.outputs:
output = conv
OL = s.cache_write(conv, 'local')
else:
output = s.outputs[0].output(0)
s[conv].set_scope('local')
OL = conv

# create cache stage
AA = s.cache_read(pad_data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])

# tile and bind spatial axes
n, f, y, x = s[output].op.axis
kernel_scope, n = s[output].split(n, nparts=1)

g, f = s[output].split(f, nparts=groups)
bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
bg, vg = cfg["tile_g"].apply(s, output, g)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y"))
s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
s[output].bind(vn, tvm.thread_axis("vthread"))
s[output].bind(vg, tvm.thread_axis("vthread"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))

cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
if cfg["fuse_yx"].val:
s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
tyx = s[output].fuse(ty, tx)
s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
s[OL].compute_at(s[output], tyx)

# number of threads
n_tz = cfg["tile_n"].size[2]
n_ty = cfg["tile_f"].size[2]
n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
else:
s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[OL].compute_at(s[output], tx)

# number of threads
n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
n_ty = cfg["tile_y"].size[2]
n_tx = cfg["tile_x"].size[2]

# tile reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
rco, rci = cfg['tile_rc'].apply(s, OL, rc)
ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)

s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)

# cooperative fetching
for load in [AA, WW]:
n, f, y, x = s[load].op.axis
fused = s[load].fuse(n, f, y, x)
fused, tx = s[load].split(fused, factor=n_tx)
fused, ty = s[load].split(fused, factor=n_ty)
fused, tz = s[load].split(fused, factor=n_tz)
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))

# unroll
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)

N, CO, OH, OW = get_const_tuple(output.shape)
_, CI_div_groups, KH, KW = get_const_tuple(kernel.shape)
cfg.add_flop(2 * N * OH * OW * CO * CI_div_groups * KH * KW)


@autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw,
["cuda", "gpu"], ["int8"])
["cuda", "gpu"], ["int8", "direct"])
def schedule_conv2d_nchw_cuda(cfg, outs):
"""TOPI schedule callback of group conv2d for cuda gpu
Expand All @@ -347,7 +463,7 @@ def _callback(op):
if op.tag == "group_conv2d_NCHWc_int8":
schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0))
if op.tag == "group_conv2d_nchw":
raise tvm.error.OpNotImplemented("group_conv2d_nchw not supported")
schedule_group_conv2d_nchw_direct(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s
5 changes: 1 addition & 4 deletions topi/tests/python/test_topi_group_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ def check_device(device):
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
print("Skip because int8 intrinsics are not available")
return

print("Running on target: %s" % device)
with tvm.target.create(device):
Expand All @@ -94,7 +91,7 @@ def check_device(device):
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

for device in ["llvm"]:
for device in ["llvm", "cuda"]:
check_device(device)


Expand Down

0 comments on commit 12f768d

Please sign in to comment.