diff --git a/vta/python/vta/testing/util.py b/vta/python/vta/testing/util.py index c01d206e91268..6ec3a606baab5 100644 --- a/vta/python/vta/testing/util.py +++ b/vta/python/vta/testing/util.py @@ -24,6 +24,7 @@ def run(run_func): # with ./apps/pynq_rpc/start_rpc_server.sh # Set your VTA_LOCAL_SIM_RPC environment variable to # the port it's listening to, e.g. 9090 + local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) if local_rpc: remote = rpc.connect("localhost", local_rpc) diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index a472925465e21..3e9f78dfc00d7 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -87,9 +87,11 @@ def _get_data_movement_byte(schedule, layer): return total_xfer_byte # Scheduling exploration + OH = (layer.height + 2 * layer.hpad - layer.hkernel) // layer.hstride + 1 + OW = (layer.width + 2 * layer.wpad - layer.wkernel) // layer.wstride + 1 batch_factors = _find_factors(layer.batch // env.BATCH) - height_factors = _find_factors(layer.height // layer.hstride) - width_factors = _find_factors(layer.width // layer.wstride) + height_factors = _find_factors(OH) + width_factors = _find_factors(OW) cin_factors = _find_factors(layer.in_filter // env.BLOCK_IN) cout_factors = _find_factors(layer.out_filter // env.BLOCK_OUT) ht_factors = [1, 2] @@ -323,8 +325,6 @@ def compute_conv2d_transpose(attrs, inputs, out): layout = attrs["layout"] out_dtype = attrs['out_dtype'] - print(inputs) - assert dilation == (1, 1), "not support dilate now" if is_packed_layout(layout): return packed_conv2d_transpose(inputs[0], inputs[1], diff --git a/vta/python/vta/top/vta_conv2d_transpose.py b/vta/python/vta/top/vta_conv2d_transpose.py index d53d69d02f2c3..761263248b6d8 100644 --- a/vta/python/vta/top/vta_conv2d_transpose.py +++ b/vta/python/vta/top/vta_conv2d_transpose.py @@ -18,9 +18,34 @@ ('b_factor', 'oc_factor', 'ic_factor', 'h_factor', 'w_factor', 'oc_nthread', 'h_nthread', 'debug_sync')) +workloads = [ + Workload(1, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2), + Workload(1, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2), + Workload(1, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2), +] + +schedules = [ + Schedule(1, 16, 1, 8, 8, 1, 1, False), + Schedule(1, 4, 1, 16, 16, 1, 1, False), + Schedule(1, 1, 1, 32, 32, 1, 1, False), +] + +injected_schedule = None + def find_schedules(layer, vt_only=False, best_only=False): - return [Schedule(1, 1, 1, 2, 4, 1, 1, False)] + global injected_schedule + if injected_schedule: + return [injected_schedule] + for i, wkl in enumerate(workloads): + if str(wkl) == str(layer): + return [schedules[i]] + raise RuntimeError("No schedule for " + str(layer)) + + +def inject_schedule(sch): + global injected_schedule + injected_schedule = sch def packed_conv2d_transpose(data, @@ -28,6 +53,8 @@ def packed_conv2d_transpose(data, padding, strides, out_dtype="int32"): + env = get_env() + batch, in_c, in_h, in_w, B_BATCH, B_CI = get_const_tuple(data.shape) out_c, _, filter_h, filter_w, B_CO, B_CI = get_const_tuple(kernel.shape) stride_h, stride_w = strides @@ -84,8 +111,8 @@ def _dilate(*indices): axis=[dc, dh, dw, dci]), tag="packed_conv2d_transpose", name='res', - attrs={"workload": (n, in_h, in_w, in_c, out_c, filter_h, filter_w, - padding[0], padding[1], stride_h, stride_w)}) + attrs={"workload": (batch * env.BATCH, in_h, in_w, in_c * env.BLOCK_IN, out_c * env.BLOCK_OUT, + filter_h, filter_w, padding[0], padding[1], stride_h, stride_w)}) return Output @@ -103,8 +130,6 @@ def schedule_packed_conv2d_transpose(outs): conv2d_res = [] assert output.dtype == "int8" assert output.op.input_tensors[0].dtype == "int32" - # - #return tvm.create_schedule(output.op) def _traverse(op): if topi.tag.is_broadcast(op.tag): @@ -197,9 +222,11 @@ def _traverse(op): x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis - s[conv2d_stage].reorder(x_bo, k_o, d_j, d_i, x_co, x_i, x_j, x_bi, x_ci, k_i) + x_i, x_ii = s[conv2d_stage].split(x_i, 4) + x_j, x_jj = s[conv2d_stage].split(x_j, 2) + s[conv2d_stage].reorder(x_bo, k_o, x_j, x_co, x_i, x_jj, d_j, d_i, x_ii, x_bi, x_ci, k_i) - for axis in [d_j, d_i, x_i, x_j]: + for axis in [d_j, d_i, x_ii, x_jj]: s[conv2d_stage].unroll(axis) ic_factor = plan.ic_factor or 1 diff --git a/vta/python/vta/top/vta_group_conv2d.py b/vta/python/vta/top/vta_group_conv2d.py index c883b154f1c8f..97e9e939951b3 100644 --- a/vta/python/vta/top/vta_group_conv2d.py +++ b/vta/python/vta/top/vta_group_conv2d.py @@ -16,10 +16,46 @@ ('b_factor', 'oc_factor', 'ic_factor', 'h_factor', 'w_factor', 'oc_nthread', 'h_nthread', 'debug_sync')) +workloads = [ + Workload(1, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1), + Workload(1, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2), + Workload(1, 56, 56, 128, 128, 8, 3, 3, 1, 1, 1, 1), + Workload(1, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2), + Workload(1, 28, 28, 256, 256, 16, 3, 3, 1, 1, 1, 1), + Workload(1, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2), + Workload(1, 14, 14, 512, 512, 32, 3, 3, 1, 1, 1, 1), + Workload(1, 14, 14, 512, 512, 32, 3, 3, 1, 1, 2, 2), + Workload(1, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1), +] + +schedules = [ + Schedule(1, 1, 1, 28, 56, 1, 1, False), + Schedule(1, 1, 1, 14, 28, 1, 1, False), + Schedule(1, 1, 1, 28, 56, 1, 1, False), + Schedule(1, 1, 1, 14, 28, 1, 1, False), + Schedule(1, 1, 1, 28, 28, 1, 1, False), + Schedule(1, 1, 1, 14, 14, 1, 1, False), + Schedule(1, 1, 1, 14, 14, 1, 1, False), + Schedule(1, 1, 1, 7, 7, 1, 1, False), + Schedule(1, 1, 1, 7, 7, 1, 1, False), +] + +injected_schedule = None + +# load schedule def find_schedules(layer, vt_only=False, best_only=False): - return [Schedule(0, 0, 1, 0, 0, 0, 0, False)] - + global injected_schedule + if injected_schedule: + return [injected_schedule] + for i, wkl in enumerate(workloads): + if str(wkl) == str(layer): + return [schedules[i]] + raise RuntimeError("No schedule for " + str(layer)) + +def inject_schedule(sch): + global injected_schedule + injected_schedule = sch def _get_workload(data, pad_data, kernel, output): """ Get the workload structure. @@ -141,7 +177,6 @@ def _traverse(op): pad_data = None wrkld = _get_workload(data, pad_data, kernel, output) plan = find_schedules(wrkld, vt_only=True, best_only=True)[0] - logging.info("Trying to find plan for %s", wrkld) env = get_env() load_inp = load_wgt = load_out = store_out = env.dma_copy diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py index e338dc8fc7209..2f4e6c4935b49 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d_transpose.py @@ -1,5 +1,8 @@ """Testing if we can generate code in topi style""" +import pickle +import json + import tvm from tvm import autotvm from tvm.contrib import util @@ -10,8 +13,8 @@ import vta.testing import numpy as np -Workload = vta.top.vta_conv2d_transpose.Workload -Schedule = vta.top.vta_conv2d_transpose.Schedule +from vta.top.vta_conv2d_transpose import Workload, Schedule, inject_schedule + @tvm.tag_scope(tag=topi.tag.ELEMWISE) def my_clip(x, a_min, a_max): @@ -23,6 +26,15 @@ def my_clip(x, a_min, a_max): return x +# Helper function to get factors +def _find_factors(n): + factors = [] + for f in range(1, n + 1): + if n % f == 0: + factors.append(f) + return factors + + def test_vta_conv2d_transpose(): def run_vta_conv2d_transpose(env, remote, name, wl, profile=True): assert wl.batch % env.BATCH == 0 @@ -106,8 +118,12 @@ def verify(s, check_correctness): kernel_arr = tvm.nd.array(kernel_flipped, ctx) bias_arr = tvm.nd.array(bias_packed, ctx) res_arr = tvm.nd.array(res_np, ctx) - time_f = f.time_evaluator("conv2d_transpose", ctx, number=5) + + remote.get_function("vta.simulator.profiler_clear")() + time_f = f.time_evaluator("conv2d_transpose", ctx, number=1) cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) + res_unpack = res_arr.asnumpy().transpose( (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) if check_correctness: @@ -118,19 +134,20 @@ def verify(s, check_correctness): res_ref += bias_orig.reshape(wl.out_filter, 1, 1) res_ref = np.clip(res_ref, 0, 127).astype("int8") np.testing.assert_allclose(res_unpack, res_ref) - return cost + return cost, stats def conv2d_transpose_normal(print_ir): - print("----- Conv2d Transpose End-to-End Test-------") + # print("----- Conv2d Transpose End-to-End Test-------") with vta.build_config(): s = vta.top.schedule_packed_conv2d_transpose([res]) if print_ir: print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) - cost = verify(s, True) - gops = (num_ops / cost.mean) / float(10 ** 9) - print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + cost, stats = verify(s, True) + # gops = (num_ops / cost.mean) / float(10 ** 9) + # print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + return cost, stats - conv2d_transpose_normal(False) + return conv2d_transpose_normal(False) def _run(env, remote): tasks = [ @@ -138,13 +155,64 @@ def _run(env, remote): ('DCGAN.CT1', Workload(1, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2)), ('DCGAN.CT2', Workload(1, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2)), ('DCGAN.CT3', Workload(1, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2)), - ('DCGAN.CT4', Workload(1, 32, 32, 128, env.BLOCK_IN, 4, 4, 1, 1, 2, 2)), ] - for tsk in tasks: + # for tsk in tasks: + # print(tsk) + # name, wkl = tsk + # run_vta_conv2d_transpose(env, remote, name, wkl) + # exit() + + map_list = {} + for i, tsk in enumerate(tasks): print(tsk) name, wkl = tsk - run_vta_conv2d_transpose(env, remote, name, wkl) + + fout_height = (wkl.height - 1) * wkl.hstride - 2 * wkl.hpad + wkl.hkernel + fout_width = (wkl.width - 1) * wkl.wstride - 2 * wkl.wpad + wkl.wkernel + + batch_factors = _find_factors(wkl.batch // env.BATCH) + height_factors = _find_factors(fout_height) + width_factors = _find_factors(fout_width) + cin_factors = _find_factors(wkl.in_filter // env.BLOCK_IN) + cout_factors = _find_factors(wkl.out_filter // env.BLOCK_OUT) + ht_factors = [1] + cot_factors = [1] + + sch_list = [] + cost_list = [] + ct = 0 + total = np.prod([len(x) for x in [batch_factors, height_factors, width_factors, cin_factors, cout_factors, + ht_factors, cot_factors]]) + best = 1 << 32 + for b_f in batch_factors: + for h_f in height_factors: + for w_f in width_factors: + for ci_f in cin_factors: + for co_f in cout_factors: + for h_t in ht_factors: + for co_t in cot_factors: + sch = Schedule(b_f, co_f, ci_f, h_f, w_f, h_t, co_t, False) + inject_schedule(sch) + try: + _, stats = run_vta_conv2d_transpose(env, remote, name, wkl) + cost = stats['inp_load_nbytes'] + stats['wgt_load_nbytes'] + stats['acc_load_nbytes'] + \ + stats['out_store_nbytes'] + stats['uop_load_nbytes'] + except tvm.TVMError: + cost = 1 << 32 + best = min(best, cost) + print("[Task %d/%d] %d/%d : %d / %d" % (i, len(tasks), ct, total, cost, best)) + ct += 1 + sch_list.append(sch) + cost_list.append(cost) + cost_list = np.array(cost_list) + + sort_index = np.argsort(cost_list) + + map_list[str(wkl)] = tuple(sch_list[sort_index[0]]) + + pickle.dump(map_list, open("conv_tmp.pkl", "wb")) + vta.testing.run(_run) if __name__ == "__main__": diff --git a/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py index 59c6e262f0afc..53758245eb51e 100644 --- a/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py +++ b/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py @@ -1,7 +1,8 @@ """Testing if we can generate code in topi style""" - +import pickle import tvm from tvm import autotvm +import json from tvm.contrib import util from tvm.contrib.pickle_memoize import memoize import topi @@ -10,7 +11,7 @@ import vta.testing import numpy as np -Workload = vta.top.vta_group_conv2d.Workload +from vta.top.vta_group_conv2d import Workload, Schedule, inject_schedule @tvm.tag_scope(tag=topi.tag.ELEMWISE) @@ -22,9 +23,17 @@ def my_clip(x, a_min, a_max): x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x +# Helper function to get factors +def _find_factors(n): + factors = [] + for f in range(1, n + 1): + if n % f == 0: + factors.append(f) + return factors + def test_vta_group_conv2d(): - def run_vta_group_conv2d(env, remote, name, wl, profile=True): + def run_vta_group_conv2d(env, remote, name, wl, profile=False): assert wl.in_filter % wl.groups == 0 assert wl.out_filter % wl.groups == 0 assert wl.in_filter % (wl.groups * env.BLOCK_IN) == 0 @@ -110,8 +119,12 @@ def verify(s, check_correctness): kernel_arr = tvm.nd.array(kernel_packed, ctx) bias_arr = tvm.nd.array(bias_packed, ctx) res_arr = tvm.nd.array(res_np, ctx) - time_f = f.time_evaluator("group_conv2d", ctx, number=5) + + remote.get_function("vta.simulator.profiler_clear")() + time_f = f.time_evaluator("group_conv2d", ctx, number=1) cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) + res_unpack = res_arr.asnumpy().transpose( (0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) if check_correctness: @@ -119,38 +132,87 @@ def verify(s, check_correctness): res_ref += bias_orig.reshape(wl.out_filter, 1, 1) res_ref = np.clip(res_ref, 0, 127).astype("int8") np.testing.assert_allclose(res_unpack, res_ref) - return cost + return cost, stats def group_conv_normal(print_ir): - print("----- Group conv2d End-to-End Test-------") + # print("----- Group conv2d End-to-End Test-------") with vta.build_config(): s = vta.top.schedule_packed_group_conv2d([res]) if print_ir: print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) - cost = verify(s, True) - gops = (num_ops / cost.mean) / float(10 ** 9) - print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + cost, stats = verify(s, True) + # gops = (num_ops / cost.mean) / float(10 ** 9) + # print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + return cost, stats - group_conv_normal(False) + return group_conv_normal(False) def _run(env, remote): tasks = [ # mobilenet ('mobilenet.D1', Workload(1, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1)), ('mobilenet.D2', Workload(1, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2)), - ('mobilenet.D3', Workload(1, 56, 56, 64, 64, 4, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D3', Workload(1, 56, 56, 128, 128, 8, 3, 3, 1, 1, 1, 1)), ('mobilenet.D4', Workload(1, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2)), - ('mobilenet.D5', Workload(1, 28, 28, 256, 256, 8, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D5', Workload(1, 28, 28, 256, 256, 16, 3, 3, 1, 1, 1, 1)), ('mobilenet.D6', Workload(1, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2)), - ('mobilenet.D7', Workload(1, 14, 14, 256, 256, 16, 3, 3, 1, 1, 1, 1)), - ('mobilenet.D8', Workload(1, 14, 14, 256, 256, 16, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D7', Workload(1, 14, 14, 512, 512, 32, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D8', Workload(1, 14, 14, 512, 512, 32, 3, 3, 1, 1, 2, 2)), ('mobilenet.D9', Workload(1, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1)), ] - for tsk in tasks: + # for tsk in tasks: + # print(tsk) + # name, wkl = tsk + # run_vta_group_conv2d(env, remote, name, wkl) + # return + + map_list = {} + for i, tsk in enumerate(tasks): print(tsk) name, wkl = tsk - run_vta_group_conv2d(env, remote, name, wkl) + + batch_factors = _find_factors(wkl.batch // env.BATCH) + height_factors = _find_factors(wkl.height // wkl.hstride) + width_factors = _find_factors(wkl.width // wkl.wstride) + cin_factors = _find_factors(wkl.in_filter // env.BLOCK_IN) + cout_factors = _find_factors(wkl.out_filter // env.BLOCK_OUT) + ht_factors = [1] + cot_factors = [1] + + sch_list = [] + cost_list = [] + ct = 0 + total = np.prod([len(x) for x in [batch_factors, height_factors, width_factors, cin_factors, cout_factors, + ht_factors, cot_factors]]) + best = 1 << 32 + for b_f in batch_factors: + for h_f in height_factors: + for w_f in width_factors: + for ci_f in cin_factors: + for co_f in cout_factors: + for h_t in ht_factors: + for co_t in cot_factors: + sch = Schedule(b_f, co_f, ci_f, h_f, w_f, h_t, co_t, False) + inject_schedule(sch) + try: + _, stats = run_vta_group_conv2d(env, remote, name, wkl) + cost = stats['inp_load_nbytes'] + stats['wgt_load_nbytes'] + stats['acc_load_nbytes'] + \ + stats['out_store_nbytes'] + stats['uop_load_nbytes'] + except tvm.TVMError: + cost = 1 << 32 + best = min(best, cost) + print("[Task %d/%d] %d/%d : %d / %d" % (i, len(tasks), ct, total, cost, best)) + ct += 1 + sch_list.append(sch) + cost_list.append(cost) + cost_list = np.array(cost_list) + + sort_index = np.argsort(cost_list) + + map_list[str(wkl)] = tuple(sch_list[sort_index[0]]) + + pickle.dump(map_list, open("group_conv_tmp.pkl", "wb")) vta.testing.run(_run)