From fe60bf8e37f2d3727d448b57729e33ddd0b8e3e5 Mon Sep 17 00:00:00 2001 From: huajsj Date: Wed, 11 Aug 2021 01:37:04 -0700 Subject: [PATCH] [VTA] Make vta graph_pack compatible with latest TVM, and bring back object detection tutorials. --- vta/python/vta/top/graphpack.py | 33 ++- vta/tutorials/frontend/deploy_detection.py | 324 +++++++++++++++++++++ 2 files changed, 351 insertions(+), 6 deletions(-) create mode 100644 vta/tutorials/frontend/deploy_detection.py diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index a982b88b75e84..f15e4922b4a8b 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -56,13 +56,24 @@ def _pack_batch_channel(data, dshape, bfactor, cfactor): return data -def _unpack_batch_channel(data, old_shape): +def _unpack_batch_channel(data, old_shape, unpack_transpose=False): """Unpack the data channel dimension.""" - data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) + if unpack_transpose: + data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) data = op.reshape(data, newshape=old_shape) return data +def _channel_const_match(channel_length, cfactor_out): + """Round the chanel const variant if the value not divisible by cfactor_out""" + diff = int(channel_length) % cfactor_out + if diff != 0: + diff = cfactor_out - diff + channel_length = channel_length + diff + + return diff, channel_length + + def _const_shape_match(data, dshape, cfactor_out): """Pad the constant if the shape[0] not divisible by cfactor_out.""" assert len(dshape) == 3 @@ -299,6 +310,7 @@ def __init__(self, bfactor, cfactor, weight_bits): self.upsampling = op.op.get("nn.upsampling") self.reshape = op.op.get("reshape") self.number_of_conv2d = 0 + self.unpack_transpose = True super().__init__() def visit_call(self, call): @@ -319,7 +331,7 @@ def visit_call(self, call): self.start_pack = False data = args[0] data_shape = _get_tensor_shape(call.args[0]) - return _unpack_batch_channel(data, data_shape) + return _unpack_batch_channel(data, data_shape, self.unpack_transpose) if self.start_pack: # Operator cases if call.op == self.conv2d and odtype == "int32": @@ -429,12 +441,12 @@ def visit_call(self, call): if len(pad_width) == 6: pass elif len(pad_width) == 4: - (data,) = args + (data, pad_value) = args new_pad_width = [] new_pad_width.extend(pad_width) for _ in range(2): new_pad_width.append([0, 0]) - return op.nn.pad(data, pad_value=call.attrs.pad_value, pad_width=new_pad_width) + return op.nn.pad(data, pad_value=pad_value, pad_width=new_pad_width) elif call.op == self.upsampling: (data,) = args scale_h = call.attrs.scale_h @@ -445,8 +457,17 @@ def visit_call(self, call): return op.nn.upsampling(data, scale_h, scale_w, data_layout, method, align_corners) elif call.op == self.reshape and len(input_types[0].shape) == 4: (data,) = args + self.unpack_transpose = False data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) - return op.reshape(data, [int(x) for x in input_types[0].shape]) + new_shape = [int(x) for x in input_types[0].shape] + # Check if the reshape match with such shape after pad + pad, new_shape[1] = _channel_const_match(new_shape[1], self.cfactor) + data = op.reshape(data, new_shape) + # remove pad data + if pad != 0: + new_pad_width = [[0, 0], [0, -pad], [0, 0], [0, 0]] + data = op.nn.pad(data, pad_width=new_pad_width) + return data return relay.Call(self.visit(call.op), args, call.attrs) diff --git a/vta/tutorials/frontend/deploy_detection.py b/vta/tutorials/frontend/deploy_detection.py new file mode 100644 index 0000000000000..e80b93c3bc6a7 --- /dev/null +++ b/vta/tutorials/frontend/deploy_detection.py @@ -0,0 +1,324 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Deploy Pretrained Vision Detection Model from Darknet on VTA +================================================ +**Author**: `Hua Jiang `_ + +This tutorial provides an end-to-end demo, on how to run Darknet YoloV3-tiny +inference onto the VTA accelerator design to perform Image detection tasks. +It showcases Relay as a front end compiler that can perform quantization (VTA +only supports int8/32 inference) as well as graph packing (in order to enable +tensorization in the core) to massage the compute graph for the hardware target. +""" + +###################################################################### +# Install dependencies +# -------------------- +# To use the autotvm package in tvm, we need to install some extra dependencies. +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install "Pillow<7" +# +# YOLO-V3-tiny Model with Darknet parsing have dependancy with CFFI and CV2 library, +# we need to install CFFI and CV2 before executing this script. +# +# pip3 install "Pillow<7" +# +# pip3 install cffi +# pip3 install opencv-python +# +# Now return to the python code. Import packages. + +from __future__ import absolute_import, print_function + +import sys +import os +import time +import matplotlib.pyplot as plt +import numpy as np +import tvm +import vta +from tvm import rpc, autotvm, relay +from tvm.relay.testing import yolo_detection, darknet +from tvm.relay.testing.darknet import __darknetffi__ +from tvm.contrib import graph_runtime, graph_runtime, utils +from tvm.contrib.download import download_testdata +from vta.testing import simulator +from vta.top import graph_pack + +# Make sure that TVM was compiled with RPC=1 +assert tvm.runtime.enabled("rpc") + +############################################################################## +# Download yolo net configure file, weight file, darknet library file based on +# Model Name +# ---------------------------------------------------------------------------- +MODEL_NAME = "yolov3-tiny" +REPO_URL = "https://github.com/dmlc/web-data/blob/master/darknet/" + +cfg_path = download_testdata( + "https://github.com/pjreddie/darknet/blob/master/cfg/" + MODEL_NAME + ".cfg" + "?raw=true", + MODEL_NAME + ".cfg", + module="darknet", +) +weights_path = download_testdata( + "https://pjreddie.com/media/files/" + MODEL_NAME + ".weights" + "?raw=true", + MODEL_NAME + ".weights", + module="darknet", +) + +if sys.platform in ["linux", "linux2"]: + darknet_lib_path = download_testdata( + REPO_URL + "lib/" + "libdarknet2.0.so" + "?raw=true", "libdarknet2.0.so", module="darknet" + ) +elif sys.platform == "darwin": + darknet_lib_path = download_testdata( + REPO_URL + "lib_osx/" + "libdarknet_mac2.0.so" + "?raw=true", + "libdarknet_mac2.0.so", + module="darknet", + ) +else: + raise NotImplementedError("Darknet lib is not supported on {} platform".format(sys.platform)) + +################################################## +# Download yolo categories and illustration front. +# ------------------------------------------------ +coco_path = download_testdata( + REPO_URL + "data/" + "coco.names" + "?raw=true", "coco.names", module="data" +) +font_path = download_testdata( + REPO_URL + "data/" + "arial.ttf" + "?raw=true", "arial.ttf", module="data" +) +with open(coco_path) as f: + content = f.readlines() +names = [x.strip() for x in content] + +######################################## +# Define the platform and model targets. +# -------------------------------------- +# Execute on CPU vs. VTA, and define the model. + +# Load VTA parameters from the vta/config/vta_config.json file +env = vta.get_env() +# Set ``device=arm_cpu`` to run inference on the CPU +# or ``device=vta`` to run inference on the FPGA. +device = "vta" +target = env.target if device == "vta" else env.target_vta_cpu + +pack_dict = { + "yolov3-tiny": ["nn.max_pool2d", "cast", 4, 186], +} + +# Name of Darknet model to compile +# The ``start_pack`` and ``stop_pack`` labels indicate where +# to start and end the graph packing relay pass: in other words +# where to start and finish offloading to VTA. +# the number 4 indicate the the ``start_pack`` index is 4, the +# number 185 indicate the ``stop_pack index`` is 185, by using +# name and index number, here we can located to correct place +# where to start/end when there are multiple ``nn.max_pool2d`` +# or ``cast``, print(mod.astext(show_meta_data=False)) can help +# to find operator name and index information. +assert MODEL_NAME in pack_dict + +############################# +# Obtain an execution remote. +# --------------------------- +# When target is 'pynq' or other FPGA backend, reconfigure FPGA and runtime. +# Otherwise, if target is 'sim', execute locally. + +if env.TARGET not in ["sim", "tsim"]: + # Get remote from tracker node if environment variable is set. + # To set up the tracker, you'll need to follow the "Auto-tuning + # a convolutional network for VTA" tutorial. + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = os.environ.get("TVM_TRACKER_PORT", None) + # Otherwise if you have a device you want to program directly from + # the host, make sure you've set the variables below to the IP of + # your board. + device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") + device_port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091") + if not tracker_host or not tracker_port: + remote = rpc.connect(device_host, int(device_port)) + else: + remote = autotvm.measure.request_remote( + env.TARGET, tracker_host, int(tracker_port), timeout=10000 + ) + # Reconfigure the JIT runtime and FPGA. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + reconfig_start = time.time() + vta.reconfig_runtime(remote) + vta.program_fpga(remote, bitstream=None) + reconfig_time = time.time() - reconfig_start + print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) + +# In simulation mode, host the RPC server locally. +else: + remote = rpc.LocalSession() + +# Get execution context from remote +ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) + +#################################### +# Build the inference graph runtime. +# ---------------------------------- +# Using Darknet library load downloaded vision model and compile with Relay. +# The compilation steps are: +# +# 1. Front end translation from Darknet into Relay module. +# 2. Apply 8-bit quantization: here we skip the first conv layer, +# and dense layer which will both be executed in fp32 on the CPU. +# 3. Perform graph packing to alter the data layout for tensorization. +# 4. Perform constant folding to reduce number of operators (e.g. eliminate batch norm multiply). +# 5. Perform relay build to object file. +# 6. Load the object file onto remote (FPGA device). +# 7. Generate graph runtime, `m`. + +# Load pre-configured AutoTVM schedules +with autotvm.tophub.context(target): + net = __darknetffi__.dlopen(darknet_lib_path).load_network( + cfg_path.encode("utf-8"), weights_path.encode("utf-8"), 0 + ) + dshape = (env.BATCH, net.c, net.h, net.w) + dtype = "float32" + + # Measure build start time + build_start = time.time() + + # Start front end compilation + mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=dshape) + + if target.device_name == "vta": + # Perform quantization in Relay + # Note: We set opt_level to 3 in order to fold batch norm + with relay.build_config(opt_level=3): + with relay.quantize.qconfig( + global_scale=23.0, + skip_conv_layers=[0], + store_lowbit_output=True, + round_for_shift=True, + ): + mod = relay.quantize.quantize(mod, params=params) + # Perform graph packing and constant folding for VTA target + mod = graph_pack( + mod["main"], + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=pack_dict[MODEL_NAME][0], + stop_name=pack_dict[MODEL_NAME][1], + start_name_idx=pack_dict[MODEL_NAME][2], + stop_name_idx=pack_dict[MODEL_NAME][3], + ) + else: + mod = mod["main"] + + # Compile Relay program with AlterOpLayout disabled + with vta.build_config(disabled_pass={"AlterOpLayout"}): + graph, lib, params = relay.build( + mod, target=target, params=params, target_host=env.target_host + ) + + # Measure Relay build time + build_time = time.time() - build_start + print(MODEL_NAME + " inference graph built in {0:.2f}s!".format(build_time)) + + # Send the inference library over to the remote RPC server + temp = utils.tempdir() + lib.save(temp.relpath("graphlib.o")) + remote.upload(temp.relpath("graphlib.o")) + lib = remote.load_module("graphlib.o") + + # Graph runtime + m = graph_runtime.create(graph, lib, ctx) + +#################################### +# Perform image detection inference. +# ---------------------------------- +# We run detect on an downloaded image +# Download test image +[neth, netw] = dshape[2:] +test_image = "person.jpg" +img_url = REPO_URL + "data/" + test_image + "?raw=true" +img_path = download_testdata(img_url, test_image, "data") +data = darknet.load_image(img_path, neth, netw).transpose(1, 2, 0) + +# Prepare test image for inference +plt.imshow(data) +plt.show() +data = data.transpose((2, 0, 1)) +data = data[np.newaxis, :] +data = np.repeat(data, env.BATCH, axis=0) + +# Set the network parameters and inputs +m.set_input("data", data) +m.set_input(**params) + +# Perform inference and gather execution statistics +# More on: https://docs.tvm.ai/api/python/module.html#tvm.runtime.Module.time_evaluator +num = 1 # number of times we run module for a single measurement +rep = 1 # number of measurements (we derive std dev from this) +timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep) + +if env.TARGET in ["sim", "tsim"]: + simulator.clear_stats() + timer() + sim_stats = simulator.stats() + print("\nExecution statistics:") + for k, v in sim_stats.items(): + # Since we execute the workload many times, we need to normalize stats + # Note that there is always one warm up run + # Therefore we divide the overall stats by (num * rep + 1) + print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1))) +else: + tcost = timer() + std = np.std(tcost.results) * 1000 + mean = tcost.mean * 1000 + print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH)) + print("Average per sample inference time: %.2fms" % (mean / env.BATCH)) + +# Get detection results from out +thresh = 0.5 +nms_thresh = 0.45 +tvm_out = [] +for i in range(2): + layer_out = {} + layer_out["type"] = "Yolo" + # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total) + layer_attr = m.get_output(i * 4 + 3).asnumpy() + layer_out["biases"] = m.get_output(i * 4 + 2).asnumpy() + layer_out["mask"] = m.get_output(i * 4 + 1).asnumpy() + out_shape = (layer_attr[0], layer_attr[1] // layer_attr[0], layer_attr[2], layer_attr[3]) + layer_out["output"] = m.get_output(i * 4).asnumpy().reshape(out_shape) + layer_out["classes"] = layer_attr[4] + tvm_out.append(layer_out) + thresh = 0.560 + +# Show detection results +img = darknet.load_image_color(img_path) +_, im_h, im_w = img.shape +dets = yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, 1, tvm_out) +last_layer = net.layers[net.n - 1] +yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) +yolo_detection.show_detections(img, dets, thresh, names, last_layer.classes) +yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) +plt.imshow(img.transpose(1, 2, 0)) +plt.show()