From daaf172ca8d3380f2ad241bfb19345c9a9dad808 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 5 Aug 2018 18:59:51 -0700 Subject: [PATCH] Add conv2d transpose (#9) * pass mobilenet * rename * failed to tensorize * can pass one * pass stride = 1 for conv2d_transpose * add conv2d transpose * add end2end support for gan * fix --- python/tvm/build_module.py | 1 + python/tvm/contrib/util.py | 2 +- vta/python/vta/build_module.py | 3 +- vta/python/vta/ir_pass.py | 129 +++++++++++ vta/python/vta/top/__init__.py | 2 + vta/python/vta/top/arm_conv2d.py | 1 - vta/python/vta/top/vta_conv2d.py | 39 +++- vta/python/vta/top/vta_conv2d_transpose.py | 217 ++++++++++++++++++ vta/python/vta/top/vta_dense.py | 155 +++++++++++++ vta/python/vta/top/vta_group_conv2d.py | 2 - .../test_benchmark_topi_conv2d_transpose.py | 151 ++++++++++++ .../integration/test_benchmark_topi_dense.py | 129 +++++++++++ ...py => test_benchmark_topi_group_conv2d.py} | 5 +- 13 files changed, 826 insertions(+), 10 deletions(-) create mode 100644 vta/python/vta/top/vta_conv2d_transpose.py create mode 100644 vta/python/vta/top/vta_dense.py create mode 100644 vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py create mode 100644 vta/tests/python/integration/test_benchmark_topi_dense.py rename vta/tests/python/integration/{test_benchmark_topi_group_conv.py => test_benchmark_topi_group_conv2d.py} (97%) diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 4068b1ce3a94e..1f27caab6f612 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -369,6 +369,7 @@ def lower(sch, stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.RemoveNoOp(stmt) stmt = ir_pass.RewriteUnsafeSelect(stmt) + stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase3: stmt = f(stmt) # Instrument BoundCheckers diff --git a/python/tvm/contrib/util.py b/python/tvm/contrib/util.py index d3a727f9389ff..8a84846146aa8 100644 --- a/python/tvm/contrib/util.py +++ b/python/tvm/contrib/util.py @@ -174,4 +174,4 @@ def find_all(op): for out in outputs: find_all(out) - return lower(s, inputs, simple_mode=True) + return lower(s, inputs + [x.output(0) for x in outputs], simple_mode=True) diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index 299a914b15f5d..a1d2299ba7aa6 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -53,7 +53,8 @@ def add_debug(stmt): debug_flag) return tvm.make.stmt_seq(debug, stmt) - pass_list = [(1, ptr_alias.lower_ptr_alias), + pass_list = [(0, ir_pass.inject_conv2d_transpose_skip), + (1, ptr_alias.lower_ptr_alias), (1, ir_pass.inject_dma_intrin), (1, ir_pass.inject_skip_copy), (1, ir_pass.annotate_alu_coproc_scope), diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 3efef7135edb2..c35fbd38dbc3f 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -277,6 +277,135 @@ def _do_fold(stmt): stmt_in, _do_fold, None, ["AttrStmt"]) +def show_dir(x): + print(type(x), x) + for key in dir(x): + print(key, getattr(x, key)) + + +def _get_gemm_intrin_buffer(): + env = get_env() + wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH + assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN + wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) + assert wgt_shape[0] * wgt_shape[1] == wgt_lanes + inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH + assert inp_lanes == env.BATCH * env.BLOCK_IN + inp_shape = (env.BATCH, env.BLOCK_IN) + assert inp_shape[0] * inp_shape[1] == inp_lanes + out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH + assert out_lanes == env.BATCH * env.BLOCK_OUT + out_shape = (env.BATCH, env.BLOCK_OUT) + assert out_shape[0] * out_shape[1] == out_lanes + wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]), + dtype="int%d" % env.WGT_WIDTH, + name=env.wgt_scope) + inp = tvm.placeholder((inp_shape[0], inp_shape[1]), + dtype="int%d" % env.INP_WIDTH, + name=env.inp_scope) + k = tvm.reduce_axis((0, wgt_shape[1]), name="k") + out_dtype = "int%d" % env.ACC_WIDTH + out = tvm.compute((out_shape[0], out_shape[1]), + lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) * + wgt[j, k].astype(out_dtype), + axis=[k]), + name="out") + wgt_layout = tvm.decl_buffer( + wgt.shape, wgt.dtype, env.wgt_scope, + scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) + inp_layout = tvm.decl_buffer( + inp.shape, inp.dtype, env.inp_scope, + scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) + out_layout = tvm.decl_buffer( + out.shape, out.dtype, env.acc_scope, + scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) + + return wgt_layout, inp_layout, out_layout + + +def inject_conv2d_transpose_skip(stmt_in): + env = get_env() + dwgt, dinp, dout = _get_gemm_intrin_buffer() + + calls = [] + selects = [] + + def _find_basics(op): + if isinstance(op, tvm.expr.Call): + calls.append(op) + elif isinstance(op, tvm.expr.Select): + selects.append(op) + + def _do_fold(op): + if _match_pragma(op, "conv2d_transpose_gemm"): + is_init = ".init" in str(op) + tvm.ir_pass.PostOrderVisit(op, _find_basics) + + if is_init: + # create inner most block + irb = tvm.ir_builder.create() + dev = env.dev + irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) + irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) + irb.emit(tvm.call_extern("int32", "VTAUopPush", + 0, 1, + dout.access_ptr("rw", "int32"), + 0, 0, + 0, 0, 0)) + inner = irb.get() + args = op.body.body.args + res_tensor = op.body.body.func.output(0) + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, 16) + inner = tvm.make.AttrStmt( + [dout, res_tensor], 'buffer_bind_scope', + tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) + return inner + else: + conv_call, data_call, kernel_call = calls[-3:] + pad_data_tensor, kernel_tensor, res_tensor = (data_call.func.output(0), + kernel_call.func.output(0), conv_call.func.output(0)) + + if selects: + condition = selects[0].condition + else: + condition = tvm.const(1, 'int') + + # create inner most block + irb = tvm.ir_builder.create() + with irb.if_scope(condition): + dev = env.dev + irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) + irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) + irb.emit(tvm.call_extern("int32", "VTAUopPush", + 0, 0, + dout.access_ptr("rw", "int32"), + dinp.access_ptr("r", "int32"), + dwgt.access_ptr("r", "int32"), + 0, 0, 0)) + inner = irb.get() + + args = conv_call.args + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, 16) + inner = tvm.make.AttrStmt( + [dout, res_tensor], 'buffer_bind_scope', + tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) + args = kernel_call.args + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 16, 0, 16) + inner = tvm.make.AttrStmt( + [dwgt, kernel_tensor], 'buffer_bind_scope', + tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) + args = data_call.args + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, 16) + inner = tvm.make.AttrStmt( + [dinp, pad_data_tensor], 'buffer_bind_scope', + tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) + return inner + return None + ret = tvm.ir_pass.IRTransform( + stmt_in, _do_fold, None, ["AttrStmt"]) + return ret + + def inject_coproc_sync(stmt_in): """Pass to inject skip copy stmt, used in debug. diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py index 6c07c64f27d74..fbe4d04dca7e1 100644 --- a/vta/python/vta/top/__init__.py +++ b/vta/python/vta/top/__init__.py @@ -4,5 +4,7 @@ from . import arm_conv2d from .bitpack import bitpack +from .vta_dense import packed_dense, schedule_packed_dense from .vta_conv2d import packed_conv2d, schedule_packed_conv2d from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d +from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose diff --git a/vta/python/vta/top/arm_conv2d.py b/vta/python/vta/top/arm_conv2d.py index 012c16b098ed4..634348a87cfe8 100644 --- a/vta/python/vta/top/arm_conv2d.py +++ b/vta/python/vta/top/arm_conv2d.py @@ -5,7 +5,6 @@ from topi.nn import conv2d, conv2d_alter_layout from topi import generic - @conv2d.register(["vtacpu", "vta"]) def compute(*args, **kwargs): with tvm.target.arm_cpu("vtacpu"): diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 345f379817c34..c3d2131ff16f3 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -13,7 +13,7 @@ from ..environment import get_env from ..ptr_alias import reinterpret from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d - +from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose Workload = namedtuple("Conv2DWorkload", ['batch', 'height', 'width', 'in_filter', 'out_filter', @@ -156,6 +156,7 @@ def _get_data_movement_byte(schedule, layer): return [fil_sched[xfer_size.index(min(xfer_size))]] return fil_sched + def packed_conv2d(data, kernel, padding, @@ -309,6 +310,42 @@ def schedule_conv2d(attrs, outs, target): return _nn.schedule_conv2d(attrs, outs, target) +@reg.register_compute("conv2d_transpose", level=15) +def compute_conv2d_transpose(attrs, inputs, out): + """ 2D convolution algorithm. + """ + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + layout = attrs["layout"] + out_dtype = attrs['out_dtype'] + + print(inputs) + + assert dilation == (1, 1), "not support dilate now" + if is_packed_layout(layout): + return packed_conv2d_transpose(inputs[0], inputs[1], + padding, strides, + out_dtype=out_dtype) + return _nn.compute_conv2d_transpose(attrs, inputs, out) + + +@reg.register_schedule("conv2d_transpose", level=15) +def schedule_conv2d_transpose(attrs, outs, target): + """ 2D convolution schedule. + """ + layout = attrs["layout"] + + if is_packed_layout(layout): + target = tvm.target.create(target) + if target.device_name == "vta": + return schedule_packed_conv2d_transpose(outs) + elif str(target).startswith("llvm"): + return tvm.create_schedule([x.op for x in outs]) + else: + raise RuntimeError("not support target %s" % target) + return _nn.schedule_conv2d_transpose(attrs, outs, target) + def _get_workload(data, pad_data, kernel, output): """ Get the workload structure. """ diff --git a/vta/python/vta/top/vta_conv2d_transpose.py b/vta/python/vta/top/vta_conv2d_transpose.py new file mode 100644 index 0000000000000..d53d69d02f2c3 --- /dev/null +++ b/vta/python/vta/top/vta_conv2d_transpose.py @@ -0,0 +1,217 @@ +import logging +from collections import namedtuple + +import tvm +import topi +from topi.nn.util import get_pad_tuple +from topi.util import get_const_int, get_const_tuple +from tvm.contrib.util import get_lower_ir + +from ..environment import get_env + + +Workload = namedtuple("Conv2DTransposeWorkload", + ('batch', 'height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride')) + +Schedule = namedtuple("Conv2DTransposeSchedule", + ('b_factor', 'oc_factor', 'ic_factor', 'h_factor', 'w_factor', + 'oc_nthread', 'h_nthread', 'debug_sync')) + + +def find_schedules(layer, vt_only=False, best_only=False): + return [Schedule(1, 1, 1, 2, 4, 1, 1, False)] + + +def packed_conv2d_transpose(data, + kernel, + padding, + strides, + out_dtype="int32"): + batch, in_c, in_h, in_w, B_BATCH, B_CI = get_const_tuple(data.shape) + out_c, _, filter_h, filter_w, B_CO, B_CI = get_const_tuple(kernel.shape) + stride_h, stride_w = strides + + # padding stage + fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w)) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + # padding stage + FirstPad = topi.nn.pad(data, + [0, 0, (bpad_top + stride_h - 1) // stride_h, + (bpad_left + stride_w - 1) // stride_w, 0, 0], + [0, 0, (bpad_bottom + stride_h - 1) // stride_h, + (bpad_right + stride_w - 1) // stride_w, 0, 0], + name='pad_data') + border_h = (stride_h - bpad_top % stride_h) % stride_h # remove extra padding introduced by dilatation + border_w = (stride_w - bpad_left % stride_w) % stride_w + + # dilation stage + data = FirstPad + strides = [1, 1, stride_h, stride_w, 1, 1] + n = len(data.shape) + + def _dilate(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if not topi.util.equal_const_int(strides[i], 1): + index_tuple.append(indices[i] // strides[i]) + not_zero.append((indices[i] % strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = tvm.all(*not_zero) + return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) + return data(*index_tuple) + + # convolution stage + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + dc = tvm.reduce_axis((0, in_c), name='dc') + dh = tvm.reduce_axis((0, filter_h), name='dh') + dw = tvm.reduce_axis((0, filter_w), name='dw') + dci = tvm.reduce_axis((0, B_CI), name='dci') + + Output = tvm.compute( + (batch, out_c, out_h, out_w, B_BATCH, B_CO), + lambda b, c, h, w, b_n, b_co: tvm.sum( + _dilate(b, dc, h + dh + border_h, w + dw + border_w, b_n, dci).astype(out_dtype) * + kernel[c, dc, dh, dw, b_co, dci].astype(out_dtype), + axis=[dc, dh, dw, dci]), + tag="packed_conv2d_transpose", + name='res', + attrs={"workload": (n, in_h, in_w, in_c, out_c, filter_h, filter_w, + padding[0], padding[1], stride_h, stride_w)}) + + return Output + +global_plan = None + +def set_global_plan(plan): + global global_plan + global_plan = plan + +def schedule_packed_conv2d_transpose(outs): + assert len(outs) == 1 + output = outs[0] + ewise_inputs = [] + ewise_ops = [] + conv2d_res = [] + assert output.dtype == "int8" + assert output.op.input_tensors[0].dtype == "int32" + # + #return tvm.create_schedule(output.op) + + def _traverse(op): + if topi.tag.is_broadcast(op.tag): + if not op.same_as(output.op): + ewise_ops.append(op) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.PlaceholderOp): + ewise_inputs.append((op, tensor)) + else: + _traverse(tensor.op) + else: + assert op.tag == "packed_conv2d_transpose" + conv2d_res.append(op) + + _traverse(output.op) + assert len(conv2d_res) == 1 + conv2d_stage = conv2d_res[0].output(0) + + data, kernel = conv2d_stage.op.input_tensors + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + temp = data.op.input_tensors[0] + pad_data = data + data = temp + else: + pad_data = None + + wrkld = Workload(*conv2d_stage.op.attrs['workload']) + plan = find_schedules(wrkld, vt_only=True, best_only=True)[0] + logging.info("Trying to find plan for %s", wrkld) + env = get_env() + + load_inp = load_wgt = load_out = store_out = env.dma_copy + alu = env.alu + gemm = env.gemm + + # schedule1 + s = tvm.create_schedule(output.op) + + # setup pad + if pad_data is not None: + cdata = pad_data + s[pad_data].set_scope(env.inp_scope) + else: + cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) + ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) + s[conv2d_stage].set_scope(env.acc_scope) + # cache read input + cache_read_ewise = [] + + for consumer, tensor in ewise_inputs: + cache_read_ewise.append( + s.cache_read(tensor, env.acc_scope, [consumer])) + # set ewise scope + for op in ewise_ops: + s[op].set_scope(env.acc_scope) + s[op].pragma(s[op].op.axis[0], alu) + + # tile + oc_factor = (plan.oc_factor if plan.oc_factor else 1) + h_factor = (plan.h_factor if plan.h_factor else 1) + w_factor = (plan.w_factor if plan.w_factor else 1) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis + x_co0, x_co1 = s[output].split(x_co, factor=oc_factor) + x_i0, x_i1 = s[output].split(x_i, factor=h_factor) + x_j0, x_j1 = s[output].split(x_j, factor=w_factor) + s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) + store_pt = x_j0 + + # set all compute scopes + s[conv2d_stage].compute_at(s[output], store_pt) + for op in ewise_ops: + s[op].compute_at(s[output], store_pt) + + for tensor in cache_read_ewise: + s[tensor].compute_at(s[output], store_pt) + s[tensor].pragma(s[tensor].op.axis[0], load_out) + + # virtual threading along output channel axes + if plan.oc_nthread > 1: + _, v_t = s[output].split(x_co0, factor=plan.oc_nthread) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + # virtual threading along spatial rows + if plan.h_nthread > 1: + _, v_t = s[output].split(x_i0, factor=plan.h_nthread) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis + k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis + s[conv2d_stage].reorder(x_bo, k_o, d_j, d_i, x_co, x_i, x_j, x_bi, x_ci, k_i) + + for axis in [d_j, d_i, x_i, x_j]: + s[conv2d_stage].unroll(axis) + + ic_factor = plan.ic_factor or 1 + if ic_factor: + k_o, _ = s[conv2d_stage].split(k_o, factor=ic_factor) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) + + # Use VTA instructions + s[cdata].pragma(s[cdata].op.axis[0], load_inp) + s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt) + s[conv2d_stage].pragma(x_bi, "conv2d_transpose_gemm") + s[output].pragma(x_co1, store_out) + + return s diff --git a/vta/python/vta/top/vta_dense.py b/vta/python/vta/top/vta_dense.py new file mode 100644 index 0000000000000..a190e4e979ac2 --- /dev/null +++ b/vta/python/vta/top/vta_dense.py @@ -0,0 +1,155 @@ +import logging +from collections import namedtuple + +import tvm +import topi +from topi.util import get_const_int, get_const_tuple + +from ..environment import get_env + +Workload = namedtuple("DenseWorkload", + ('batch', 'in_dim', 'out_dim')) + +Schedule = namedtuple("GroupConv2DSchedule", ('factor', )) + + +def find_schedules(layer, vt_only=False, best_only=False): + return [Schedule(0, 0, 1, 0, 0, 0, 0, False)] + + +def packed_dense(data, + weight, + out_dtype="int32"): + """ Packed conv2d function.""" + env = get_env() + + N, IN, B_BATCH, B_CI = get_const_tuple(data.shape) + OUT, IN, B_OUT, B_IN = get_const_tuple(weight.shape) + + oshape = (N, OUT, B_BATCH, B_OUT) + + ko = tvm.reduce_axis((0, IN), name='ko') + ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki') + + out = tvm.compute( + oshape, + lambda n, o, b_n, b_out: tvm.sum(data[n, ko, b_n, ki].astype(out_dtype) * + weight[o, ko, b_out, ki].astype(out_dtype), + axis=[ko, ki]), + name="res", tag="packed_dense", + attrs={'workload': (N, IN * B_CI, OUT * B_OUT)}) + return out + + +def schedule_packed_dense(outs): + """ Schedule the packed conv2d. + """ + assert len(outs) == 1 + output = outs[0] + return tvm.create_schedule(output.op) + + def _traverse(op): + if topi.tag.is_broadcast(op.tag): + if not op.same_as(output.op): + ewise_ops.append(op) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.PlaceholderOp): + ewise_inputs.append((op, tensor)) + else: + _traverse(tensor.op) + else: + assert op.tag == "packed_group_conv2d" + conv2d_res.append(op) + + _traverse(output.op) + assert len(conv2d_res) == 1 + conv2d_stage = conv2d_res[0].output(0) + + data, kernel = conv2d_stage.op.input_tensors + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + temp = data.op.input_tensors[0] + pad_data = data + data = temp + else: + pad_data = None + wrkld = _get_workload(data, pad_data, kernel, output) + plan = find_schedules(wrkld, vt_only=True, best_only=True)[0] + logging.info("Trying to find plan for %s", wrkld) + env = get_env() + + load_inp = load_wgt = load_out = store_out = env.dma_copy + alu = env.alu + gemm = env.gemm + + # schedule1 + oshape = topi.util.get_const_tuple(output.shape) + s = tvm.create_schedule(output.op) + + # setup pad + if pad_data is not None: + cdata = pad_data + s[pad_data].set_scope(env.inp_scope) + else: + cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) + ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) + s[conv2d_stage].set_scope(env.acc_scope) + # cache read input + cache_read_ewise = [] + + for consumer, tensor in ewise_inputs: + cache_read_ewise.append( + s.cache_read(tensor, env.acc_scope, [consumer])) + # set ewise scope + for op in ewise_ops: + s[op].set_scope(env.acc_scope) + s[op].pragma(s[op].op.axis[0], alu) + + # tile + oc_factor = (plan.oc_factor if plan.oc_factor else 1) + h_factor = (plan.h_factor if plan.h_factor else 1) + w_factor = (plan.w_factor if plan.w_factor else 1) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis + x_co0, x_co1 = s[output].split(x_co, factor=oc_factor) + x_i0, x_i1 = s[output].split(x_i, factor=h_factor) + x_j0, x_j1 = s[output].split(x_j, factor=w_factor) + s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) + store_pt = x_j0 + + # set all compute scopes + s[conv2d_stage].compute_at(s[output], store_pt) + for op in ewise_ops: + s[op].compute_at(s[output], store_pt) + + for tensor in cache_read_ewise: + s[tensor].compute_at(s[output], store_pt) + s[tensor].pragma(s[tensor].op.axis[0], load_out) + + # virtual threading along output channel axes + if plan.oc_nthread > 1: + _, v_t = s[output].split(x_co0, factor=plan.oc_nthread) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + # virtual threading along spatial rows + if plan.h_nthread > 1: + _, v_t = s[output].split(x_i0, factor=plan.h_nthread) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis + k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis + s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i) + + if plan.ic_factor: + k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) + + # Use VTA instructions + s[cdata].pragma(s[cdata].op.axis[0], load_inp) + s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt) + s[conv2d_stage].tensorize(x_bi, gemm) + s[output].pragma(x_co1, store_out) + + return s diff --git a/vta/python/vta/top/vta_group_conv2d.py b/vta/python/vta/top/vta_group_conv2d.py index e6891233a18d4..c883b154f1c8f 100644 --- a/vta/python/vta/top/vta_group_conv2d.py +++ b/vta/python/vta/top/vta_group_conv2d.py @@ -3,8 +3,6 @@ import tvm import topi - - from topi.util import get_const_int, get_const_tuple from tvm.contrib.util import get_lower_ir diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py new file mode 100644 index 0000000000000..e338dc8fc7209 --- /dev/null +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py @@ -0,0 +1,151 @@ +"""Testing if we can generate code in topi style""" + +import tvm +from tvm import autotvm +from tvm.contrib import util +from tvm.contrib.pickle_memoize import memoize +import topi +import topi.testing +import vta +import vta.testing +import numpy as np + +Workload = vta.top.vta_conv2d_transpose.Workload +Schedule = vta.top.vta_conv2d_transpose.Schedule + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + +def test_vta_conv2d_transpose(): + def run_vta_conv2d_transpose(env, remote, name, wl, profile=True): + assert wl.batch % env.BATCH == 0 + assert wl.in_filter % env.BLOCK_IN == 0 + assert wl.out_filter % env.BLOCK_OUT == 0 + + data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN, + wl.height, wl.width, env.BATCH, env.BLOCK_IN) + kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter // env.BLOCK_IN, + wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT, + 1, 1, env.BATCH, env.BLOCK_OUT) + + fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel + fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel + + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) + + res_conv = vta.top.packed_conv2d_transpose( + data, kernel, (wl.hpad, wl.wpad), (wl.hstride, wl.wstride)) + res = topi.right_shift(res_conv, 8) + res = topi.add(res, bias) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + # To compute number of ops, use a x2 factor for FMA + num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * \ + wl.out_filter * wl.in_filter / (wl.hstride * wl.wstride) + + a_shape = (wl.batch, wl.in_filter, wl.height, wl.width) + w_shape = (wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel) + data_dtype = data.dtype + kernel_dtype = kernel.dtype + acc_dtype = env.acc_dtype + stride = (wl.hstride, wl.wstride) + padding = (wl.hpad, wl.wpad) + + @memoize("vta.tests.test_conv2d_transpose") + def get_ref_data(): + a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) + w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) + a_np = np.abs(a_np) + w_np = np.abs(w_np) + b_np = topi.testing.conv2d_transpose_nchw_python( + a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) + return a_np, w_np, b_np + + def verify(s, check_correctness): + mod = vta.build(s, [data, kernel, bias, res], "ext_dev", + env.target_host, name="conv2d_transpose") + temp = util.tempdir() + + mod.save(temp.relpath("conv2d_transpose.o")) + remote.upload(temp.relpath("conv2d_transpose.o")) + f = remote.load_module("conv2d_transpose.o") + # verify + ctx = remote.ext_dev(0) + # Data in original format + data_orig, kernel_orig, res_ref = get_ref_data() + bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32") + bias_orig = np.abs(bias_orig) + + data_packed = data_orig.reshape( + wl.batch//env.BATCH, env.BATCH, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) + kernel_packed = kernel_orig.reshape( + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, + wl.hkernel, wl.wkernel).transpose((2, 0, 4, 5, 3, 1)) + kernel_flipped = np.flip(kernel_packed, [2, 3]) + + bias_packed = bias_orig.reshape( + 1, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) + res_shape = topi.util.get_const_tuple(res.shape) + + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_packed, ctx) + kernel_arr = tvm.nd.array(kernel_flipped, ctx) + bias_arr = tvm.nd.array(bias_packed, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d_transpose", ctx, number=5) + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + res_unpack = res_arr.asnumpy().transpose( + (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) + if check_correctness: + assert wl.hpad == wl.wpad + stride = (wl.hstride, wl.wstride) + padding = (wl.hpad, wl.wpad) + res_ref = res_ref >> 8 + res_ref += bias_orig.reshape(wl.out_filter, 1, 1) + res_ref = np.clip(res_ref, 0, 127).astype("int8") + np.testing.assert_allclose(res_unpack, res_ref) + return cost + + def conv2d_transpose_normal(print_ir): + print("----- Conv2d Transpose End-to-End Test-------") + with vta.build_config(): + s = vta.top.schedule_packed_conv2d_transpose([res]) + if print_ir: + print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) + cost = verify(s, True) + gops = (num_ops / cost.mean) / float(10 ** 9) + print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + + conv2d_transpose_normal(False) + + def _run(env, remote): + tasks = [ + # mobilenet + ('DCGAN.CT1', Workload(1, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT2', Workload(1, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT3', Workload(1, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT4', Workload(1, 32, 32, 128, env.BLOCK_IN, 4, 4, 1, 1, 2, 2)), + ] + + for tsk in tasks: + print(tsk) + name, wkl = tsk + run_vta_conv2d_transpose(env, remote, name, wkl) + vta.testing.run(_run) + +if __name__ == "__main__": + test_vta_conv2d_transpose() diff --git a/vta/tests/python/integration/test_benchmark_topi_dense.py b/vta/tests/python/integration/test_benchmark_topi_dense.py new file mode 100644 index 0000000000000..5f4d8e8e47658 --- /dev/null +++ b/vta/tests/python/integration/test_benchmark_topi_dense.py @@ -0,0 +1,129 @@ +"""Testing if we can generate code in topi style""" + +import tvm +from tvm import autotvm +from tvm.contrib import util +from tvm.contrib.pickle_memoize import memoize +import topi +import topi.testing +import vta +import vta.testing +import numpy as np + +Workload = vta.top.vta_dense.Workload + + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + +def test_vta_dense(): + def run_vta_dense(env, remote, name, wl, profile=True): + data_shape = (wl.batch//env.BATCH, wl.in_dim//env.BLOCK_IN, + env.BATCH, env.BLOCK_IN) + weight_shape = (wl.out_dim//env.BLOCK_OUT, wl.in_dim//env.BLOCK_IN, + env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (wl.batch//env.BATCH, wl.out_dim//env.BLOCK_OUT, + env.BATCH, env.BLOCK_OUT) + + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + weight = tvm.placeholder(weight_shape, name="kernel", dtype=env.wgt_dtype) + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) + data_dtype = data.dtype + weight_dtype = weight.dtype + + res = vta.top.packed_dense(data, weight) + res = topi.right_shift(res, 8) + res = topi.add(res, bias) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + # To compute number of ops, use a x2 factor for FMA + num_ops = 2 * wl.batch * wl.in_dim * wl.out_dim + a_shape = (wl.batch, wl.in_dim) + w_shape = (wl.out_dim, wl.in_dim) + acc_dtype = env.acc_dtype + + @memoize("vta.tests.test_dense") + def get_ref_data(): + a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) + w_np = (np.random.uniform(size=w_shape) * 4).astype(weight_dtype) + a_np = np.abs(a_np) + w_np = np.abs(w_np) + b_np = np.dot(a_np.astype(acc_dtype), w_np.astype(acc_dtype).T).astype(acc_dtype) + return a_np, w_np, b_np + + def verify(s, check_correctness): + mod = vta.build(s, [data, weight, bias, res], "ext_dev", + env.target_host, name="dense") + temp = util.tempdir() + + mod.save(temp.relpath("dense.o")) + remote.upload(temp.relpath("dense.o")) + f = remote.load_module("dense.o") + # verify + ctx = remote.ext_dev(0) + # Data in original format + data_orig, id_card_opriginal, res_ref = get_ref_data() + bias_orig = (np.random.uniform(size=(wl.out_dim,)) * 4).astype("int32") + bias_orig = np.ones_like(bias_orig) + + data_packed = data_orig.reshape( + wl.batch//env.BATCH, env.BATCH, + wl.in_dim//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3)) + weight_packed = id_card_opriginal.reshape( + wl.out_dim//env.BLOCK_OUT, env.BLOCK_OUT, + wl.in_dim//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3)) + bias_packed = bias_orig.reshape( + 1, wl.out_dim // env.BLOCK_OUT, 1, env.BLOCK_OUT) + res_shape = topi.util.get_const_tuple(res.shape) + + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_packed, ctx) + weight_arr = tvm.nd.array(weight_packed, ctx) + bias_arr = tvm.nd.array(bias_packed, ctx) + res_arr = tvm.nd.array(res_np, ctx) + + time_f = f.time_evaluator("dense", ctx, number=5) + cost = time_f(data_arr, weight_arr, bias_arr, res_arr) + res_unpack = res_arr.asnumpy().transpose( + (0, 2, 1, 3)).reshape(wl.batch, wl.out_dim) + if check_correctness: + res_ref = res_ref >> 8 + res_ref += bias_orig.reshape(1, wl.out_dim) + res_ref = np.clip(res_ref, 0, 127).astype("int8") + np.testing.assert_allclose(res_unpack, res_ref) + return cost + + def dense_normal(print_ir): + print("----- dense End-to-End Test-------") + with vta.build_config(): + s = vta.top.schedule_packed_dense([res]) + if print_ir: + print(vta.lower(s, [data, weight, bias, res], simple_mode=True)) + cost = verify(s, True) + gops = (num_ops / cost.mean) / float(10 ** 9) + print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + + dense_normal(False) + + def _run(env, remote): + tasks = [ + ('dense.DEN1', Workload(1, 1024, 1024)), + ('dense.DEN2', Workload(1, 512, 512)), + ] + + for tsk in tasks: + name, wkl = tsk + run_vta_dense(env, remote, name, wkl) + + vta.testing.run(_run) + +if __name__ == "__main__": + test_vta_dense() diff --git a/vta/tests/python/integration/test_benchmark_topi_group_conv.py b/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py similarity index 97% rename from vta/tests/python/integration/test_benchmark_topi_group_conv.py rename to vta/tests/python/integration/test_benchmark_topi_group_conv2d.py index 0b16c41350c07..59c6e262f0afc 100644 --- a/vta/tests/python/integration/test_benchmark_topi_group_conv.py +++ b/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py @@ -68,7 +68,7 @@ def run_vta_group_conv2d(env, remote, name, wl, profile=True): padding = (wl.hpad, wl.wpad) groups = wl.groups - @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc") + @memoize("vta.tests.test_group_conv2d") def get_ref_data(): a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) @@ -115,9 +115,6 @@ def verify(s, check_correctness): res_unpack = res_arr.asnumpy().transpose( (0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) if check_correctness: - assert wl.hpad == wl.wpad - stride = (wl.hstride, wl.wstride) - padding = (wl.hpad, wl.wpad) res_ref = res_ref >> 8 res_ref += bias_orig.reshape(wl.out_filter, 1, 1) res_ref = np.clip(res_ref, 0, 127).astype("int8")