diff --git a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala index ca6803c15aba..d184cd2c286a 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala @@ -103,20 +103,21 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)( when(dec.xpad_1 =/= 0.U) { state := sXPad1 }.elsewhen(dec.ypad_1 =/= 0.U) { - state := sYPad1 - } - .otherwise { - state := sIdle - } - }.elsewhen(dataCtrl.io.stride || dataCtrl.io.split) { + state := sYPad1 + } + .otherwise { + state := sIdle + } + }.elsewhen(dataCtrl.io.stride) { when(dec.xpad_1 =/= 0.U) { state := sXPad1 }.elsewhen(dec.xpad_0 =/= 0.U) { - state := sXPad0 - } - .otherwise { - state := sReadCmd - } + state := sXPad0 + }.otherwise { + state := sReadCmd + } + }.elsewhen(dataCtrl.io.split) { + state := sReadCmd } } } @@ -168,13 +169,11 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)( xPadCtrl0.io.start := dec.xpad_0 =/= 0.U & ((state === sIdle & io.start) | (state === sYPad0 & yPadCtrl0.io.done) | - (io.vme_rd.data - .fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) | + (io.vme_rd.data.fire() & ~dataCtrlDone & dataCtrl.io.stride & dec.xpad_1 === 0.U) | (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone)) xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() & - ((dataCtrl.io.done) | - (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U)) + ((dataCtrl.io.done) | (~dataCtrl.io.done & dataCtrl.io.stride & dec.xpad_1 =/= 0.U)) yPadCtrl0.io.inst := io.inst yPadCtrl1.io.inst := io.inst diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py index 574273f274f4..ef3c45ce58d6 100644 --- a/vta/tests/python/unittest/test_vta_insn.py +++ b/vta/tests/python/unittest/test_vta_insn.py @@ -24,6 +24,7 @@ import vta.testing from vta.testing import simulator +np.random.seed(0xdeadb) def test_save_load_out(): """Test save/store output command""" @@ -88,68 +89,73 @@ def _run(env, remote): def test_padded_load(): """Test padded load.""" def _run(env, remote): - # declare - n = 3 - m = 5 - pad_before = [2, 1, 0, 0] - pad_after = [1, 2, 0, 0] - x = tvm.placeholder( - (n, m, env.BATCH, env.BLOCK_OUT), - name="x", - dtype=env.acc_dtype) - x_buf = topi.nn.pad(x, pad_before, pad_after, name="y") - # insert no-op that won't be optimized away - y_buf = tvm.compute((n + pad_before[0] + pad_after[0], + def check_padded_load(pad_before, pad_after, test_name=None): + # declare + n = 3 + m = 5 + x = tvm.placeholder( + (n, m, env.BATCH, env.BLOCK_OUT), + name="x", + dtype=env.acc_dtype) + x_buf = topi.nn.pad(x, pad_before, pad_after, name="y") + # insert no-op that won't be optimized away + y_buf = tvm.compute((n + pad_before[0] + pad_after[0], + m + pad_before[1] + pad_after[1], + env.BATCH, + env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf") + y = tvm.compute((n + pad_before[0] + pad_after[0], m + pad_before[1] + pad_after[1], env.BATCH, - env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf") - y = tvm.compute((n + pad_before[0] + pad_after[0], - m + pad_before[1] + pad_after[1], - env.BATCH, - env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y") - # schedule - s = tvm.create_schedule(y.op) - s[x_buf].set_scope(env.acc_scope) - s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy) - s[y_buf].set_scope(env.acc_scope) - s[y_buf].pragma(y_buf.op.axis[0], env.alu) - s[y].pragma(y.op.axis[0], env.dma_copy) - # build - with vta.build_config(): - mod = vta.build(s, [x, y], "ext_dev", env.target_host) + env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y") + # schedule + s = tvm.create_schedule(y.op) + s[x_buf].set_scope(env.acc_scope) + s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy) + s[y_buf].set_scope(env.acc_scope) + s[y_buf].pragma(y_buf.op.axis[0], env.alu) + s[y].pragma(y.op.axis[0], env.dma_copy) + # build + with vta.build_config(): + mod = vta.build(s, [x, y], "ext_dev", env.target_host) - if not remote: - return - temp = util.tempdir() - mod.save(temp.relpath("padded_load.o")) - remote.upload(temp.relpath("padded_load.o")) - f = remote.load_module("padded_load.o") - # verify - ctx = remote.ext_dev(0) - x_np = np.random.randint(-10, 10, size=( - n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype) - y_np = np.zeros((n + pad_before[0] + pad_after[0], - m + pad_before[1] + pad_after[1], - env.BATCH, - env.BLOCK_OUT)).astype(y.dtype) - y_np[pad_before[0]:pad_before[0] + n, - pad_before[1]:pad_before[1] + m, - :] = x_np - x_nd = tvm.nd.array(x_np, ctx) - y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) + if not remote: + return + temp = util.tempdir() + mod.save(temp.relpath("padded_load.o")) + remote.upload(temp.relpath("padded_load.o")) + f = remote.load_module("padded_load.o") + # verify + ctx = remote.ext_dev(0) + x_np = np.random.randint(0, 10, size=( + n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype) + y_np = np.zeros((n + pad_before[0] + pad_after[0], + m + pad_before[1] + pad_after[1], + env.BATCH, + env.BLOCK_OUT)).astype(y.dtype) + y_np[pad_before[0]:pad_before[0] + n, + pad_before[1]:pad_before[1] + m, + :] = x_np + x_nd = tvm.nd.array(x_np, ctx) + y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) - if env.TARGET in ["sim", "tsim"]: - simulator.clear_stats() + if env.TARGET in ["sim", "tsim"]: + simulator.clear_stats() - f(x_nd, y_nd) + f(x_nd, y_nd) - np.testing.assert_equal(y_np, y_nd.asnumpy()) + np.testing.assert_equal(y_np, y_nd.asnumpy()) - if env.TARGET in ["sim", "tsim"]: - sim_stats = simulator.stats() - print("Padded load execution statistics:") - for k, v in sim_stats.items(): - print("\t{:<16}: {:>16}".format(k, v)) + if env.TARGET in ["sim", "tsim"]: + sim_stats = simulator.stats() + print("Padded {} load execution statistics:".format(test_name)) + for k, v in sim_stats.items(): + print("\t{:<16}: {:>16}".format(k, v)) + + check_padded_load([2, 0, 0, 0], [0, 0, 0, 0], test_name="Y0") + check_padded_load([0, 2, 0, 0], [0, 0, 0, 0], test_name="Y1") + check_padded_load([0, 0, 0, 0], [2, 0, 0, 0], test_name="X0") + check_padded_load([0, 0, 0, 0], [0, 2, 0, 0], test_name="X1") + check_padded_load([1, 1, 0, 0], [1, 1, 0, 0], test_name="all") vta.testing.run(_run)