Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VTA] TSIM improvements and fixes #3505

Merged
merged 32 commits into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
78abce5
add tsim init function
vegaluisjose Jul 4, 2019
b855d5e
add sim device
vegaluisjose Jul 4, 2019
5eb475b
test wait and resume
vegaluisjose Jul 4, 2019
e3e55d7
launch simulation thread from DPILoader
vegaluisjose Jul 4, 2019
d3c950e
add VTASimDPI module to handle all simulation related stuff
vegaluisjose Jul 4, 2019
c09b3fa
test tsim init
vegaluisjose Jul 4, 2019
4b1f3c1
move exit to simdpi module
vegaluisjose Jul 4, 2019
f4639e6
update vta driver
vegaluisjose Jul 4, 2019
6a149c5
add chisel DPI module
vegaluisjose Jul 4, 2019
06d7b9e
get back simshell
vegaluisjose Jul 4, 2019
7e752cd
update vta to support dpi sim
vegaluisjose Jul 5, 2019
f36c58b
update unittests
vegaluisjose Jul 5, 2019
556ae39
add tsim to integration-conv2d test
vegaluisjose Jul 5, 2019
a9918b9
run resnet on tsim
vegaluisjose Jul 6, 2019
ecae18e
remove max-cycles
vegaluisjose Jul 6, 2019
47bc419
match tsim counters with sim counters
vegaluisjose Jul 6, 2019
50e2084
use env in simulator to switch between sim and tsim
vegaluisjose Jul 6, 2019
c011e19
update unittest
vegaluisjose Jul 6, 2019
3bc575d
rollback conv2d test
vegaluisjose Jul 6, 2019
540f7f4
update resnet
vegaluisjose Jul 7, 2019
8a87b47
add stats to matrix multiply
vegaluisjose Jul 7, 2019
048942e
add stats
vegaluisjose Jul 7, 2019
7aa6080
print stats after assert
vegaluisjose Jul 7, 2019
3e5cbbf
update other tests
vegaluisjose Jul 7, 2019
424cf62
add stats to gemm
vegaluisjose Jul 7, 2019
87855ab
add return and remove unused libs
vegaluisjose Jul 7, 2019
09ea11b
add missing arg
vegaluisjose Jul 7, 2019
6869377
return lib
vegaluisjose Jul 7, 2019
6a976d8
update comments for linter
vegaluisjose Jul 7, 2019
edfb2d4
add more comments to VTASimDPI module
vegaluisjose Jul 8, 2019
c0f2a81
remove trailing spaces
vegaluisjose Jul 8, 2019
63b7aa0
remove trailing spaces
vegaluisjose Jul 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package test

import chisel3._
import chisel3.experimental.MultiIOModule
import vta.dpi._
import accel._

Expand All @@ -28,32 +29,39 @@ import accel._
* Instantiate Host and Memory DPI modules.
*
*/
class VTASimShell extends Module {
val io = IO(new Bundle {
val host = new VTAHostDPIMaster
val mem = new VTAMemDPIClient
})
val host = Module(new VTAHostDPI)
val mem = Module(new VTAMemDPI)
mem.io.dpi <> io.mem
mem.io.reset := reset
mem.io.clock := clock
io.host <> host.io.dpi
host.io.reset := reset
host.io.clock := clock
class VTASimShell extends MultiIOModule {
val host = IO(new VTAHostDPIMaster)
val mem = IO(new VTAMemDPIClient)
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val mod_sim = Module(new VTASimDPI)
val mod_host = Module(new VTAHostDPI)
val mod_mem = Module(new VTAMemDPI)
mod_mem.io.clock := clock
mod_mem.io.reset := reset
mod_mem.io.dpi <> mem
mod_host.io.clock := clock
mod_host.io.reset := reset
host <> mod_host.io.dpi
mod_sim.io.clock := sim_clock
mod_sim.io.reset := reset
sim_wait := mod_sim.io.dpi_wait
}

/** Test accelerator.
*
* Instantiate and connect the simulation-shell and the accelerator.
*
*/
class TestAccel extends Module {
val io = IO(new Bundle {})
class TestAccel extends MultiIOModule {
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new VTASimShell)
val vta_accel = Module(new Accel)
vta_accel.io.host <> sim_shell.io.host
sim_shell.io.mem <> vta_accel.io.mem
sim_shell.sim_clock := sim_clock
sim_wait := sim_shell.sim_wait
sim_shell.mem <> vta_accel.io.mem
vta_accel.io.host <> sim_shell.host
}

/** Generate TestAccel as top module */
Expand Down
13 changes: 12 additions & 1 deletion vta/apps/tsim_example/hardware/verilog/src/TestAccel.v
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
module TestAccel
(
input clock,
input reset
input reset,
input sim_clock,
output sim_wait
);

localparam HOST_ADDR_BITS = 8;
Expand Down Expand Up @@ -53,6 +55,14 @@ module TestAccel
logic [MEM_DATA_BITS-1:0] mem_rd_bits;
logic mem_rd_ready;

VTASimDPI sim
(
.clock (sim_clock),
.reset (reset),

.dpi_wait (sim_wait)
);

VTAHostDPI host
(
.clock (clock),
Expand Down Expand Up @@ -114,4 +124,5 @@ module TestAccel
.mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready)
);

endmodule
47 changes: 27 additions & 20 deletions vta/apps/tsim_example/python/tsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,22 @@
import os.path as osp
from sys import platform

def driver(hw_backend):
def get_ext():
return ".dylib" if platform == "darwin" else ".so"

def load_dll(dll):
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []

