Skip to content

Commit

Permalink
Conv2d modified for better performance (#516)
Browse files Browse the repository at this point in the history
* conv2d tweaked for better end-to-end performance

* syntax changed
  • Loading branch information
Laurawly authored and tqchen committed Oct 6, 2017
1 parent 13970eb commit 46657ed
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions topi/python/topi/cuda/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,20 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]):
num_thread_x = 8
mark = util.get_const_int(Out.shape[2]) * util.get_const_int(Out.shape[3])
num_thread_x = 0
if mark % 8 == 0 and mark % 7 == 0:
num_thread_x = 8
vthread_x = 7
else:
for i in range(5, mark):
if mark % i == 0 and num_thread_x == 0:
vthread_x = i
mark = mark // i
if mark % i == 0 and vthread_x > 0:
num_thread_x = i
break
num_thread_y = 8
vthread_x = 7
vthread_y = 2
ifactor = 8

Expand All @@ -80,20 +91,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy")

i, oc, h, w = s[Out].op.axis
oh, ih = s[Out].split(h, nparts=vthread_x)
w = s[Out].fuse(ih, w)
w = s[Out].fuse(h, w)
ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x)
ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y)
ow, iw = s[Out].split(w, factor=num_thread_x)
oiw, iiw = s[Out].split(iw, nparts=vthread_x)
oioc, iioc = s[Out].split(ioc, nparts=vthread_y)
s[Out].reorder(i, ooc, oh, oioc, ow, iioc, iw)
s[Out].bind(iw, thread_x)
s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw)
s[Out].bind(iiw, thread_x)
s[Out].bind(iioc, thread_y)
s[Out].bind(ow, thread_xz)
s[Out].bind(oiw, thread_xz)
s[Out].bind(oioc, thread_yz)
s[Out].bind(oh, block_x)
s[Out].bind(ow, block_x)
s[Out].bind(ooc, block_y)

s[Out_L].compute_at(s[Out], iw)
s[Out_L].compute_at(s[Out], iiw)

# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
Expand Down Expand Up @@ -260,9 +271,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):

else:
# scheduler params
vthread_x = min(8, util.get_const_int(Out.shape[2]))
vthread_x = util.get_const_int(Out.shape[2])
num_thread_x = 16
num_thread_y = min(8, util.get_const_int(Out.shape[3]))
num_thread_y = util.get_const_int(Out.shape[3])
ofactor = 8
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
Expand All @@ -271,12 +282,10 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):

i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=num_thread_x)
oh, ih = s[Out].split(h, factor=vthread_x)
ow, iw = s[Out].split(w, factor=num_thread_y)
s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc)
s[Out].reorder(i, ooc, h, w, ioc)
s[Out].bind(ioc, thread_x)
s[Out].bind(iw, thread_y)
s[Out].bind(ih, thread_xz)
s[Out].bind(w, thread_y)
s[Out].bind(h, thread_xz)
s[Out].bind(ooc, block_x)

s[Out_L].compute_at(s[Out], ioc)
Expand Down

0 comments on commit 46657ed

Please sign in to comment.