Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYORTL][Verilator] update ops and add MobileNet #7972

Merged
merged 8 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
5 changes: 4 additions & 1 deletion src/runtime/contrib/verilator/verilator_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ namespace tvm {
namespace runtime {
namespace contrib {

extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* data, int* weight, int* out,
extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* left, int* right, int* out,
int p_h_, int p_w_);

extern "C" TVM_DLL void verilator_bias_add(VerilatorHandle handle, int* data, int* bias, int* out,
int p_n_, int p_c_, int p_h_, int p_w_);

} // namespace contrib
} // namespace runtime
} // namespace tvm
Expand Down
19 changes: 12 additions & 7 deletions src/runtime/contrib/verilator/verilator_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ VerilatorRuntime::~VerilatorRuntime() {
auto dealloc = reinterpret_cast<VerilatorDeallocFunc>(lib_->GetSymbol("VerilatorDealloc"));
ICHECK(dealloc != nullptr);
dealloc(device_);
delete lib_;
lib_->~VerilatorLibrary();
}

void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ = lib_path; }
Expand All @@ -100,15 +100,14 @@ void VerilatorRuntime::Init(const Array<NDArray>& consts) {
ICHECK(reset != nullptr);
read_ = reinterpret_cast<VerilatorReadFunc>(lib_->GetSymbol("VerilatorRead"));
ICHECK(read_ != nullptr);
add_op_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));

// alloc verilator device
device_ = alloc();

// enable profiler
if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal();

// reset verilator device.
// reset verilator device
reset(device_, reset_cycles_);

CHECK_EQ(consts.size(), const_idx_.size())
Expand Down Expand Up @@ -136,11 +135,17 @@ void VerilatorRuntime::Run() {
if (node.GetOpType() == "kernel") {
CHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
auto entry = node.GetInputs()[0];
auto shape = node.GetOpShape()[entry.index_];
if ("add" == op_name) {
auto entry = node.GetInputs()[0];
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
ICHECK(add_op_ != nullptr);
add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
auto add = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
ICHECK(add != nullptr);
add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
} else if ("nn.bias_add" == op_name) {
auto bias_add =
reinterpret_cast<VerilatorBiasAddFunc>(lib_->GetSymbol("verilator_bias_add"));
ICHECK(bias_add != nullptr);
bias_add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[3], shape[1], shape[2]);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand Down
5 changes: 2 additions & 3 deletions src/runtime/contrib/verilator/verilator_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ using namespace tvm::runtime::json;
typedef VerilatorHandle (*VerilatorAllocFunc)();
typedef void (*VerilatorDeallocFunc)(VerilatorHandle);
typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
typedef int (*VerilatorReadFunc)(VerilatorHandle, int, int);
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
typedef void (*VerilatorBiasAddFunc)(VerilatorHandle, int*, int*, int*, int, int, int, int);

class VerilatorLibrary : public Library {
public:
Expand Down Expand Up @@ -122,8 +123,6 @@ class VerilatorRuntime : public JSONRuntimeBase {
VerilatorProfiler* prof_{nullptr};
/*! \brief the verilator read function */
VerilatorReadFunc read_{nullptr};
/*! \brief the verilator add op function */
VerilatorAddFunc add_op_{nullptr};
/*! \brief the verilator reset cycles */
int reset_cycles_{1};
/*! \brief the verilator profiler status */
Expand Down
128 changes: 104 additions & 24 deletions tests/python/contrib/test_verilator/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import sys
import subprocess as sp
import json

import tvm
from tvm import relay
Expand Down Expand Up @@ -48,6 +49,10 @@ def _func_wrapper(expr):
return _func_wrapper


_register_verilator_op("add")
_register_verilator_op("nn.bias_add")


def skip_test():
"""Skip test if it requires the Verilator codegen and it's not present."""
if not tvm.get_global_func("relay.ext.verilator", True):
Expand All @@ -59,8 +64,33 @@ def skip_test():
return False


def clear_stats():
"""Clear profiler statistics."""
f = tvm.get_global_func("verilator.profiler_clear", True)
if f:
f()


def stats():
"""Get profiler statistics."""

x = tvm.get_global_func("verilator.profiler_status")()
return json.loads(x)


def offload(mod):
"""Offload ops based on the registered ops"""
"""Offload ops based on the registered ops

Paramters
---------
mod : Module
The input module.

Returns
-------
mod : Module
The output module with offloaded ops.
"""

backend = "verilator"
mod = transform.AnnotateTarget([backend])(mod)
Expand All @@ -69,7 +99,7 @@ def offload(mod):


def verilator_app_path():
"""Find verilator hardware app path"""
"""Create verilator hardware app path."""

cur_dir = os.path.dirname(os.path.realpath(__file__))
return os.path.join(
Expand All @@ -82,37 +112,87 @@ def verilator_app_path():
"vta-hw",
"apps",
"verilator",
"add",
)


def compile_hardware():
"""Compile hardware into shared library"""
def compile_hardware(lanes):
"""Compile hardware into shared library

Paramters
---------
lanes : Int
The number of vector lanes.

Returns
-------
path : Str
The path of the shared library.
"""
lib_name = "libverilator_{}".format(lanes)
lib_name_ext = "{}.so".format(lib_name)
lib = os.path.join(verilator_app_path(), lib_name_ext)
if not os.path.isfile(lib):
opt_lib_name = "LIB_NAME={}".format(lib_name)
opt_lanes = "LANES={}".format(lanes)
cmd = []
cmd.append("make")
cmd.append("--directory")
cmd.append(verilator_app_path())
cmd.append(opt_lib_name)
cmd.append(opt_lanes)
sp.run(cmd, check=True, stdout=sp.DEVNULL)
return lib


cmd = []
cmd.append("make")
cmd.append("--directory")
cmd.append(verilator_app_path())
sp.run(cmd, check=True)
def compiler_opts(lib):
"""Create compiler options

Paramters
---------
lib : Str
The path of the hardware shared library.

def compile_module(mod):
"""Compile Relay module and hardware library"""
Returns
-------
opts : Dict
The compiler options.
"""
opts = {
"lib_path": lib,
"profiler_enable": True,
"profiler_cycle_counter_id": 0,
}
return opts

lib = os.path.join(verilator_app_path(), "libverilator.so")
if not os.path.isfile(lib):
compile_hardware()

opts = {"lib_path": lib}
def run_module(inp, mod, params=None, opts=None):
"""Compile Relay module and hardware library

with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}):
exe = relay.vm.compile(mod, target="llvm", params=None)
code, lib = exe.save()
return runtime.vm.Executable.load_exec(code, lib)
Paramters
---------
inp : Data
The input data.

mod : Module
The relay module.

def run_module(exe, inputs):
"""Run Relay module"""
params : Parameters
The model Parameters.

dev = tvm.cpu()
vm = runtime.vm.VirtualMachine(exe, dev)
return vm.run(**inputs)
opts : Dict
The compiler

Returns
-------
out : Data
The output data.
"""

with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}):
lib = relay.vm.compile(mod, target="llvm", params=params)
code, lib = lib.save()
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
out = vm.run(**inp)
return out
Loading