diff --git a/3rdparty/vta-hw b/3rdparty/vta-hw index 43194178b4e5..dfe9f572a43d 160000 --- a/3rdparty/vta-hw +++ b/3rdparty/vta-hw @@ -1 +1 @@ -Subproject commit 43194178b4e570a5f1dd4f3f9d37ee16fc1b65be +Subproject commit dfe9f572a43d41e0c1ecdf036cea97042a0febfe diff --git a/src/runtime/contrib/verilator/verilator_kernel.h b/src/runtime/contrib/verilator/verilator_kernel.h index f62097c0d795..57353297db8d 100644 --- a/src/runtime/contrib/verilator/verilator_kernel.h +++ b/src/runtime/contrib/verilator/verilator_kernel.h @@ -33,9 +33,12 @@ namespace tvm { namespace runtime { namespace contrib { -extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* data, int* weight, int* out, +extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* left, int* right, int* out, int p_h_, int p_w_); +extern "C" TVM_DLL void verilator_bias_add(VerilatorHandle handle, int* data, int* bias, int* out, + int p_n_, int p_c_, int p_h_, int p_w_); + } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/verilator/verilator_runtime.cc b/src/runtime/contrib/verilator/verilator_runtime.cc index 5dfb8441c864..85172d480ead 100644 --- a/src/runtime/contrib/verilator/verilator_runtime.cc +++ b/src/runtime/contrib/verilator/verilator_runtime.cc @@ -80,7 +80,7 @@ VerilatorRuntime::~VerilatorRuntime() { auto dealloc = reinterpret_cast(lib_->GetSymbol("VerilatorDealloc")); ICHECK(dealloc != nullptr); dealloc(device_); - delete lib_; + lib_->~VerilatorLibrary(); } void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ = lib_path; } @@ -100,7 +100,6 @@ void VerilatorRuntime::Init(const Array& consts) { ICHECK(reset != nullptr); read_ = reinterpret_cast(lib_->GetSymbol("VerilatorRead")); ICHECK(read_ != nullptr); - add_op_ = reinterpret_cast(lib_->GetSymbol("verilator_add")); // alloc verilator device device_ = alloc(); @@ -108,7 +107,7 @@ void VerilatorRuntime::Init(const Array& consts) { // enable profiler if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal(); - // reset verilator device. + // reset verilator device reset(device_, reset_cycles_); CHECK_EQ(consts.size(), const_idx_.size()) @@ -136,11 +135,17 @@ void VerilatorRuntime::Run() { if (node.GetOpType() == "kernel") { CHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); + auto entry = node.GetInputs()[0]; + auto shape = node.GetOpShape()[entry.index_]; if ("add" == op_name) { - auto entry = node.GetInputs()[0]; - auto shape = nodes_[entry.id_].GetOpShape()[entry.index_]; - ICHECK(add_op_ != nullptr); - add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]); + auto add = reinterpret_cast(lib_->GetSymbol("verilator_add")); + ICHECK(add != nullptr); + add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]); + } else if ("nn.bias_add" == op_name) { + auto bias_add = + reinterpret_cast(lib_->GetSymbol("verilator_bias_add")); + ICHECK(bias_add != nullptr); + bias_add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[3], shape[1], shape[2]); } else { LOG(FATAL) << "Unsupported op: " << op_name; } diff --git a/src/runtime/contrib/verilator/verilator_runtime.h b/src/runtime/contrib/verilator/verilator_runtime.h index acdaa3b03ce2..664a0413dfcf 100644 --- a/src/runtime/contrib/verilator/verilator_runtime.h +++ b/src/runtime/contrib/verilator/verilator_runtime.h @@ -50,8 +50,9 @@ using namespace tvm::runtime::json; typedef VerilatorHandle (*VerilatorAllocFunc)(); typedef void (*VerilatorDeallocFunc)(VerilatorHandle); typedef void (*VerilatorResetFunc)(VerilatorHandle, int); -typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int); typedef int (*VerilatorReadFunc)(VerilatorHandle, int, int); +typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int); +typedef void (*VerilatorBiasAddFunc)(VerilatorHandle, int*, int*, int*, int, int, int, int); class VerilatorLibrary : public Library { public: @@ -122,8 +123,6 @@ class VerilatorRuntime : public JSONRuntimeBase { VerilatorProfiler* prof_{nullptr}; /*! \brief the verilator read function */ VerilatorReadFunc read_{nullptr}; - /*! \brief the verilator add op function */ - VerilatorAddFunc add_op_{nullptr}; /*! \brief the verilator reset cycles */ int reset_cycles_{1}; /*! \brief the verilator profiler status */ diff --git a/tests/python/contrib/test_verilator/infrastructure.py b/tests/python/contrib/test_verilator/infrastructure.py index cf9f8bd4c6bc..779f7872eb2b 100644 --- a/tests/python/contrib/test_verilator/infrastructure.py +++ b/tests/python/contrib/test_verilator/infrastructure.py @@ -19,6 +19,7 @@ import os import sys import subprocess as sp +import json import tvm from tvm import relay @@ -48,6 +49,10 @@ def _func_wrapper(expr): return _func_wrapper +_register_verilator_op("add") +_register_verilator_op("nn.bias_add") + + def skip_test(): """Skip test if it requires the Verilator codegen and it's not present.""" if not tvm.get_global_func("relay.ext.verilator", True): @@ -59,8 +64,33 @@ def skip_test(): return False +def clear_stats(): + """Clear profiler statistics.""" + f = tvm.get_global_func("verilator.profiler_clear", True) + if f: + f() + + +def stats(): + """Get profiler statistics.""" + + x = tvm.get_global_func("verilator.profiler_status")() + return json.loads(x) + + def offload(mod): - """Offload ops based on the registered ops""" + """Offload ops based on the registered ops + + Paramters + --------- + mod : Module + The input module. + + Returns + ------- + mod : Module + The output module with offloaded ops. + """ backend = "verilator" mod = transform.AnnotateTarget([backend])(mod) @@ -69,7 +99,7 @@ def offload(mod): def verilator_app_path(): - """Find verilator hardware app path""" + """Create verilator hardware app path.""" cur_dir = os.path.dirname(os.path.realpath(__file__)) return os.path.join( @@ -82,37 +112,87 @@ def verilator_app_path(): "vta-hw", "apps", "verilator", + "add", ) -def compile_hardware(): - """Compile hardware into shared library""" +def compile_hardware(lanes): + """Compile hardware into shared library + + Paramters + --------- + lanes : Int + The number of vector lanes. + + Returns + ------- + path : Str + The path of the shared library. + """ + lib_name = "libverilator_{}".format(lanes) + lib_name_ext = "{}.so".format(lib_name) + lib = os.path.join(verilator_app_path(), lib_name_ext) + if not os.path.isfile(lib): + opt_lib_name = "LIB_NAME={}".format(lib_name) + opt_lanes = "LANES={}".format(lanes) + cmd = [] + cmd.append("make") + cmd.append("--directory") + cmd.append(verilator_app_path()) + cmd.append(opt_lib_name) + cmd.append(opt_lanes) + sp.run(cmd, check=True, stdout=sp.DEVNULL) + return lib + - cmd = [] - cmd.append("make") - cmd.append("--directory") - cmd.append(verilator_app_path()) - sp.run(cmd, check=True) +def compiler_opts(lib): + """Create compiler options + Paramters + --------- + lib : Str + The path of the hardware shared library. -def compile_module(mod): - """Compile Relay module and hardware library""" + Returns + ------- + opts : Dict + The compiler options. + """ + opts = { + "lib_path": lib, + "profiler_enable": True, + "profiler_cycle_counter_id": 0, + } + return opts - lib = os.path.join(verilator_app_path(), "libverilator.so") - if not os.path.isfile(lib): - compile_hardware() - opts = {"lib_path": lib} +def run_module(inp, mod, params=None, opts=None): + """Compile Relay module and hardware library - with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}): - exe = relay.vm.compile(mod, target="llvm", params=None) - code, lib = exe.save() - return runtime.vm.Executable.load_exec(code, lib) + Paramters + --------- + inp : Data + The input data. + mod : Module + The relay module. -def run_module(exe, inputs): - """Run Relay module""" + params : Parameters + The model Parameters. - dev = tvm.cpu() - vm = runtime.vm.VirtualMachine(exe, dev) - return vm.run(**inputs) + opts : Dict + The compiler + + Returns + ------- + out : Data + The output data. + """ + + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}): + lib = relay.vm.compile(mod, target="llvm", params=params) + code, lib = lib.save() + exe = runtime.vm.Executable.load_exec(code, lib) + vm = runtime.vm.VirtualMachine(exe, tvm.cpu()) + out = vm.run(**inp) + return out diff --git a/tests/python/contrib/test_verilator/test_mobilenet.py b/tests/python/contrib/test_verilator/test_mobilenet.py new file mode 100644 index 000000000000..8447f19141ce --- /dev/null +++ b/tests/python/contrib/test_verilator/test_mobilenet.py @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te, relay, transform +from tvm.contrib.download import download_testdata +from tvm.contrib import graph_executor as runtime + +import os +from PIL import Image +import numpy as np + +from test_verilator.infrastructure import ( + compile_hardware, + compiler_opts, + offload, + clear_stats, + stats, +) + + +def extract(path): + """Extract a tgz or gz file. + + Paramters + --------- + path : Str + The path of the compressed file. + """ + import tarfile + + if path.endswith("tgz") or path.endswith("gz"): + dir_path = os.path.dirname(path) + tar = tarfile.open(path) + tar.extractall(path=dir_path) + tar.close() + else: + raise RuntimeError("Could not decompress the file: " + path) + + +def get_real_image(im_height, im_width): + """Get a real image. + + Paramters + --------- + im_height : Int + The image height. + + im_width : Int + The image width. + + Returns + ------- + data: Data + The image array. + """ + repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/" + img_name = "elephant-299.jpg" + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module="data") + image = Image.open(img_path).resize((im_height, im_width)) + x = np.array(image).astype("uint8") + data = np.reshape(x, (1, im_height, im_width, 3)) + return data + + +def get_mobilenet_model(): + """Return mobilenet model.""" + model_url = "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz" + model_path = download_testdata( + model_url, "mobilenet_v1_1.0_224_quant.tgz", module=["tf", "official"] + ) + model_dir = os.path.dirname(model_path) + extract(model_path) + tflite_model_file = os.path.join(model_dir, "mobilenet_v1_1.0_224_quant.tflite") + tflite_model_buf = open(tflite_model_file, "rb").read() + try: + import tflite + + return tflite.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite.Model + + return tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + + +def get_input_tensor_name(): + """Return input name.""" + return "input" + + +def compile_model_to_relay(model): + """Compile model to relay. + + Paramters + --------- + model : Model + The input model. + + Returns + ------- + mod: Module + The relay module. + + params: Parameters + The model parameters. + """ + input_tensor = get_input_tensor_name() + input_shape = (1, 224, 224, 3) + input_dtype = "uint8" + mod, params = relay.frontend.from_tflite( + model, + shape_dict={input_tensor: input_shape}, + dtype_dict={input_tensor: input_dtype}, + ) + return mod, params + + +def run_model(mod, params=None, opts=None): + """Run model. + + Paramters + --------- + mod: Module + The relay module. + + params: Parameters + The model parameters. + + opts: Dict + The compiler options. + + Returns + ------- + out: Data + The output data. + """ + with transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}): + lib = relay.build(mod, target="llvm", params=params) + module = runtime.GraphModule(lib["default"](tvm.cpu())) + image_data = get_real_image(224, 224) + input_tensor = get_input_tensor_name() + module.set_input(input_tensor, image_data) + module.run() + out = module.get_output(0).asnumpy() + return out + + +def get_labels(): + """Return labels.""" + label_file_url = "".join( + [ + "https://raw.githubusercontent.com/", + "tensorflow/tensorflow/master/tensorflow/lite/java/demo/", + "app/src/main/assets/", + "labels_mobilenet_quant_v1_224.txt", + ] + ) + label_file = "labels_mobilenet_quant_v1_224.txt" + label_path = download_testdata(label_file_url, label_file, module="data") + # List of 1001 classes + with open(label_path) as f: + labels = f.readlines() + return labels + + +def check_result(res): + """Check prediction.""" + labels = get_labels() + predictions = np.squeeze(res) + prediction = np.argmax(predictions) + # 387 is the elephant + assert prediction == 387 + + +def print_test_info(lanes, cycles): + """Print test info + + Paramters + --------- + lanes : Int + The number of vector lanes. + + cycles : Int + The number of cycles. + """ + print( + "[mobilenet] vector-lanes:{} number of cycles:{} spent in nn.bias_add".format(lanes, cycles) + ) + + +def is_tflite_available(): + """Skip test if tensorflow-lite is not installed.""" + try: + import tflite + + return True + except: + return False + + +def tmobilenet(lanes): + """Mobilenet test template. + Paramters + --------- + lanes : Int + The number of vector lanes. + """ + if not is_tflite_available(): + return + model = get_mobilenet_model() + mod, params = compile_model_to_relay(model) + mod = offload(mod) + lib = compile_hardware(lanes) + opts = compiler_opts(lib) + clear_stats() + res = run_model(mod, params, opts) + values = stats() + check_result(res) + print_test_info(lanes, values["cycle_counter"]) + + +def test_mobilenet(): + """Mobilenet tests.""" + tmobilenet(4) + tmobilenet(32) diff --git a/tests/python/contrib/test_verilator/test_verilator_codegen.py b/tests/python/contrib/test_verilator/test_verilator_codegen.py deleted file mode 100644 index 664e254041b2..000000000000 --- a/tests/python/contrib/test_verilator/test_verilator_codegen.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Verilator codegen tests""" - -import numpy as np - -import tvm -from tvm import relay - -from test_verilator.infrastructure import ( - _register_verilator_op, - skip_test, - compile_module, - run_module, - offload, -) - - -_register_verilator_op("add") - - -def create_module_add(shape, dtype): - x = relay.var("x", shape=shape, dtype=dtype) - y = relay.var("y", shape=shape, dtype=dtype) - z = relay.add(x, y) - f = relay.Function([x, y], z) - mod = tvm.IRModule() - mod["main"] = f - return mod - - -def run_check_add(exe, shape, dtype): - x_data = np.random.randint(5, size=shape, dtype=dtype) - y_data = np.random.randint(5, size=shape, dtype=dtype) - ref = x_data + y_data - inputs = {"x": x_data, "y": y_data} - out = run_module(exe, inputs) - tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5) - - -def test_add(): - if skip_test(): - return - dtype = "int32" - shape = (8, 4) - mod = create_module_add(shape, dtype) - mod = offload(mod) - exe = compile_module(mod) - run_check_add(exe, shape, dtype) - - -if __name__ == "__main__": - test_add() diff --git a/tests/python/contrib/test_verilator/test_verilator_ops.py b/tests/python/contrib/test_verilator/test_verilator_ops.py new file mode 100644 index 000000000000..19ed1f067fed --- /dev/null +++ b/tests/python/contrib/test_verilator/test_verilator_ops.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Verilator codegen tests""" + +import numpy as np + +import tvm +from tvm import relay + +from test_verilator.infrastructure import ( + skip_test, + compile_hardware, + compiler_opts, + run_module, + offload, + clear_stats, + stats, +) + + +def create_module_add(shape, dtype): + """Create add module. + + Paramters + --------- + shape : Tuple + The shape tuple. + + dtype : Str + The data type. + + Returns + ------- + mod: Module + The relay module. + """ + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.var("y", shape=shape, dtype=dtype) + z = relay.add(x, y) + f = relay.Function([x, y], z) + mod = tvm.IRModule() + mod["main"] = f + return mod + + +def create_module_bias_add(xshape, yshape, dtype): + """Create bias_add module. + + Paramters + --------- + xshape : Tuple + The x shape tuple. + + yshape : Tuple + The y shape tuple. + + dtype : Str + The data type. + + Returns + ------- + mod: Module + The relay module. + """ + x = relay.var("x", shape=xshape, dtype=dtype) + y = relay.var("y", shape=yshape, dtype=dtype) + z = relay.nn.bias_add(x, y, axis=3) + f = relay.Function([x, y], z) + mod = tvm.IRModule() + mod["main"] = f + return mod + + +def run_and_check(xshape, yshape, dtype, mod, opts): + """Run and check values. + + Paramters + --------- + xshape : Tuple + The x shape tuple. + + yshape : Tuple + The y shape tuple. + + dtype : Str + The data type. + + mod: Module + The relay module. + + opts: Dict + The compiler options. + + Returns + ------- + cycles: Int + The number of cycles. + """ + x_data = np.random.randint(5, size=xshape, dtype=dtype) + y_data = np.random.randint(5, size=yshape, dtype=dtype) + ref = x_data + y_data + inp = {"x": x_data, "y": y_data} + clear_stats() + out = run_module(inp, mod, params=None, opts=opts) + values = stats() + tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5) + return values["cycle_counter"] + + +def print_test_info(test, lanes, cycles): + """Print counter + + Paramters + --------- + test : Str + The name of the test. + + lanes : Int + The number of vector lanes. + + cycles : Int + The number of cycles. + """ + print("test:{} vector-lanes:{} number of cycles:{}".format(test, lanes, cycles)) + + +def tadd(lanes): + """Print counter + + Paramters + --------- + lanes : Int + The number of vector lanes. + """ + if skip_test(): + return + dtype = "int32" + shape = (8, 4) + mod = create_module_add(shape, dtype) + mod = offload(mod) + lib = compile_hardware(lanes) + opts = compiler_opts(lib) + cycles = run_and_check(shape, shape, dtype, mod, opts) + print_test_info("add", lanes, cycles) + + +def tbias(lanes): + """Print counter + + Paramters + --------- + lanes : Int + The number of vector lanes. + """ + if skip_test(): + return + dtype = "int32" + xshape = (1, 112, 112, 32) + yshape = (32,) + mod = create_module_bias_add(xshape, yshape, dtype) + mod = offload(mod) + lib = compile_hardware(lanes) + opts = compiler_opts(lib) + cycles = run_and_check(xshape, yshape, dtype, mod, opts) + print_test_info("nn.bias_add", lanes, cycles) + + +def test_add(): + """add tests.""" + tadd(1) + tadd(4) + + +def test_bias_add(): + """bias_add tests.""" + tbias(1) + tbias(32)