def load_sw():
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
sw_libname = "libsw" + get_ext()
sw_lib = osp.join(cur_path, "..", "build", sw_libname)
load_dll(sw_lib)

def init(hw_backend):
"""Init hardware and software shared library for accelerator

Parameters
Expand All @@ -29,23 +44,15 @@ def driver(hw_backend):
Hardware backend can be verilog or chisel

"""
_ext = ".dylib" if platform == "darwin" else ".so"
_hw_libname = "libhw" + _ext
_sw_libname = "libsw" + _ext
_cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
hw_libname = "libhw" + get_ext()
if hw_backend in ("verilog", "chisel"):
_hw_lib = osp.join(_cur_path, "..", "hardware", hw_backend, "build", _hw_libname)
_sw_lib = osp.join(_cur_path, "..", "build", _sw_libname)

def load_dll(dll):
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []

def run(a, b, c):
load_dll(_sw_lib)
f = tvm.get_global_func("tvm.vta.driver")
m = tvm.module.load(_hw_lib, "vta-tsim")
return f(m, a, b, c)
return run
hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
m = tvm.module.load(hw_lib, "vta-tsim")
load_sw()
f = tvm.get_global_func("tvm.vta.tsim.init")
f(m)

def load_module():
load_sw()
return tvm.get_global_func("tvm.vta.driver")
66 changes: 53 additions & 13 deletions vta/apps/tsim_example/src/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,56 @@ uint32_t get_half_addr(void *p, bool upper) {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;

class DPILoader {
public:
~DPILoader() {
dpi_->SimResume();
dpi_->SimFinish();
}

void Init(Module module) {
mod_ = module;
dpi_ = this->Get();
dpi_->SimLaunch();
dpi_->SimWait();
}

DPIModuleNode* Get() {
return static_cast<DPIModuleNode*>(mod_.operator->());
}

static DPILoader* Global() {
static DPILoader inst;
return &inst;
}

// TVM module
Module mod_;
// DPI Module
DPIModuleNode* dpi_{nullptr};
};

class Device {
public:
Device(Module module)
: module_(module) {
dpi_ = static_cast<DPIModuleNode*>(
module.operator->());
Device() {
loader_ = DPILoader::Global();
}

uint32_t Run(uint32_t c, uint32_t length, void* inp, void* out) {
uint32_t cycles;
this->Init();
this->Launch(c, length, inp, out);
cycles = this->WaitForCompletion();
dpi_->Finish();
return cycles;
}

private:
void Init() {
dpi_ = loader_->Get();
dpi_->SimResume();
}

void Launch(uint32_t c, uint32_t length, void* inp, void* out) {
dpi_->Launch(wait_cycles_);
dpi_->WriteReg(0x08, c);
dpi_->WriteReg(0x0c, length);
dpi_->WriteReg(0x10, get_half_addr(inp, false));
Expand All @@ -70,24 +101,33 @@ class Device {
if (val == 2) break; // finish
}
val = dpi_->ReadReg(0x04);
dpi_->SimWait();
return val;
}

// wait cycles
uint32_t wait_cycles_{100000000};
DPIModuleNode* dpi_;
Module module_;
// DPI loader
DPILoader* loader_{nullptr};
// DPI Module
DPIModuleNode* dpi_{nullptr};
};

using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;

TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});

TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module dev_mod = args[0];
DLTensor* A = args[1];
DLTensor* B = args[2];
Device dev_(dev_mod);
uint32_t cycles = dev_.Run(static_cast<int>(args[3]), A->shape[0], A->data, B->data);
DLTensor* A = args[0];
DLTensor* B = args[1];
Device dev_;
uint32_t cycles = dev_.Run(static_cast<int>(args[2]), A->shape[0], A->data, B->data);
*rv = static_cast<int>(cycles);
});

Expand Down
3 changes: 2 additions & 1 deletion vta/apps/tsim_example/tests/python/chisel_accel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def test_accel():
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
f = tsim.driver("chisel")
f = tsim.load_module()
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)

if __name__ == "__main__":
tsim.init("chisel")
for i in range(10):
test_accel()
3 changes: 2 additions & 1 deletion vta/apps/tsim_example/tests/python/verilog_accel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def test_accel():
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
f = tsim.driver("verilog")
f = tsim.load_module()
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)

if __name__ == "__main__":
tsim.init("verilog")
for i in range(10):
test_accel()
21 changes: 0 additions & 21 deletions vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ module VTAHostDPI #

import "DPI-C" function void VTAHostDPI
(
output byte unsigned exit,
output byte unsigned req_valid,
output byte unsigned req_opcode,
output byte unsigned req_addr,
Expand All @@ -50,7 +49,6 @@ module VTAHostDPI #
typedef logic [31:0] dpi32_t;

dpi1_t __reset;
dpi8_t __exit;
dpi8_t __req_valid;
dpi8_t __req_opcode;
dpi8_t __req_addr;
Expand Down Expand Up @@ -80,15 +78,13 @@ module VTAHostDPI #
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__exit = 0;
__req_valid = 0;
__req_opcode = 0;
__req_addr = 0;
__req_value = 0;
end
else begin
VTAHostDPI(
__exit,
__req_valid,
__req_opcode,
__req_addr,
Expand All @@ -99,21 +95,4 @@ module VTAHostDPI #
end
end

logic [63:0] cycles;

always_ff @(posedge clock) begin
if (reset | __reset) begin
cycles <= 'd0;
end
else begin
cycles <= cycles + 1'b1;
end
end

always_ff @(posedge clock) begin
if (__exit == 'd1) begin
$finish;
end
end

endmodule
Loading