Skip to content

Commit

Permalink
Add vectorization to cuda conv2d_nhwc schedule
Browse files Browse the repository at this point in the history
Adding vectorization significantly improved performance. About 6-7x
boost.
  • Loading branch information
echuraev committed Aug 3, 2021
1 parent 3cb0e93 commit dbbc678
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/tvm/topi/cuda/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2])
cfg.define_knob("vthread_c", [1, 2])
cfg.define_knob("step", [16, 3, 32, 64])
cfg.define_knob("vectorize", [4, 2, 8, 16])

# fallback support
target = tvm.target.Target.current()
Expand All @@ -70,6 +71,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
vthread_n = cfg["vthread_n"].val
vthread_c = cfg["vthread_c"].val
step = cfg["step"].val
vec_factor = cfg["vectorize"].val
block_factor_c = tile_c * num_thread_c * vthread_c

offset = 8
Expand All @@ -86,7 +88,9 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):

# Schedule for output
ni, hi, wi, fi = s[output].op.axis
bz = s[output].fuse(hi, wi)
bz = wi
fi, vec = s[output].split(fi, factor=vec_factor)
s[output].vectorize(vec)
tx, fi = s[output].split(fi, factor=tile_c)
txz, tx = s[output].split(tx, factor=num_thread_c)
bx, txz = s[output].split(txz, factor=vthread_c)
Expand Down Expand Up @@ -125,6 +129,8 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
_, _, ic, o = s[WW].op.axis
t = s[WW].fuse(ic, o)
s[WW].storage_align(ic, W_align - 1, W_align)
t, vec = s[WW].split(t, factor=vec_factor)
s[WW].vectorize(vec)
ty, tx = s[WW].split(t, factor=num_thread_c)
_, ty = s[WW].split(ty, factor=num_thread_n)
s[WW].bind(tx, thread_x)
Expand Down

0 comments on commit dbbc678

Please sign in to comment.