From 46657ed1863364fc4ab010957490b526bdced74f Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 6 Oct 2017 08:59:26 -0700 Subject: [PATCH] Conv2d modified for better performance (#516) * conv2d tweaked for better end-to-end performance * syntax changed --- topi/python/topi/cuda/conv2d_nchw.py | 43 +++++++++++++++++----------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 4987f8d6fef2..8e0f22781c1d 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -66,9 +66,20 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): """Schedule conv2d for specific feature_in_out_filter pattern""" if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]): - num_thread_x = 8 + mark = util.get_const_int(Out.shape[2]) * util.get_const_int(Out.shape[3]) + num_thread_x = 0 + if mark % 8 == 0 and mark % 7 == 0: + num_thread_x = 8 + vthread_x = 7 + else: + for i in range(5, mark): + if mark % i == 0 and num_thread_x == 0: + vthread_x = i + mark = mark // i + if mark % i == 0 and vthread_x > 0: + num_thread_x = i + break num_thread_y = 8 - vthread_x = 7 vthread_y = 2 ifactor = 8 @@ -80,20 +91,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy") i, oc, h, w = s[Out].op.axis - oh, ih = s[Out].split(h, nparts=vthread_x) - w = s[Out].fuse(ih, w) + w = s[Out].fuse(h, w) + ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x) ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y) - ow, iw = s[Out].split(w, factor=num_thread_x) + oiw, iiw = s[Out].split(iw, nparts=vthread_x) oioc, iioc = s[Out].split(ioc, nparts=vthread_y) - s[Out].reorder(i, ooc, oh, oioc, ow, iioc, iw) - s[Out].bind(iw, thread_x) + s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw) + s[Out].bind(iiw, thread_x) s[Out].bind(iioc, thread_y) - s[Out].bind(ow, thread_xz) + s[Out].bind(oiw, thread_xz) s[Out].bind(oioc, thread_yz) - s[Out].bind(oh, block_x) + s[Out].bind(ow, block_x) s[Out].bind(ooc, block_y) - s[Out_L].compute_at(s[Out], iw) + s[Out_L].compute_at(s[Out], iiw) # schedule Out_L local write i, oc, h, w = s[Out_L].op.axis @@ -260,9 +271,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): else: # scheduler params - vthread_x = min(8, util.get_const_int(Out.shape[2])) + vthread_x = util.get_const_int(Out.shape[2]) num_thread_x = 16 - num_thread_y = min(8, util.get_const_int(Out.shape[3])) + num_thread_y = util.get_const_int(Out.shape[3]) ofactor = 8 block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") @@ -271,12 +282,10 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): i, oc, h, w = s[Out].op.axis ooc, ioc = s[Out].split(oc, factor=num_thread_x) - oh, ih = s[Out].split(h, factor=vthread_x) - ow, iw = s[Out].split(w, factor=num_thread_y) - s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc) + s[Out].reorder(i, ooc, h, w, ioc) s[Out].bind(ioc, thread_x) - s[Out].bind(iw, thread_y) - s[Out].bind(ih, thread_xz) + s[Out].bind(w, thread_y) + s[Out].bind(h, thread_xz) s[Out].bind(ooc, block_x) s[Out_L].compute_at(s[Out], ioc)