diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index a0c992b07347..2b0fccf503af 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None): workload = tuple([args_to_workload(a) for a in x]) elif isinstance(x, (str, int, float, np.int, np.float)): workload = x - elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): + elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)): workload = x.value elif x is None: workload = 0 diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index c184c6b46998..66fa07510237 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -68,6 +68,8 @@ def __init__(self): topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", topi.nn.dense: "topi_nn_dense", + topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw", + topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc", topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw", } @@ -79,6 +81,8 @@ def __init__(self): topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], topi.nn.dense: [topi.generic.schedule_dense], + topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw], + topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc], topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], } @@ -174,6 +178,24 @@ def _topi_nn_dense(*args, **kwargs): return s, [data, weight, bias, C] return s, [data, weight, C] + @register("topi_nn_bitserial_conv2d_nhwc") + def _topi_bitserial_conv2d_nhwc(*args, **kwargs): + args = deserialize_args(args) + C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs) + s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C]) + data = args[0] + kernel = args[1] + return s, [data, kernel, C] + + @register("topi_nn_bitserial_conv2d_nchw") + def _topi_bitserial_conv2d_nchw(*args, **kwargs): + args = deserialize_args(args) + C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs) + s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C]) + data = args[0] + kernel = args[1] + return s, [data, kernel, C] + @register("topi_nn_deformable_conv2d_nchw") def _topi_nn_deformable_conv2d_nchw(*args, **kwargs): assert not kwargs, "Do not support kwargs in template function call" diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index ffef3ce81b98..1bfca4e5580c 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -1,100 +1,44 @@ # pylint: disable=invalid-name,unused-variable,invalid-name -"""Bitserial conv2d schedule on raspberry pi""" +"""Bitserial conv2d schedule on arm cpu""" from __future__ import absolute_import as _abs -from collections import namedtuple import tvm +from tvm import autotvm from .. import tag from ..nn.pad import pad -from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload, bitpack -from ..nn.bitserial_conv2d import SpatialPackNCHW, _WORKLOADS, spatial_pack_nchw +from ..nn.bitserial_conv2d import bitpack, bitserial_conv2d_nhwc from ..nn.util import get_pad_tuple -from ..util import get_const_int +from ..util import get_const_int, get_const_tuple from .. import generic -RaspSpatialPack = namedtuple('SpatialPack', - ['vh', 'vw', 'vc', 'ba', 'bc', 'split_ci', 'kfactor']) - -_QUANTIZED_SCHEDULES_NHWC = [ - RaspSpatialPack(2, 2, 8, 1, 1, False, 8), - RaspSpatialPack(1, 4, 8, 4, 1, False, 8), - RaspSpatialPack(1, 4, 8, 1, 16, False, 8), - RaspSpatialPack(1, 4, 8, 4, 8, False, 8), - RaspSpatialPack(1, 7, 8, 3, 8, False, 16), - RaspSpatialPack(1, 2, 8, 1, 8, False, 16), - RaspSpatialPack(2, 1, 8, 1, 4, False, 16), - RaspSpatialPack(1, 7, 8, 1, 1, True, 16), - RaspSpatialPack(1, 1, 8, 1, 16, True, 16), - RaspSpatialPack(1, 1, 8, 1, 8, True, 16), - RaspSpatialPack(1, 1, 8, 1, 16, True, 16), -] - -_QUANTIZED_SCHEDULES_NCHW = [ - # resnet - SpatialPackNCHW(2, 2, 8, 1, 1), - SpatialPackNCHW(1, 4, 8, 4, 1), - SpatialPackNCHW(1, 4, 8, 1, 16), - SpatialPackNCHW(1, 4, 8, 4, 8), - SpatialPackNCHW(1, 7, 8, 3, 8), - SpatialPackNCHW(1, 2, 8, 1, 8), - SpatialPackNCHW(2, 1, 8, 1, 4), - SpatialPackNCHW(1, 7, 8, 1, 1), - SpatialPackNCHW(1, 1, 8, 1, 16), - SpatialPackNCHW(1, 1, 8, 1, 8), - SpatialPackNCHW(1, 1, 8, 1, 16), -] - -@_get_schedule.register("arm_cpu") -def _get_schedule_bitserial_conv2d(wkl, layout): - if wkl not in _WORKLOADS: - raise ValueError("no schedule for such workload: {}".format(wkl)) - idx = _WORKLOADS.index(wkl) - if layout == "NCHW": - sch = _QUANTIZED_SCHEDULES_NCHW[idx] - elif layout == "NHWC": - sch = _QUANTIZED_SCHEDULES_NHWC[idx] - return sch - - -@bitserial_conv2d.register("arm_cpu") -def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits, - layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False): - if out_dtype is None: - out_dtype = data.dtype - assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" - assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC" - if dorefa: - assert layout == "NCHW", "Cannot support dorea with NHWC layout yet" - wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout) - sch = _get_schedule(wkl, layout) - if layout == "NCHW": - return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits, - pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) - return _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, - weight_bits, out_dtype) - -def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC): - kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8') +def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC, use_bitpack=True): + if use_bitpack: + kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8') + else: + kernel_q = kernel KH, KW, KB, CI, CO = kernel_q.shape kvshape = (CO//VC, KH, KW, KB, VC, CI) return tvm.compute(kvshape, lambda co, dh, dw, b, vc, ci: \ kernel_q[dh][dw][b][ci][co*VC+vc], name='kernel_vec') -def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, out_dtype): +@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'arm_cpu', 'direct') +def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weight_bits, + pack_dtype, out_dtype, unipolar): """ Compute convolution with pack on spatial axes. """ assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" - wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC") - sch = _get_schedule(wkl, "NHWC") - VH = sch.vh - VW = sch.vw - VC = sch.vc + assert pack_dtype == 'uint8', "only support packing into uint8 bits" + assert out_dtype == 'int16', "only support output type of int16" - data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8') - kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC) - N, H, W, IB, CI = data_q.shape - OCO, KH, KW, KB, VC, _ = kernel_vec.shape + N, H, W, CI = get_const_tuple(data.shape) + if len(kernel.shape) == 4: + KH, KW, _, CO = get_const_tuple(kernel.shape) + CI_packed = CI // 8 + else: + KH, KW, KB, CI_packed, CO = get_const_tuple(kernel.shape) - CO = OCO * VC - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): + TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) + else: + TPAD, LPAD, DPAD, RPAD = padding if isinstance(stride, (tuple, list)): HSTR, WSTR = stride @@ -102,75 +46,151 @@ def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bi HSTR, WSTR = stride, stride HCAT, WCAT = KH-1, KW-1 - PAD_H = H + 2*HPAD - PAD_W = W + 2*WPAD - OH = (H + 2*HPAD - KH) // HSTR + 1 - OW = (W + 2*WPAD - KW) // WSTR + 1 + PAD_H = H + (TPAD + DPAD) + PAD_W = W + (LPAD + RPAD) + OH = (PAD_H - KH) // HSTR + 1 + OW = (PAD_W - KW) // WSTR + 1 + oshape = (1, OH, OW, CO) + + # Pad input channels of weights and data when it is not a multiple of 8 + if CI_packed % 8 != 0: + CI_PAD = CI_packed % 8 + CI_packed += CI_PAD + else: + CI_PAD = 0 + + # ==================== define configuration space ==================== + n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO) + ci, kh, kw = cfg.reduce_axis(CI_packed), cfg.reduce_axis(KH), cfg.reduce_axis(KW) + ib, kb = cfg.reduce_axis(activation_bits), cfg.reduce_axis(weight_bits) + + co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, + filter=lambda x: x.size[-1] == 8) + oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, + filter=lambda x: x.size[-1] >= 2) + ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, + filter=lambda x: x.size[-1] >= 2) + ci_o, ci_i = cfg.define_split("tile_ci", ci, num_outputs=2, + filter=lambda x: x.size[-1] == 8 or x.size[-1] == 16) + re_axes = cfg.define_reorder("reorder_0", + [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i], + policy='candidate', candidate=[ + [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i], + [n, oh, ow, co, vh, vw, kw, kh, ci_o, kb, ib, vc, ci_i],]) + cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops + # ==================== + + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] + + data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8') + + kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4) + if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0: + kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD]) + + N, H, W, IB, CI = data_q.shape + OCO, KH, KW, KB, VC, CI = kernel_vec.shape + dvshape = (N, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, IB, CI) ovshape = (1, OH // VH, OW // VW, CO // VC, VH, VW, VC) - oshape = (1, OH, OW, CO) - if (HPAD != 0 and WPAD != 0): - data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad") + if (TPAD != 0 and RPAD != 0): + data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, CI_PAD), name="data_pad") + elif CI_PAD != 0: + data_pad = pad(data_q, (0, 0, 0, 0, 0), (0, 0, 0, 0, CI_PAD), name="data_pad") else: data_pad = data_q data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, b, ci: \ data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][b][ci], name='data_vec') - ci = tvm.reduce_axis((0, CI), name='ci') dh = tvm.reduce_axis((0, KH), name='dh') dw = tvm.reduce_axis((0, KW), name='dw') ib = tvm.reduce_axis((0, IB), name='ib') kb = tvm.reduce_axis((0, KB), name='kb') - def _conv(n, h, w, co, vh, vw, vc): + def _bipolar_conv(n, h, w, co, vh, vw, vc): return tvm.sum((tvm.popcount( kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') & data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16')) << (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci]) + def _unipolar_conv(n, h, w, co, vh, vw, vc): + return tvm.sum( + ((tvm.popcount(kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') & + data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('int16')) - + tvm.popcount(~kernel_vec[co, dh, dw, kb, vc, ci].astype('int16') & + data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci]).astype('int16')) + << (kb + ib).astype('int16')), axis=[dh, dw, kb, ib, ci]) + if unipolar: + conv_vec = tvm.compute(ovshape, _unipolar_conv, name='conv_vec', tag='unipolar') + else: + conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar') - conv = tvm.compute(ovshape, _conv, name='conv') + conv = tvm.compute(oshape, lambda n, h, w, co: + conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype), + name='conv', tag='spatial_bitserial_conv_nhwc') - return tvm.compute(oshape, lambda n, h, w, co: - conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype), - name='output_vec', tag='spatial_bitserial_conv_nhwc') + return conv -def _intrin_popcount(m, k_i, w_b, x_b): - dtype = 'uint8' - w = tvm.placeholder((w_b, m, k_i), dtype=dtype, name='w') - x = tvm.placeholder((x_b, k_i,), dtype=dtype, name='x') +def _intrin_popcount(m, k_i, w_b, x_b, unipolar): + pack_dtype = 'uint8' + w = tvm.placeholder((w_b, m, k_i), dtype=pack_dtype, name='w') + x = tvm.placeholder((x_b, k_i,), dtype=pack_dtype, name='x') k = tvm.reduce_axis((0, k_i), name='k') bw = tvm.reduce_axis((0, w_b), name='bw') bx = tvm.reduce_axis((0, x_b), name='bx') - z = tvm.compute((m,), lambda i: - tvm.sum(tvm.popcount(w[bw, i, k].astype('uint16') & - x[bx, k].astype('uint16')) - << (bw+bx).astype('uint16'), axis=[bw, bx, k]), name='z') - + if unipolar: + dtype = 'int16' + z = tvm.compute((m,), lambda i: + tvm.sum((tvm.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype)) - + tvm.popcount(~w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype))) + << (bw+bx).astype(dtype), axis=[bw, bx, k]), name='z') + else: + dtype = 'uint16' + z = tvm.compute((m,), lambda i: + tvm.sum(tvm.popcount(w[bw, i, k].astype(dtype) & x[bx, k].astype(dtype)) + << (bw+bx).astype(dtype), axis=[bw, bx, k]), name='z') Wb = tvm.decl_buffer(w.shape, w.dtype, name="W", offset_factor=k_i, - strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) + strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) # stride can be inferred Xb = tvm.decl_buffer(x.shape, x.dtype, name="X", offset_factor=k_i, strides=[tvm.var('ldw'), 1]) + Zb = tvm.decl_buffer(z.shape, z.dtype, + name="Z", + offset_factor=1, + strides=[1]) def _intrin_func(ins, outs): ww, xx = ins zz = outs[0] - vpadd = "llvm.arm.neon.vpadd.v8u8" - vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16" + args_1 = tvm.const(1, 'uint32') args_2 = tvm.const(2, 'uint32') + if unipolar: + vpadd = "llvm.arm.neon.vpadd.v8i8" + vpadalu = "llvm.arm.neon.vpadals.v16i8.v8i16" + full_dtype = 'int8x16' + half_dtype = 'int8x8' + return_dtype = 'int16x8' + else: + vpadd = "llvm.arm.neon.vpadd.v8u8" + vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16" + full_dtype = 'uint8x16' + half_dtype = 'uint8x8' + return_dtype = 'uint16x8' + def _instr(index): irb = tvm.ir_builder.create() - if index == 1: - irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) + if index == 1: # reduce reset + irb.emit(zz.vstore(0, tvm.const(0, return_dtype))) return irb.get() - + # body and reduce update cnts8 = [None] * 8 cnts4 = [None] * 4 cnts2 = [None] * 2 @@ -178,154 +198,108 @@ def _instr(index): for bx in range(x_b): if k_i == 16: for i in range(m): - ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16') - cnts = tvm.popcount(ands) - upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts) - lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts) + w_ = ww.vload([bw, i, 0], 'uint8x16').astype(full_dtype) + x_ = xx.vload([bx, 0], 'uint8x16').astype(full_dtype) + if unipolar: + cnts = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) + else: + cnts = tvm.popcount(w_ & x_) + upper_half = tvm.call_pure_intrin(half_dtype, 'vectorhigh', cnts) + lower_half = tvm.call_pure_intrin(half_dtype, 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): - cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, + cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): - cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, + cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) - cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) - shifted_cnts = cnts << tvm.const(bw+bx, dtype) - out = tvm.call_llvm_intrin('uint16x8', vpadalu, - args_2, zz.vload(0, 'uint16x8'), shifted_cnts) + cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) + shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype) + out = tvm.call_llvm_intrin(return_dtype, vpadalu, + args_2, zz.vload(0, return_dtype), shifted_cnts) else: # ki == 8 for i in range(m): - ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8') - cnts8[i] = tvm.popcount(ands) + w_ = ww.vload([bw, i, 0], 'uint8x8').astype(half_dtype) + x_ = xx.vload([bx, 0], 'uint8x8').astype(half_dtype) + if unipolar: + cnts8[i] = tvm.popcount(w_ & x_) - tvm.popcount(~w_ & x_) + else: + cnts8[i] = tvm.popcount(w_ & x_) for i in range(m//2): - cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, + cnts4[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts8[i*2], cnts8[i*2+1]) for i in range(m//4): - cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, + cnts2[i] = tvm.call_llvm_intrin(half_dtype, vpadd, args_1, cnts4[i*2], cnts4[i*2+1]) - cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) - shifted_cnts = cnts << tvm.const(bw+bx, dtype) - out = tvm.call_llvm_intrin('uint16x8', vpadalu, - args_2, zz.vload(0, 'uint16x8'), shifted_cnts) + cnts = tvm.call_pure_intrin(full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) + shifted_cnts = cnts << tvm.const(bw+bx, pack_dtype) + out = tvm.call_llvm_intrin(return_dtype, vpadalu, + args_2, zz.vload(0, return_dtype), shifted_cnts) irb.emit(zz.vstore(0, out)) return irb.get() # body, reset, update return _instr(0), _instr(1), _instr(2) with tvm.build_config(offset_factor=1, partition_const_loop=True): - return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb}) + return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb, z:Zb}) # ARM specific schedule that using custom microkernel -def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, - kernel, kernel_q, kernel_vec, - conv_out, output, last): - # no stride and padding info here - _, H, W, IB, CI = data_q.shape - KH, KW, KB, _, CO = kernel_q.shape +def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec, + conv_out, output, last, unipolar): + _, _, _, _, _, IB, CI = data_vec.shape + _, KH, KW, KB, _, _ = kernel_vec.shape KB = get_const_int(KB) IB = get_const_int(IB) - if data_pad is None: - padding = (0, 0) - _, in_h, in_w, _, _ = data_q.shape - kern_h, kern_w, _, _ = kernel.shape - _, out_h, out_w, _ = output.shape - hstride = (in_h - kern_h) // (out_h - 1) - wstride = (in_w - kern_w) // (out_w - 1) - stride = get_const_int(hstride), get_const_int(wstride) - else: - _, in_h, in_w, _, _ = data_q.shape - _, pad_h, pad_w, _, _ = data_pad.shape - hpad = (pad_h - in_h) // 2 - wpad = (pad_w - in_w) // 2 - padding = get_const_int(hpad), get_const_int(wpad) - - _, in_h, in_w, _, _ = data_pad.shape - kern_h, kern_w, _, _ = kernel.shape - _, out_h, out_w, _ = output.shape - hstride = (in_h - kern_h) // (out_h - 1) - wstride = (in_w - kern_w) // (out_w - 1) - stride = get_const_int(hstride), get_const_int(wstride) - - wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NHWC") - sch = _get_schedule(wkl, "NHWC") - - VH = sch.vh - VW = sch.vw - VC = sch.vc - ba = sch.ba - bc = sch.bc - - ##### Schedule data packing + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] + + ##### Schedule data padding and packing if data_pad is not None: s[data_pad].compute_inline() _, h, _, _, _, _, _ = s[data_vec].op.axis - if ba == 1: - oaxis = h - paxis = h - else: - oh, ih = s[data_vec].split(h, ba) - oaxis = oh - paxis = ih - - s[data_vec].parallel(paxis) - s[data_vec].pragma(oaxis, "parallel_launch_point") - s[data_vec].pragma(paxis, "parallel_stride_pattern") - s[data_vec].pragma(oaxis, "parallel_barrier_when_finish") + cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) + oh, ih = cfg["tile_ah"].apply(s, data_vec, h) + s[data_vec].parallel(oh) - ##### Schedule kernel packing + #### Schedule kernel packing co, _, _, _, _, _ = s[kernel_vec].op.axis - if bc == 1: - oaxis = co - paxis = co - else: - oco, ico = s[kernel_vec].split(co, bc) - oaxis = oco - paxis = ico - - s[kernel_vec].parallel(paxis) - s[kernel_vec].pragma(oaxis, "parallel_launch_point") - s[kernel_vec].pragma(paxis, "parallel_stride_pattern") - s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish") + cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) + oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) + s[kernel_vec].parallel(oco) ##### Schedule Convolution n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis - dh, dw, kb, ib, ci = s[conv_out].op.reduce_axis + kh, kw, kb, ib, ci = s[conv_out].op.reduce_axis - kfactor = sch.kfactor - if sch.split_ci: - oci, ici = s[conv_out].split(ci, kfactor) - s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, oci, kb, ib, vc, ici) - else: - s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, kb, ib, vc, ci) + ci_o, ci_i = cfg['tile_ci'].apply(s, conv_out, ci) + re_axes = cfg["reorder_0"].apply(s, conv_out, + [n, oh, ow, co, vh, vw, kh, kw, ci_o, kb, ib, vc, ci_i]) - pc = _intrin_popcount(8, kfactor, KB, IB) - s[conv_out].tensorize(kb, pc) + # Use microkernel + kfactor = cfg['tile_ci'].size[1] + if kfactor % 8 == 0: + pc = _intrin_popcount(VC, kfactor, KB, IB, unipolar) + s[conv_out].tensorize(kb, pc) n, h, w, co = s[last].op.axis - co, vc = s[last].split(co, VC) - oh, ow, vh, vw = s[last].tile(h, w, VH, VW) - s[last].reorder(n, oh, ow, co, vc, vh, vw) - s[last].vectorize(vw) + co, vc = cfg['tile_co'].apply(s, last, co) + oh, vh = cfg['tile_oh'].apply(s, last, h) + ow, vw = cfg['tile_ow'].apply(s, last, w) + s[last].reorder(n, oh, ow, co, vh, vw, vc) + s[last].vectorize(vc) if last != output: s[last].compute_inline() - s[conv_out].compute_at(s[last], ow) - if co == 1: - oaxis = oh - paxis = oh - else: - oho, iho = s[last].split(oh, bc) - oaxis = oho - paxis = iho - - s[last].parallel(paxis) + s[conv_out].compute_at(s[last], co) + s[last].parallel(oh) s = s.normalize() return s -@generic.schedule_bitserial_conv2d_nhwc.register(["arm_cpu"]) -def schedule_bitserial_conv2d_nhwc(outs): - """Raspverry pi schedule for bitserial conv2d""" +@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, 'arm_cpu', 'direct') +def schedule_bitserial_conv2d_nhwc(cfg, outs): + """Arm cpu schedule for bitserial conv2d""" s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] @@ -344,10 +318,6 @@ def traverse(op): conv_out = op.input_tensors[0] kernel_vec = conv_out.op.input_tensors[0] kernel_q = kernel_vec.op.input_tensors[0] - kernel = kernel_q.op.input_tensors[0] - if "QuantizeInput" in kernel.op.name: - # Need to go up 1 further, from the combine in bitpack - kernel = kernel.op.input_tensors[0] data_vec = conv_out.op.input_tensors[1] data_q = data_vec.op.input_tensors[0] data = data_q.op.input_tensors[0] @@ -355,13 +325,10 @@ def traverse(op): if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag: data_pad = data_q data_q = data - data = data_q.op.input_tensors[0] - if "QuantizeInput" in data.op.name: - # Need to go up 1 further, from the combine in bitpack data = data.op.input_tensors[0] - - _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, - kernel, kernel_q, kernel_vec, conv_out, output, outs[0]) + unipolar = "unipolar" in conv_out.op.tag + _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec, + conv_out, output, outs[0], unipolar) scheduled_ops.append(op) traverse(outs[0].op) diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index d41a99a04a9d..431f51777201 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -1,70 +1,32 @@ # pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument """Bitserial Conv2D operators""" from __future__ import absolute_import as _abs -from collections import namedtuple import numpy as np import tvm +from tvm import autotvm from topi.transform import concatenate from .pad import pad from .util import get_pad_tuple from ..util import get_const_tuple, get_const_int -# workload description of conv2d -Workload = namedtuple('Workload', - ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) - -SpatialPackNCHW = namedtuple('SpatialPack', - ['vh', 'vw', 'vc', 'ba', 'bc']) - -SpatialPackNHWC = namedtuple('SpatialPack', - ['vh', 'vw', 'vc', 'ba', 'bc']) - -_WORKLOADS = [ - # workloads of resnet18 on imagenet - # input_size, input_size, ic, oc, kh, kw, pad, pad, stride, stride - Workload('uint32', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - Workload('uint32', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - Workload('uint32', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - Workload('uint32', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - Workload('uint32', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - Workload('uint32', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - Workload('uint32', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - Workload('uint32', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - Workload('uint32', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - Workload('uint32', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - Workload('uint32', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), - - # workload of alexnet on cifar10 - Workload('int32', 'int32', 27, 27, 96, 192, 5, 5, 2, 2, 1, 1), - Workload('int32', 'int32', 13, 13, 192, 384, 3, 3, 1, 1, 1, 1), - Workload('int32', 'int32', 13, 13, 384, 384, 3, 3, 1, 1, 1, 1), - Workload('int32', 'int32', 13, 13, 384, 256, 3, 3, 1, 1, 1, 1), -] - @tvm.target.generic_func -def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits, - layout='NCHW', pack_dtype='uint32', out_dtype='int32', dorefa=True): +def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight_bits, + pack_dtype='uint32', out_dtype='int16', unipolar=True): """Bitserial Conv2D operator. Parameters ---------- input : tvm.Tensor - 4-D with shape [batch, in_channel, in_height, in_width] or - [batch, in_height, in_width, in_channel] + 4-D with shape [batch, in_channel, in_height, in_width] filter : tvm.Tensor - 4-D with shape [num_filter, in_channel, filter_height, filter_width] or - [filter_height, filter_width, in_channel, num_filter] + 4-D with shape [num_filter, in_channel, filter_height, filter_width] stride : int or a list/tuple of two ints stride size, or [stride_height, stride_width] - padding : int or a list/tuple of two ints - padding size, or [pad_height, pad_width] - - layout : str - layout of data + padding : int or a list/tuple of two or four ints + padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right] activation_bits: int number of bits used for activations/input elements @@ -78,63 +40,184 @@ def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits pack_dtype: str bit packing type - dorefa: bool - preform the bitserial dot-product using 2 popcounts (required for DoReFa-Net) + unipolar: bool + if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format Returns ------- output : tvm.Tensor - 4-D with shape [batch, out_channel, out_height, out_width] or - [batch, out_height, out_width, out_channel] + 4-D with shape [batch, out_channel, out_height, out_width] """ - # search platform specific declaration first - # default declaration - if layout == 'NCHW': - return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits, - pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) - if layout == 'NHWC': - return spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, - pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) - raise ValueError("not support this layout {} yet".format(layout)) - -def _get_workload(data, kernel, stride, padding, out_dtype, layout): - """ Get the workload structure. """ - assert layout in ("NCHW", "NHWC"), \ - "Only support layouts NCHW and NHWC" - if layout == "NCHW": - _, CI, IH, IW = [x.value for x in data.shape] - CO, _, KH, KW = [x.value for x in kernel.shape] - else: # NHWC - IH, IW = data.shape[1].value, data.shape[2].value - KH, KW, CI, CO = [x for x in get_const_tuple(kernel.shape)] - - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride + assert isinstance(stride, int) or len(stride) == 2 + Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype) + Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) + batch, in_channel, activation_bits, in_height, in_width = Input_q.shape + num_filter, channel, kernel_h, kernel_w, weight_bits = Filter_q.shape + + if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): + TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) else: - HSTR, WSTR = stride, stride + TPAD, LPAD, DPAD, RPAD = padding + pad_before = [0, 0, 0, TPAD, LPAD] + pad_after = [0, 0, 0, DPAD, RPAD] + + PadInput_q = pad(Input_q, pad_before, pad_after, name="pad_temp") + # compute the output shape + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + out_channel = num_filter + out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1 + out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1 + + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + b1 = tvm.reduce_axis((0, activation_bits), name='b1') + b2 = tvm.reduce_axis((0, weight_bits), name='b2') + + if unipolar: + def _conv(nn, ff, yy, xx): + b1b2 = (b1+b2).astype(out_dtype) + return tvm.sum( + ((tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & + Filter_q[ff, rc, ry, rx, b2]) - + tvm.popcount(PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & + ~Filter_q[ff, rc, ry, rx, b2])) + << (b1b2)).astype(out_dtype), + axis=[rc, ry, rx, b2, b1]).astype(out_dtype) + else: + def _conv(nn, ff, yy, xx): + b1b2 = (b1+b2).astype(out_dtype) + return tvm.sum((tvm.popcount( + PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & + Filter_q[ff, rc, ry, rx, b2])<< (b1b2)).astype(out_dtype), + axis=[rc, ry, rx, b2, b1]).astype(out_dtype) - return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) + return tvm.compute((batch, out_channel, out_height, out_width), _conv, + name="Conv2dOutput", tag="bitserial_conv2d_nchw") @tvm.target.generic_func -def _get_schedule(wkl, layout): - # pylint: disable=unreachable - """ Get the platform specific schedule. """ - target = tvm.target.current_target() - raise RuntimeError( - "No schedule for current target:{}".format(target)) - # This return has no use, merely to supress pylint warning - return wkl - -def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits, - pack_dtype, out_dtype, dorefa=False): +def bitserial_conv2d_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, + pack_dtype='uint32', out_dtype='int16', unipolar=True): + """Bitserial Conv2D operator. + + Parameters + ---------- + input : tvm.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + + filter : tvm.Tensor + 4-D with shape [filter_height, filter_width, in_channel, num_filter] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of two or four ints + padding size, [pad_height, pad_width], [pad_top, pad_left, pad_down, pad_right] + + activation_bits: int + number of bits used for activations/input elements + + weight_bits: int + number of bits used for weight elements + + out_dtype: str + return type of convolution + + pack_dtype: str + bit packing type + + unipolar: bool + if binarization style is in unipolar 1/0 format, instead of bipolar -1/+1 format + + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + assert isinstance(stride, int) or len(stride) == 2 + Input_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype) + if len(kernel.shape) == 4: + Filter_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype) + kernel_h, kernel_w, _, num_filter, _ = get_const_tuple(Filter_q.shape) + else: + Filter_q = kernel + kernel_h, kernel_w, _, _, num_filter = get_const_tuple(Filter_q.shape) + batch, in_height, in_width, in_channel_q, _ = get_const_tuple(Input_q.shape) + + if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): + TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) + else: + TPAD, LPAD, DPAD, RPAD = padding + pad_before = [0, TPAD, LPAD, 0, 0] + pad_after = [0, DPAD, RPAD, 0, 0] + + # compute the output shape + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + out_channel = num_filter + out_height = (in_height - kernel_h + TPAD + DPAD) // stride_h + 1 + out_width = (in_width - kernel_w + LPAD + RPAD) // stride_w + 1 + PadInput_q = pad(Input_q, pad_before, pad_after, name="PaddedInput") + + rc = tvm.reduce_axis((0, in_channel_q), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + b1 = tvm.reduce_axis((0, activation_bits), name='b1') + b2 = tvm.reduce_axis((0, weight_bits), name='b2') + + if unipolar: + def _conv(nn, yy, xx, ff): + b1b2 = (b1+b2).astype(out_dtype) + return tvm.sum( + ((tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & + Filter_q[ry, rx, rc, ff, b2]) - + tvm.popcount(PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & + ~Filter_q[ry, rx, rc, ff, b2])) + << b1b2).astype(out_dtype), + axis=[rc, ry, rx, b2, b1]) + + else: + def _conv(nn, yy, xx, ff): + b1b2 = (b1+b2).astype(out_dtype) + return tvm.sum((tvm.popcount( + PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & + Filter_q[ry, rx, rc, ff, b2]) << b1b2).astype(out_dtype), + axis=[rc, ry, rx, b2, b1]) + + conv = tvm.compute((batch, out_height, out_width, out_channel), _conv, + name="Conv2dOutput", tag="bitserial_conv2d_nhwc") + + return conv + +@autotvm.register_topi_compute(bitserial_conv2d_nchw, ['cpu', 'arm_cpu'], 'direct') +def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, + pack_dtype='uint32', out_dtype='int16', unipolar=True): """ Compute convolution with pack on spatial axes. """ assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) - kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) - IB, _, CI, H, W = data_q.shape - KB, CO, _, KH, KW = kernel_q.shape - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + # Check if kernel is already bitpacked + if len(kernel.shape) == 4: + kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) + KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape) + else: + kernel_vec = kernel + OCO, _, KH, KW, KB, VC = get_const_tuple(kernel_vec.shape) + CO = OCO * VC + + IB, N, CI, H, W = get_const_tuple(data_q.shape) + KB, CO, _, KH, KW = get_const_tuple(kernel_q.shape) + + if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): + TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) + else: + TPAD, LPAD, DPAD, RPAD = padding + pad_before = [0, 0, 0, TPAD, LPAD] + pad_after = [0, 0, 0, DPAD, RPAD] if isinstance(stride, (tuple, list)): HSTR, WSTR = stride @@ -142,38 +225,50 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits, HSTR, WSTR = stride, stride HCAT, WCAT = KH-1, KW-1 - wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NCHW") - sch = _get_schedule(wkl, "NCHW") - VH = sch.vh - VW = sch.vw - VC = sch.vc - - TH = H + 2*HPAD - TW = W + 2*WPAD - OH = (H + 2*HPAD - KH) // HSTR + 1 - OW = (W + 2*WPAD - KW) // WSTR + 1 + TH = H + TPAD + DPAD + TW = W + LPAD + RPAD + OH = (H + TPAD + DPAD - KH) // HSTR + 1 + OW = (W + LPAD + RPAD - KW) // WSTR + 1 + + # ==================== define configuration space ==================== + n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW) + ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) + ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits) + + co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, + filter=lambda x: max(x.size[1:]) <= 16) + oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, + filter=lambda x: max(x.size[1:]) <= 16) + ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, + filter=lambda x: max(x.size[1:]) <= 16) + cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') + + re_axes = cfg.define_reorder("reorder_0", + [n, co, oh, ow, vc, vh, vw, kh, kw, kb, ib, ci], + policy='interval_all', interval=(6, 11)) + cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops + # ==================== + + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] - dshape = (IB, 1, CI, H, W) - dpshape = (IB, 1, CI, TH, TW) dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB) - - kshape = (KB, CO, CI, KH, KW) kvshape = (CO//VC, CI, KH, KW, KB, VC) - ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC) oshape = (1, CO, OH, OW) - DOPAD = (HPAD != 0 and WPAD != 0) - if DOPAD: - data_pad = pad(data_q, (0, 0, 0, HPAD, WPAD), name="data_pad") + if (TPAD != 0 and RPAD != 0): + data_pad = pad(data_q, (0, 0, 0, TPAD, LPAD), (0, 0, 0, DPAD, RPAD), name="data_pad") else: data_pad = data_q data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \ data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec') - kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \ - kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec') + if len(kernel.shape) == 4: + kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \ + kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec') ci = tvm.reduce_axis((0, CI), name='ci') dh = tvm.reduce_axis((0, KH), name='dh') @@ -183,7 +278,7 @@ def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits, def _conv(n, co, h, w, vh, vw, vc): b1b2 = (b1+b2).astype(out_dtype) - if dorefa: + if unipolar: return tvm.sum((tvm.popcount( data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) & kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) - @@ -203,15 +298,28 @@ def _conv(n, co, h, w, vh, vw, vc): conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], name='conv_vec', tag='spatial_bitserial_conv_nchw') -def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits, - pack_dtype, out_dtype, dorefa=False): +@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct') +def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, + pack_dtype='uint32', out_dtype='int16', unipolar=True): """ Compute convolution with pack on spatial axes. """ assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype) - kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype) - _, H, W, CI, IB = data_q.shape - KH, KW, _, CO, KB = kernel_q.shape - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + pack_kernel = len(kernel.shape) == 4 + + if pack_kernel: + kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype) + else: + kernel_q = kernel + + KH, KW, _, CO, KB = get_const_tuple(kernel_q.shape) + N, H, W, CI, IB = get_const_tuple(data_q.shape) + + if isinstance(padding, int) or (isinstance(padding, (tuple, list)) and len(padding) == 2): + TPAD, LPAD, DPAD, RPAD = get_pad_tuple(padding, kernel) + else: + TPAD, LPAD, DPAD, RPAD = padding + pad_before = [0, TPAD, LPAD, 0, 0] + pad_after = [0, DPAD, RPAD, 0, 0] if isinstance(stride, (tuple, list)): HSTR, WSTR = stride @@ -219,24 +327,41 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits, HSTR, WSTR = stride, stride HCAT, WCAT = KH-1, KW-1 - wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC") - sch = _get_schedule(wkl, "NHWC") - VH = sch.vh - VW = sch.vw - VC = sch.vc + PAD_H = H + (TPAD + DPAD) + PAD_W = W + (LPAD + RPAD) + OH = (PAD_H - KH) // HSTR + 1 + OW = (PAD_W - KW) // WSTR + 1 + oshape = (1, OH, OW, CO) - PAD_H = H + 2*HPAD - PAD_W = W + 2*WPAD - OH = (H + 2*HPAD - KH) // HSTR + 1 - OW = (W + 2*WPAD - KW) // WSTR + 1 + # ==================== define configuration space ==================== + n, oh, ow, co = cfg.axis(N), cfg.axis(OH), cfg.axis(OW), cfg.axis(CO) + ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) + ib, kb = cfg.reduce_axis(in_bits), cfg.reduce_axis(weight_bits) + + co, vc = cfg.define_split('tile_co', co, policy='all', num_outputs=2, + filter=lambda x: max(x.size[1:]) <= 16) + oh, vh = cfg.define_split('tile_oh', oh, policy='all', num_outputs=2, + filter=lambda x: max(x.size[1:]) <= 16) + ow, vw = cfg.define_split('tile_ow', ow, policy='all', num_outputs=2, + filter=lambda x: max(x.size[1:]) <= 16) + cfg.define_annotate('ann_reduce', [ib, kb, kh, kw], policy='try_unroll') + re_axes = cfg.define_reorder("reorder_0", + [n, oh, ow, co, vh, vw, kh, kw, kb, ib, vc, ci], + policy='interval_all', interval=(3, 7)) + cfg.add_flop(2 * N * OH * OW * CO * CI * 8 * KH * KW) # these are actually binary ops + # ==================== + + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB) kvshape = (CO, KH, KW, CI, VC, KB) ovshape = (1, OH, OW, CO, VH, VW, VC) oshape = (1, OH, OW, CO) - if (HPAD != 0 and WPAD != 0): - data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad") + if (DPAD != 0 and RPAD != 0): + data_pad = pad(data_q, (0, TPAD, LPAD, 0, 0), (0, DPAD, RPAD, 0, 0), name="data_pad") else: data_pad = data_q @@ -254,12 +379,12 @@ def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits, def _conv(n, h, w, co, vh, vw, vc): b1b2 = (b1+b2).astype(out_dtype) - if dorefa: + if unipolar: return tvm.sum( - (tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) & - kernel_vec[co, dh, dw, ci, vc, b2].astype(out_dtype)) - - tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) & - ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2, + ((tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] & + kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) - + tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1]& + ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2), axis=[dh, dw, ci, b1, b2]) return tvm.sum(tvm.popcount( @@ -273,6 +398,7 @@ def _conv(n, h, w, co, vh, vw, vc): conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC], name='output_unpack', tag='spatial_bitserial_conv_nhwc') + def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"): """Packs data into format necessary for bitserial computation pack_axis : int @@ -334,8 +460,3 @@ def _bitpack(*indices): if bits > 1: return concatenate(output_tuple, axis=bit_axis) return output_tuple - -_SCH_TO_DECL_FUNC_QUANT = { - SpatialPackNCHW: spatial_pack_nchw, - SpatialPackNHWC: spatial_pack_nhwc, -} diff --git a/topi/python/topi/x86/bitserial_conv2d.py b/topi/python/topi/x86/bitserial_conv2d.py index 327f15a49e07..67f773548464 100644 --- a/topi/python/topi/x86/bitserial_conv2d.py +++ b/topi/python/topi/x86/bitserial_conv2d.py @@ -1,74 +1,13 @@ # pylint: disable=invalid-name,unused-variable,invalid-name """Bitserial conv2d schedule on x86""" import tvm +from tvm import autotvm from topi.util import get_const_int from .. import generic, tag -from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload -from ..nn.bitserial_conv2d import SpatialPackNCHW, SpatialPackNHWC -from ..nn.bitserial_conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC_QUANT - -_QUANTIZED_SCHEDULES_NCHW = [ - # resnet - SpatialPackNCHW(2, 2, 8, 1, 1), - SpatialPackNCHW(1, 4, 8, 4, 1), - SpatialPackNCHW(1, 4, 8, 1, 16), - SpatialPackNCHW(1, 4, 8, 4, 8), - SpatialPackNCHW(1, 7, 8, 3, 8), - SpatialPackNCHW(1, 2, 8, 1, 8), - SpatialPackNCHW(2, 1, 8, 1, 4), - SpatialPackNCHW(1, 7, 8, 1, 1), - SpatialPackNCHW(1, 1, 8, 1, 16), - SpatialPackNCHW(1, 1, 8, 1, 8), - SpatialPackNCHW(1, 1, 8, 1, 16), - - SpatialPackNCHW(3, 3, 16, 3, 16), - SpatialPackNCHW(1, 1, 16, 2, 16), - SpatialPackNCHW(1, 1, 8, 1, 16), - SpatialPackNCHW(1, 1, 8, 1, 16), -] - -_QUANTIZED_SCHEDULES_NHWC = [ - # resnet - SpatialPackNHWC(2, 2, 8, 1, 1), - SpatialPackNHWC(1, 4, 8, 4, 1), - SpatialPackNHWC(1, 4, 8, 1, 16), - SpatialPackNHWC(1, 4, 8, 4, 8), - SpatialPackNHWC(1, 7, 8, 3, 8), - SpatialPackNHWC(1, 2, 8, 1, 8), - SpatialPackNHWC(2, 1, 8, 1, 4), - SpatialPackNHWC(1, 7, 8, 1, 1), - SpatialPackNHWC(1, 1, 8, 1, 16), - SpatialPackNHWC(1, 1, 8, 1, 8), - SpatialPackNHWC(1, 1, 8, 1, 16), -] - -@_get_schedule.register("cpu") -def _get_schedule_bitserial_conv2d(wkl, layout): - if wkl not in _WORKLOADS: - raise ValueError("no schedule for such workload: {}".format(wkl)) - idx = _WORKLOADS.index(wkl) - if layout == "NCHW": - sch = _QUANTIZED_SCHEDULES_NCHW[idx] - elif layout == "NHWC": - sch = _QUANTIZED_SCHEDULES_NHWC[idx] - return sch - -@bitserial_conv2d.register("cpu") -def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits, - layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False): - if out_dtype is None: - out_dtype = data.dtype - assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" - assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC" - - wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout) - sch = _get_schedule(wkl, layout) - return _SCH_TO_DECL_FUNC_QUANT[type(sch)](data, kernel, stride, padding, activation_bits, - weight_bits, pack_dtype, out_dtype, dorefa) - -@generic.schedule_bitserial_conv2d_nchw.register(["cpu"]) -@generic.schedule_bitserial_conv2d_nhwc.register(["cpu"]) -def schedule_bitserial_conv2d(outs): + +@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nchw, ['cpu'], 'direct') +@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nhwc, ['cpu'], 'direct') +def schedule_bitserial_conv2d(cfg, outs): """CPU schedule for bitserial convolutions NCHW and NHWC""" s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] @@ -88,7 +27,6 @@ def traverse(op): conv_out = op.input_tensors[0] kernel_vec = conv_out.op.input_tensors[1] kernel_q = kernel_vec.op.input_tensors[0] - kernel = kernel_q.op.input_tensors[0] data_vec = conv_out.op.input_tensors[0] data_q = data_vec.op.input_tensors[0] data = data_q.op.input_tensors[0] @@ -97,29 +35,27 @@ def traverse(op): data_pad = data_q data_q = data data = data_q.op.input_tensors[0] - if "QuantizeInput" in kernel.op.name: - # Need to go up 1 further, from the combine in bitpack - kernel = kernel.op.input_tensors[0] + if "QuantizeInput" in data.op.name: # Need to go up 1 further, from the combine in bitpack data = data.op.input_tensors[0] if 'spatial_bitserial_conv_nchw' in op.tag: - _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, - kernel, kernel_q, kernel_vec, - conv_out, output, outs[0]) + _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec, + kernel_q, kernel_vec, + conv_out, output, outs[0]) elif 'spatial_bitserial_conv_nhwc' in op.tag: - _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, - kernel, kernel_q, kernel_vec, - conv_out, output, outs[0]) + _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec, + kernel_q, kernel_vec, + conv_out, output, outs[0]) scheduled_ops.append(op) traverse(outs[0].op) return s -def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, - kernel, kernel_q, kernel_vec, - conv_out, output, last): +def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec, + kernel_q, kernel_vec, + conv_out, output, last): IB, _, CI, IH, IW = data_q.shape KB, CO, _, KH, KW = kernel_q.shape _, _, OH, OW = output.shape @@ -138,37 +74,21 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, wstride = get_const_int((TW - KW) // (OW - 1)) stride = (hstride, wstride) - wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NCHW") - sch = _get_schedule(wkl, "NCHW") - VH = sch.vh - VW = sch.vw - VC = sch.vc - ba = sch.ba - bc = sch.bc - - CC = s.cache_write(conv_out, "global") - n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis - s[conv_out].vectorize(vc) - - s[CC].compute_at(s[conv_out], ow) - n, co, oh, ow, vh, vw, vc = s[CC].op.axis - ci, dh, dw, b1, b2 = s[CC].op.reduce_axis - s[CC].reorder(ci, dh, vh, dw, vw, b1, b2, vc) - s[CC].unroll(b1) - s[CC].unroll(b2) - s[CC].vectorize(vc) + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] - ##### Schedule A + ##### Schedule Data padding, and bitpacking if data_pad is not None: s[data_pad].compute_inline() - _, h, _, _, _, _, vw = s[data_vec].op.axis - s[data_vec].vectorize(vw) - if ba == 1: - oaxis = h - paxis = h + _, _, h, _, _, _, _ = s[data_vec].op.axis + cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) + oh, ih = cfg["tile_ah"].apply(s, data_vec, h) + if cfg["tile_ah"].size[1] == 1: + oaxis = oh + paxis = oh else: - oh, ih = s[data_vec].split(h, ba) oaxis = oh paxis = ih @@ -178,14 +98,14 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, s[data_vec].pragma(oaxis, "parallel_barrier_when_finish") - ##### Schedule B - co, _, _, _, _, vc = s[kernel_vec].op.axis - s[kernel_vec].vectorize(vc) - if bc == 1: - oaxis = co - paxis = co + ##### Schedule Kenerl bitpacking + co, _, _, _, _, _ = s[kernel_vec].op.axis + cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) + oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) + if cfg["tile_bco"].size[1] == 1: + oaxis = oco + paxis = oco else: - oco, ico = s[kernel_vec].split(co, bc) oaxis = oco paxis = ico @@ -195,7 +115,23 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish") - ##### Schedule C + ##### Schedule Convolution + n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis + ci, dh, dw, ib, kb = s[conv_out].op.reduce_axis + + # s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2) + cfg["reorder_0"].apply(s, conv_out, [n, co, oh, ow, vc, vh, vw, dh, dw, kb, ib, ci]) + cfg["ann_reduce"].apply(s, conv_out, [kb, ib, dh, dw], + axis_lens=[get_const_int(kb.dom.extent), + get_const_int(ib.dom.extent), + get_const_int(dh.dom.extent), + get_const_int(dw.dom.extent)], + max_unroll=16, + cfg=cfg) + + s[conv_out].vectorize(vc) + + # # Schedule output n, co, h, w = s[last].op.axis co, vc = s[last].split(co, VC) oh, ow, vh, vw = s[last].tile(h, w, VH, VW) @@ -204,89 +140,58 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, s[output].compute_inline() s[conv_out].compute_at(s[last], ow) - if bc == 1: - oaxis = co - paxis = co + oco, ico = cfg["tile_oh"].apply(s, last, co) + if cfg["tile_oh"].size[1] == 1: + oaxis = oco + paxis = oco else: oco, ico = s[last].split(co, bc) oaxis = oco paxis = ico - s[last].parallel(paxis) - s[last].pragma(oaxis, "parallel_launch_point") - s[last].pragma(paxis, "parallel_stride_pattern") - s[last].pragma(oaxis, "parallel_barrier_when_finish") - + s[last].parallel(oco) return s -def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, - kernel, kernel_q, kernel_vec, - conv_out, output, last): +def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec, + kernel_q, kernel_vec, + conv_out, output, last): # no stride and padding info here _, IH, IW, CI, IB = data_q.shape KH, KW, _, CO, KB = kernel_q.shape _, OH, OW, _ = output.shape - # Infer padding and stride - if data_pad is None: - padding = (0, 0) - TH, TW = IH, IW - else: - _, TH, TW, _, _ = data_pad.shape - hpad = get_const_int((TH - IH) // 2) - wpad = get_const_int((TW - IW) // 2) - padding = (hpad, wpad) - hstride = get_const_int((TH - KH) // (OH - 1)) - wstride = get_const_int((TW - KW) // (OW - 1)) - stride = (hstride, wstride) + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] - wkl = _get_workload(data, kernel, stride, padding, last.dtype, "NHWC") - sch = _get_schedule(wkl, "NHWC") - VH = sch.vh - VW = sch.vw - VC = sch.vc - ba = sch.ba - bc = sch.bc - - ##### Schedule data packing + ##### Schedule data padding and packing if data_pad is not None: s[data_pad].compute_inline() _, h, _, _, _, _, _ = s[data_vec].op.axis - if ba == 1: - oaxis = h - paxis = h - else: - oh, ih = s[data_vec].split(h, ba) - oaxis = oh - paxis = ih - s[data_vec].parallel(paxis) - s[data_vec].pragma(oaxis, "parallel_launch_point") - s[data_vec].pragma(paxis, "parallel_stride_pattern") - s[data_vec].pragma(oaxis, "parallel_barrier_when_finish") - + cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32) + oh, ih = cfg["tile_ah"].apply(s, data_vec, h) + s[data_vec].parallel(oh) ##### Schedule kernel packing co, _, _, _, _, _ = s[kernel_vec].op.axis - if bc == 1: - oaxis = co - paxis = co - else: - oco, ico = s[kernel_vec].split(co, bc) - oaxis = oco - paxis = ico - - s[kernel_vec].parallel(paxis) - s[kernel_vec].pragma(oaxis, "parallel_launch_point") - s[kernel_vec].pragma(paxis, "parallel_stride_pattern") - s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish") - + cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32) + oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co) + s[kernel_vec].parallel(oco) ##### Schedule Convolution n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis dh, dw, ci, b1, b2 = s[conv_out].op.reduce_axis - s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2) + # s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2) + cfg["reorder_0"].apply(s, conv_out, [n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2]) + cfg["ann_reduce"].apply(s, conv_out, [b1, b2, dh, dw], + axis_lens=[get_const_int(b1.dom.extent), + get_const_int(b2.dom.extent), + get_const_int(dh.dom.extent), + get_const_int(dw.dom.extent)], + max_unroll=16, + cfg=cfg) s[conv_out].unroll(b1) s[conv_out].unroll(b2) @@ -302,17 +207,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, s[output].compute_inline() s[conv_out].compute_at(s[last], ow) - if bc == 1: - oaxis = oh - paxis = oh - else: - oho, iho = s[last].split(oh, bc) - oaxis = oho - paxis = iho - - s[last].parallel(paxis) - s[last].pragma(oaxis, "parallel_launch_point") - s[last].pragma(paxis, "parallel_stride_pattern") - s[last].pragma(oaxis, "parallel_barrier_when_finish") + oho, iho = cfg["tile_oh"].apply(s, last, oh) # reuse parameter + s[last].parallel(oho) return s diff --git a/topi/tests/python/test_topi_bitserial_conv2d.py b/topi/tests/python/test_topi_bitserial_conv2d.py index 6979cf1ce437..15db5e233df7 100644 --- a/topi/tests/python/test_topi_bitserial_conv2d.py +++ b/topi/tests/python/test_topi_bitserial_conv2d.py @@ -11,16 +11,16 @@ def generate_quantized_np(shape, bits, out_dtype): return np.random.randint(min_val, max_val, size=shape).astype(out_dtype) def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding, - activation_bits, weight_bits, dorefa): + activation_bits, weight_bits, unipolar): in_height = in_width = in_size - input_type = 'uint32' + input_dtype = 'uint32' out_dtype = 'int32' with tvm.target.create('llvm'): - A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W') - B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, - out_dtype=out_dtype, layout="NCHW", dorefa=dorefa) + A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_dtype, name='A') + W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_dtype, name='W') + B = topi.nn.bitserial_conv2d_nchw(A, W, stride, padding, activation_bits, weight_bits, + out_dtype=out_dtype, unipolar=unipolar) s = topi.generic.schedule_bitserial_conv2d_nchw([B]) a_shape = get_const_tuple(A.shape) @@ -28,9 +28,9 @@ def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, @memoize("topi.tests.test_topi_bitseral_conv2d_nchw") def get_ref_data(): - a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type) - w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type) - if dorefa: + a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype) + w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype) + if unipolar: w_ = np.copy(w_np).astype(out_dtype) for x in np.nditer(w_, op_flags=['readwrite']): x[...] = 1 if x == 1 else -1 @@ -49,16 +49,16 @@ def get_ref_data(): tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, - activation_bits, weight_bits, dorefa): + activation_bits, weight_bits, unipolar): in_height = in_width = in_size - input_type='uint32' + input_dtype='uint32' out_dtype='int32' with tvm.target.create('llvm'): - A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') - W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') - B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, - layout="NHWC", dorefa=dorefa) + A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_dtype, name='A') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_dtype, name='W') + B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits, + out_dtype=out_dtype, unipolar=unipolar) s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) a_shape = get_const_tuple(A.shape) @@ -66,9 +66,9 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, @memoize("topi.tests.test_topi_bitseral_conv2d_nhwc") def get_ref_data(): - a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type) - w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type) - if dorefa: + a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype) + w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_dtype) + if unipolar: w_ = np.copy(w_np).astype(out_dtype) for x in np.nditer(w_, op_flags=['readwrite']): x[...] = 1 if x == 1 else -1 diff --git a/topi/tests/python/test_topi_bitserial_conv2d_rasp.py b/topi/tests/python/test_topi_bitserial_conv2d_rasp.py index de467818d37f..be3e7399099c 100644 --- a/topi/tests/python/test_topi_bitserial_conv2d_rasp.py +++ b/topi/tests/python/test_topi_bitserial_conv2d_rasp.py @@ -4,6 +4,7 @@ import tvm import topi import topi.testing +from topi.util import get_const_tuple def generate_quantized_np(shape, bits, out_dtype): np.random.seed(0) @@ -13,19 +14,20 @@ def generate_quantized_np(shape, bits, out_dtype): # Verify that certain special instructions from the tensorize pass exist def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, - activation_bits, weight_bits, dorefa): + activation_bits, weight_bits, unipolar): in_height = in_width = in_size input_type = 'uint32' - out_dtype = 'int32' + out_dtype = 'int16' - with tvm.target.arm_cpu('rasp3b'): + device = 'llvm -device=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon' + with tvm.target.create(device): A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') - B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, - layout="NHWC", dorefa=dorefa) + B = topi.nn.bitserial_conv2d_nhwc(A, W, stride, padding, activation_bits, weight_bits, + pack_dtype='uint8', out_dtype='int16', unipolar=unipolar) s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) - func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b')) + func = tvm.build(s, [A, W, B], device) assembly = func.get_source('asm') matches = re.findall("vpadal", assembly) @@ -35,6 +37,33 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, matches = re.findall("vpadd", assembly) assert (len(matches) > 0) + ctx = tvm.context(device, 0) + if 'arm' not in os.uname()[4]: + print ("Skipped running code, not an arm device") + return + + print("Running on target: %s" % device) + + def get_ref_data(): + a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type) + w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type) + if unipolar: + w_ = np.copy(w_np).astype(out_dtype) + for x in np.nditer(w_, op_flags=['readwrite']): + x[...] = 1 if x == 1 else -1 + b_np = topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype) + else: + b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + + func(a, w, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + def test_bitserial_conv2d(): in_size = 56 ic, oc = 64, 64 @@ -45,6 +74,9 @@ def test_bitserial_conv2d(): verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False) verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False) + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, True) + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, True) + if __name__ == "__main__": test_bitserial_conv2d()