From 481fc5dccd0e6bf5c0d23294fa94a65ada06513f Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 8 Aug 2019 17:21:31 -0700 Subject: [PATCH 01/18] adding support for graphpack over multiply op --- vta/python/vta/top/graphpack.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index d894fc0dbec6..6087270418df 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -85,8 +85,8 @@ def _pack_weight_conv2d_transpose(data, dshape, cfactor): return data -def _pack_bias(data, dshape, dtype, bfactor, cfactor): - """Pack the bias parameter. +def _pack_const(data, dshape, dtype, bfactor, cfactor): + """Pack a constant parameter. """ dshape = _to_shape(dshape) assert len(dshape) == 3 @@ -124,6 +124,7 @@ def __init__(self, bfactor, cfactor, weight_bits): self.conv2d = op.op.get("nn.conv2d") self.conv2d_transpose = op.op.get("nn.conv2d_transpose") self.add = op.op.get("add") + self.multiply = op.op.get("multiply") self.bias_add = op.op.get("nn.bias_add") self.number_of_conv2d = 0 super().__init__() @@ -206,16 +207,26 @@ def visit_call(self, call): elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape): pass elif call.op == self.add and len(input_types[1].shape) == 3: - data, bias = args - bias = _pack_bias(bias, + data, const = args + const = _pack_const(const, _to_shape(input_types[1].shape), input_types[1].dtype, self.bfactor, self.cfactor) - return relay.Call(self.add, [data, bias]) + return relay.Call(self.add, [data, const]) + elif call.op == self.multiply and tuple(input_types[0].shape) == tuple(input_types[1].shape): + pass + elif call.op == self.multiply and len(input_types[1].shape) == 3: + data, const = args + const = _pack_const(const, + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) + return relay.Call(self.multiply, [data, const]) elif self.start_pack and call.op == self.bias_add: data, bias = args - bias = _pack_bias(bias, + bias = _pack_const(bias, _to_shape(input_types[1].shape), input_types[1].dtype, self.bfactor, From f2e6bffd4aa64dba7aebce9006324d1b7b946265 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 8 Aug 2019 17:23:55 -0700 Subject: [PATCH 02/18] increasing resnet model coverage --- .../frontend/deploy_resnet_on_vta.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/vta/tutorials/frontend/deploy_resnet_on_vta.py b/vta/tutorials/frontend/deploy_resnet_on_vta.py index c01f989a159d..7c020e7c4a9b 100644 --- a/vta/tutorials/frontend/deploy_resnet_on_vta.py +++ b/vta/tutorials/frontend/deploy_resnet_on_vta.py @@ -40,7 +40,7 @@ from __future__ import absolute_import, print_function -import argparse, json, os, requests, time +import argparse, json, os, requests, sys, time from io import BytesIO from os.path import join, isfile from PIL import Image @@ -53,6 +53,7 @@ from tvm import rpc, autotvm, relay from tvm.contrib import graph_runtime, util, download from tvm.contrib.debugger import debug_runtime +from tvm.relay import transform import vta from vta.testing import simulator @@ -61,6 +62,8 @@ # Make sure that TVM was compiled with RPC=1 assert tvm.module.enabled("rpc") +# Increase python recursion limit to traverse Relay program +sys.setrecursionlimit(10000) ###################################################################### # Define the platform and model targets @@ -75,13 +78,24 @@ device = "vta" target = env.target if device == "vta" else env.target_vta_cpu +# Dictionary lookup for when to start/end bit packing +# TODO(zihengjiang, tmoreau89) v2s will be supported once #3543 is merged +pack_dict = { + "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet152_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], +} + # Name of Gluon model to compile # The ``start_pack`` and ``stop_pack`` labels indicate where # to start and end the graph packing relay pass: in other words # where to start and finish offloading to VTA. model = "resnet18_v1" -start_pack="nn.max_pool2d" -stop_pack="nn.global_avg_pool2d" +assert model in pack_dict ###################################################################### # Obtain an execution remote @@ -170,8 +184,8 @@ env.BATCH, env.BLOCK_OUT, env.WGT_WIDTH, - start_name=start_pack, - stop_name=stop_pack) + start_name=pack_dict[model][0], + stop_name=pack_dict[model][1]) # Compile Relay program with AlterOpLayout disabled with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): From dd6de0cf7cd59f8239fab67040b300e5540be8ee Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 8 Aug 2019 17:40:52 -0700 Subject: [PATCH 03/18] fix indentation --- vta/python/vta/top/graphpack.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 6087270418df..e66b91e58d8f 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -209,28 +209,28 @@ def visit_call(self, call): elif call.op == self.add and len(input_types[1].shape) == 3: data, const = args const = _pack_const(const, - _to_shape(input_types[1].shape), - input_types[1].dtype, - self.bfactor, - self.cfactor) + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) return relay.Call(self.add, [data, const]) elif call.op == self.multiply and tuple(input_types[0].shape) == tuple(input_types[1].shape): pass elif call.op == self.multiply and len(input_types[1].shape) == 3: data, const = args const = _pack_const(const, - _to_shape(input_types[1].shape), - input_types[1].dtype, - self.bfactor, - self.cfactor) + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) return relay.Call(self.multiply, [data, const]) elif self.start_pack and call.op == self.bias_add: data, bias = args bias = _pack_const(bias, - _to_shape(input_types[1].shape), - input_types[1].dtype, - self.bfactor, - self.cfactor) + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) return relay.Call(self.add, [data, bias]) elif self.start_pack and call.op == op.op.get('cast') and \ input_types[0].dtype == 'int32': From c2f0b3431f58317154753905cda0dd9bb630c01a Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 8 Aug 2019 18:43:14 -0700 Subject: [PATCH 04/18] lint --- vta/python/vta/top/graphpack.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index e66b91e58d8f..a4c054880ac2 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -204,7 +204,8 @@ def visit_call(self, call): output_padding=call.attrs.output_padding, out_dtype=call.attrs.out_dtype) return conv2d - elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape): + elif call.op == self.add and \ + tuple(input_types[0].shape) == tuple(input_types[1].shape): pass elif call.op == self.add and len(input_types[1].shape) == 3: data, const = args @@ -214,7 +215,8 @@ def visit_call(self, call): self.bfactor, self.cfactor) return relay.Call(self.add, [data, const]) - elif call.op == self.multiply and tuple(input_types[0].shape) == tuple(input_types[1].shape): + elif call.op == self.multiply and \ + tuple(input_types[0].shape) == tuple(input_types[1].shape): pass elif call.op == self.multiply and len(input_types[1].shape) == 3: data, const = args From 8d6be27bc9ddcc96c2f022bc0f531671c752d8a3 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 9 Aug 2019 13:56:39 -0700 Subject: [PATCH 05/18] moving recursion limit fix into graphpack pass --- vta/python/vta/top/graphpack.py | 10 ++++++++++ vta/tutorials/frontend/deploy_resnet_on_vta.py | 3 --- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index a4c054880ac2..169946c8ee5b 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -319,7 +319,13 @@ def graph_pack(expr, expr : Expr The transformed expression. """ + import sys assert isinstance(expr, relay.Function) + + # Increase python recursion limit to traverse Relay program + oldrecursionlimit = sys.getrecursionlimit() + sys.setrecursionlimit(10000) + expr = get_subgraph(expr, start_name, stop_name) expr = run_opt_pass(expr, transform.InferType()) packer = ExprPack( @@ -327,4 +333,8 @@ def graph_pack(expr, weight_bits) expr = packer.visit(expr) assert not packer.start_pack + + # Restore recursion limit + sys.setrecursionlimit(oldrecursionlimit) + return run_opt_pass(expr, transform.InferType()) diff --git a/vta/tutorials/frontend/deploy_resnet_on_vta.py b/vta/tutorials/frontend/deploy_resnet_on_vta.py index 7c020e7c4a9b..dfd360a953d7 100644 --- a/vta/tutorials/frontend/deploy_resnet_on_vta.py +++ b/vta/tutorials/frontend/deploy_resnet_on_vta.py @@ -62,9 +62,6 @@ # Make sure that TVM was compiled with RPC=1 assert tvm.module.enabled("rpc") -# Increase python recursion limit to traverse Relay program -sys.setrecursionlimit(10000) - ###################################################################### # Define the platform and model targets # ------------------------------------- From a8307b51969ec3297241a2ef89482144a9073e23 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 9 Aug 2019 14:54:33 -0700 Subject: [PATCH 06/18] moving recursionlimit to relay init --- python/tvm/relay/__init__.py | 5 +++++ vta/python/vta/top/graphpack.py | 9 --------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 82712442a6b6..bbc742b76a6a 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -59,6 +59,10 @@ from .scope_builder import ScopeBuilder +# Required to traverse large programs +from sys import setrecursionlimit +setrecursionlimit(10000) + # Span Span = base.Span @@ -136,3 +140,4 @@ # Feature Feature = feature.Feature + diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 169946c8ee5b..53084eb4cc59 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -321,11 +321,6 @@ def graph_pack(expr, """ import sys assert isinstance(expr, relay.Function) - - # Increase python recursion limit to traverse Relay program - oldrecursionlimit = sys.getrecursionlimit() - sys.setrecursionlimit(10000) - expr = get_subgraph(expr, start_name, stop_name) expr = run_opt_pass(expr, transform.InferType()) packer = ExprPack( @@ -333,8 +328,4 @@ def graph_pack(expr, weight_bits) expr = packer.visit(expr) assert not packer.start_pack - - # Restore recursion limit - sys.setrecursionlimit(oldrecursionlimit) - return run_opt_pass(expr, transform.InferType()) From 62eaec8d7f841c854c0b47f41e505ca014fd27f6 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 9 Aug 2019 17:02:18 -0700 Subject: [PATCH 07/18] trailing line --- python/tvm/relay/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index bbc742b76a6a..db4ea084522f 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -140,4 +140,3 @@ # Feature Feature = feature.Feature - From 0cb0aca82f31481f0f978ee4e99cb85f8a26a1ca Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 9 Aug 2019 15:05:25 -0700 Subject: [PATCH 08/18] pooling on NCHWnc format --- src/relay/op/nn/pooling.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 72de07173d35..06502c4ca9d0 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -161,9 +161,12 @@ Array Pool2DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool2d does not support input split on width"; - CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + CHECK(inputs[0].ndim() == 4U || + inputs[0].ndim() == 5U || + inputs[0].ndim() == 6U) << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; + << " or 5-D input (last dimension is a split of channel)" + << " or 6-D input (last 2 dimensions are split of batch and channel)"; if (param->padding.size() == 1) { padding.push_back(padding[0]); From 0462406d6f357e3b8c8a853ebbce0fa6a5ce568c Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 9 Aug 2019 15:06:14 -0700 Subject: [PATCH 09/18] adding more models --- vta/tutorials/frontend/deploy_resnet_on_vta.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vta/tutorials/frontend/deploy_resnet_on_vta.py b/vta/tutorials/frontend/deploy_resnet_on_vta.py index dfd360a953d7..016bbdfe6ac2 100644 --- a/vta/tutorials/frontend/deploy_resnet_on_vta.py +++ b/vta/tutorials/frontend/deploy_resnet_on_vta.py @@ -76,8 +76,9 @@ target = env.target if device == "vta" else env.target_vta_cpu # Dictionary lookup for when to start/end bit packing -# TODO(zihengjiang, tmoreau89) v2s will be supported once #3543 is merged +# TODO(zihengjiang, tmoreau89) some quantization will break until #3543 is merged pack_dict = { + "alexnet": ["nn.max_pool2d", "nn.batch_flatten"], "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], @@ -85,6 +86,10 @@ "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet152_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "vgg11": ["nn.max_pool2d", "nn.batch_flatten"], + "vgg13": ["nn.max_pool2d", "nn.batch_flatten"], + "vgg16": ["nn.max_pool2d", "nn.batch_flatten"], + "vgg19": ["nn.max_pool2d", "nn.batch_flatten"], } # Name of Gluon model to compile From 80f3a23ac8ff4f44c864743b815c3253cfd94f56 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 9 Aug 2019 15:10:28 -0700 Subject: [PATCH 10/18] deploy_resnet_on_vta.py --- .../frontend/deploy_resnet_on_vta.py | 12 +- .../frontend/deploy_vision_on_vta.py | 291 ++++++++++++++++++ 2 files changed, 297 insertions(+), 6 deletions(-) create mode 100644 vta/tutorials/frontend/deploy_vision_on_vta.py diff --git a/vta/tutorials/frontend/deploy_resnet_on_vta.py b/vta/tutorials/frontend/deploy_resnet_on_vta.py index 016bbdfe6ac2..f14a0e36b039 100644 --- a/vta/tutorials/frontend/deploy_resnet_on_vta.py +++ b/vta/tutorials/frontend/deploy_resnet_on_vta.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. """ -Deploy Pretrained ResNet Model from MxNet on VTA +Deploy Pretrained Vision Model from MxNet on VTA ================================================ **Author**: `Thierry Moreau `_ -This tutorial provides an end-to-end demo, on how to run ResNet-18 inference -onto the VTA accelerator design to perform ImageNet classification tasks. +This tutorial provides an end-to-end demo, on how to run ImageNet classification +inference onto the VTA accelerator design to perform ImageNet classification tasks. It showcases Relay as a front end compiler that can perform quantization (VTA only supports int8/32 inference) as well as graph packing (in order to enable tensorization in the core) to massage the compute graph for the hardware target. @@ -141,7 +141,7 @@ ###################################################################### # Build the inference graph runtime # --------------------------------- -# Grab ResNet-18 model from Gluon model zoo and compile with Relay. +# Grab vision model from Gluon model zoo and compile with Relay. # The compilation steps are: # 1) Front end translation from MxNet into Relay module. # 2) Apply 8-bit quantization: here we skip the first conv layer, @@ -156,7 +156,7 @@ # Load pre-configured AutoTVM schedules with autotvm.tophub.context(target): - # Populate the shape and data type dictionary for ResNet input + # Populate the shape and data type dictionary for ImageNet input dtype_dict = {"data": 'float32'} shape_dict = {"data": (env.BATCH, 3, 224, 224)} @@ -215,7 +215,7 @@ m = graph_runtime.create(graph, lib, ctx) ###################################################################### -# Perform ResNet-18 inference +# Perform image classification # --------------------------- # We run classification on an image sample from ImageNet # We just need to download the categories files, `synset.txt` diff --git a/vta/tutorials/frontend/deploy_vision_on_vta.py b/vta/tutorials/frontend/deploy_vision_on_vta.py new file mode 100644 index 000000000000..8229de1d1140 --- /dev/null +++ b/vta/tutorials/frontend/deploy_vision_on_vta.py @@ -0,0 +1,291 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Deploy Pretrained ResNet Model from MxNet on VTA +================================================ +**Author**: `Thierry Moreau `_ + +This tutorial provides an end-to-end demo, on how to run ResNet-18 inference +onto the VTA accelerator design to perform ImageNet classification tasks. +It showcases Relay as a front end compiler that can perform quantization (VTA +only supports int8/32 inference) as well as graph packing (in order to enable +tensorization in the core) to massage the compute graph for the hardware target. +""" + +###################################################################### +# Install dependencies +# -------------------- +# To use the autotvm package in tvm, we need to install some extra dependencies. +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user mxnet requests pillow +# +# Now return to the python code. Import packages. + +from __future__ import absolute_import, print_function + +import argparse, json, os, requests, sys, time +from io import BytesIO +from os.path import join, isfile +from PIL import Image + +from mxnet.gluon.model_zoo import vision +import numpy as np +from matplotlib import pyplot as plt + +import tvm +from tvm import rpc, autotvm, relay +from tvm.contrib import graph_runtime, util, download +from tvm.contrib.debugger import debug_runtime +from tvm.relay import transform + +import vta +from vta.testing import simulator +from vta.top import graph_pack + +# Make sure that TVM was compiled with RPC=1 +assert tvm.module.enabled("rpc") + +###################################################################### +# Define the platform and model targets +# ------------------------------------- +# Execute on CPU vs. VTA, and define the model. + +# Load VTA parameters from the vta/config/vta_config.json file +env = vta.get_env() + +# Set ``device=arm_cpu`` to run inference on the CPU +# or ``device=vta`` to run inference on the FPGA. +device = "vta" +target = env.target if device == "vta" else env.target_vta_cpu + +# Dictionary lookup for when to start/end bit packing +# TODO(zihengjiang, tmoreau89) some quantization will break until #3543 is merged +pack_dict = { + "alexnet": ["nn.max_pool2d", "nn.batch_flatten"], + "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet152_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "vgg11": ["nn.max_pool2d", "nn.batch_flatten"], + "vgg13": ["nn.max_pool2d", "nn.batch_flatten"], + "vgg16": ["nn.max_pool2d", "nn.batch_flatten"], + "vgg19": ["nn.max_pool2d", "nn.batch_flatten"], +} + +# Name of Gluon model to compile +# The ``start_pack`` and ``stop_pack`` labels indicate where +# to start and end the graph packing relay pass: in other words +# where to start and finish offloading to VTA. +model = "resnet18_v1" +assert model in pack_dict + +###################################################################### +# Obtain an execution remote +# --------------------------------- +# When target is 'pynq', reconfigure FPGA and runtime. +# Otherwise, if target is 'sim', execute locally. + +if env.TARGET not in ["sim", "tsim"]: + + # Get remote from tracker node if environment variable is set. + # To set up the tracker, you'll need to follow the "Auto-tuning + # a convolutional network for VTA" tutorial. + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + # Otherwise if you have a device you want to program directly from + # the host, make sure you've set the variables below to the IP of + # your board. + device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") + device_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) + if not tracker_host or not tracker_port: + remote = rpc.connect(device_host, device_port) + else: + remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000) + + # Reconfigure the JIT runtime and FPGA. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + reconfig_start = time.time() + vta.reconfig_runtime(remote) + vta.program_fpga(remote, bitstream=None) + reconfig_time = time.time() - reconfig_start + print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) + +# In simulation mode, host the RPC server locally. +else: + remote = rpc.LocalSession() + +# Get execution context from remote +ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) + +###################################################################### +# Build the inference graph runtime +# --------------------------------- +# Grab ResNet-18 model from Gluon model zoo and compile with Relay. +# The compilation steps are: +# 1) Front end translation from MxNet into Relay module. +# 2) Apply 8-bit quantization: here we skip the first conv layer, +# and dense layer which will both be executed in fp32 on the CPU. +# 3) Perform graph packing to alter the data layout for tensorization. +# 4) Perform constant folding to reduce number of operators (e.g. eliminate +# batch norm multiply). +# 5) Perform relay build to object file. +# 6) Load the object file onto remote (FPGA device). +# 7) Generate graph runtime, `m`. + +# Load pre-configured AutoTVM schedules +with autotvm.tophub.context(target): + + # Populate the shape and data type dictionary for ResNet input + dtype_dict = {"data": 'float32'} + shape_dict = {"data": (env.BATCH, 3, 224, 224)} + + # Get off the shelf gluon model, and convert to relay + gluon_model = vision.get_model(model, pretrained=True) + + # Measure build start time + build_start = time.time() + + # Start front end compilation + mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) + + # Update shape and type dictionary + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Perform quantization in Relay + with relay.quantize.qconfig(global_scale=8.0, + skip_conv_layers=[0]): + relay_prog = relay.quantize.quantize(mod["main"], params=params) + + # Perform graph packing and constant folding for VTA target + if target.device_name == "vta": + assert env.BLOCK_IN == env.BLOCK_OUT + relay_prog = graph_pack( + relay_prog, + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=pack_dict[model][0], + stop_name=pack_dict[model][1]) + + # Compile Relay program with AlterOpLayout disabled + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + if target.device_name != "vta": + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + else: + with vta.build_config(): + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + + # Measure Relay build time + build_time = time.time() - build_start + print(model + " inference graph built in {0:.2f}s!".format(build_time)) + + # Send the inference library over to the remote RPC server + temp = util.tempdir() + lib.save(temp.relpath("graphlib.o")) + remote.upload(temp.relpath("graphlib.o")) + lib = remote.load_module("graphlib.o") + + # Graph runtime + m = graph_runtime.create(graph, lib, ctx) + +###################################################################### +# Perform ResNet-18 inference +# --------------------------- +# We run classification on an image sample from ImageNet +# We just need to download the categories files, `synset.txt` +# and an input test image. + +# Download ImageNet categories +categ_url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" +categ_fn = "synset.txt" +download.download(join(categ_url, categ_fn), categ_fn) +synset = eval(open(categ_fn).read()) + +# Download test image +image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg' +response = requests.get(image_url) + +# Prepare test image for inference +image = Image.open(BytesIO(response.content)).resize((224, 224)) +plt.imshow(image) +plt.show() +image = np.array(image) - np.array([123., 117., 104.]) +image /= np.array([58.395, 57.12, 57.375]) +image = image.transpose((2, 0, 1)) +image = image[np.newaxis, :] +image = np.repeat(image, env.BATCH, axis=0) + +# Set the network parameters and inputs +m.set_input(**params) +m.set_input('data', image) + +# Perform inference and gather execution statistics +# More on: https://docs.tvm.ai/api/python/module.html#tvm.module.Module.time_evaluator +num = 4 # number of times we run module for a single measurement +rep = 3 # number of measurements (we derive std dev from this) +timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep) + +if env.TARGET in ["sim", "tsim"]: + simulator.clear_stats() + timer() + sim_stats = simulator.stats() + print("\nExecution statistics:") + for k, v in sim_stats.items(): + # Since we execute the workload many times, we need to normalize stats + # Note that there is always one warm up run + # Therefore we divide the overall stats by (num * rep + 1) + print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1))) +else: + tcost = timer() + std = np.std(tcost.results) * 1000 / env.BATCH + mean = tcost.mean * 1000 / env.BATCH + print("\nPerformed inference in %.2fms/sample (std = %.2f)" % (mean, std)) + +# Get classification results +tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0))) +top_categories = np.argsort(tvm_output.asnumpy()[0]) + +# Report top-5 classification results +print("\n%s prediction" % model) +print("\t#1:", synset[top_categories[-1]]) +print("\t#2:", synset[top_categories[-2]]) +print("\t#3:", synset[top_categories[-3]]) +print("\t#4:", synset[top_categories[-4]]) +print("\t#5:", synset[top_categories[-5]]) + +# This just checks that one of the 5 top categories +# is one variety of cat; this is by no means an accurate +# assessment of how quantization affects classification +# accuracy but is meant to catch changes to the +# quantization pass that would accuracy in the CI. +cat_detected = False +for k in top_categories[-5:]: + if "cat" in synset[k]: + cat_detected = True +assert(cat_detected) From 0346408399309894663459bd73bb9eac9d23e12c Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Sat, 10 Aug 2019 18:57:14 -0700 Subject: [PATCH 11/18] generalizing to vision models --- .../frontend/deploy_resnet_on_vta.py | 293 ------------------ .../frontend/deploy_vision_on_vta.py | 14 +- 2 files changed, 7 insertions(+), 300 deletions(-) delete mode 100644 vta/tutorials/frontend/deploy_resnet_on_vta.py diff --git a/vta/tutorials/frontend/deploy_resnet_on_vta.py b/vta/tutorials/frontend/deploy_resnet_on_vta.py deleted file mode 100644 index f14a0e36b039..000000000000 --- a/vta/tutorials/frontend/deploy_resnet_on_vta.py +++ /dev/null @@ -1,293 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Deploy Pretrained Vision Model from MxNet on VTA -================================================ -**Author**: `Thierry Moreau `_ - -This tutorial provides an end-to-end demo, on how to run ImageNet classification -inference onto the VTA accelerator design to perform ImageNet classification tasks. -It showcases Relay as a front end compiler that can perform quantization (VTA -only supports int8/32 inference) as well as graph packing (in order to enable -tensorization in the core) to massage the compute graph for the hardware target. -""" - -###################################################################### -# Install dependencies -# -------------------- -# To use the autotvm package in tvm, we need to install some extra dependencies. -# (change "3" to "2" if you use python2): -# -# .. code-block:: bash -# -# pip3 install --user mxnet requests pillow -# -# Now return to the python code. Import packages. - -from __future__ import absolute_import, print_function - -import argparse, json, os, requests, sys, time -from io import BytesIO -from os.path import join, isfile -from PIL import Image - -from mxnet.gluon.model_zoo import vision -import numpy as np -from matplotlib import pyplot as plt - -import tvm -from tvm import rpc, autotvm, relay -from tvm.contrib import graph_runtime, util, download -from tvm.contrib.debugger import debug_runtime -from tvm.relay import transform - -import vta -from vta.testing import simulator -from vta.top import graph_pack - -# Make sure that TVM was compiled with RPC=1 -assert tvm.module.enabled("rpc") - -###################################################################### -# Define the platform and model targets -# ------------------------------------- -# Execute on CPU vs. VTA, and define the model. - -# Load VTA parameters from the vta/config/vta_config.json file -env = vta.get_env() - -# Set ``device=arm_cpu`` to run inference on the CPU -# or ``device=vta`` to run inference on the FPGA. -device = "vta" -target = env.target if device == "vta" else env.target_vta_cpu - -# Dictionary lookup for when to start/end bit packing -# TODO(zihengjiang, tmoreau89) some quantization will break until #3543 is merged -pack_dict = { - "alexnet": ["nn.max_pool2d", "nn.batch_flatten"], - "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "resnet152_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "vgg11": ["nn.max_pool2d", "nn.batch_flatten"], - "vgg13": ["nn.max_pool2d", "nn.batch_flatten"], - "vgg16": ["nn.max_pool2d", "nn.batch_flatten"], - "vgg19": ["nn.max_pool2d", "nn.batch_flatten"], -} - -# Name of Gluon model to compile -# The ``start_pack`` and ``stop_pack`` labels indicate where -# to start and end the graph packing relay pass: in other words -# where to start and finish offloading to VTA. -model = "resnet18_v1" -assert model in pack_dict - -###################################################################### -# Obtain an execution remote -# --------------------------------- -# When target is 'pynq', reconfigure FPGA and runtime. -# Otherwise, if target is 'sim', execute locally. - -if env.TARGET not in ["sim", "tsim"]: - - # Get remote from tracker node if environment variable is set. - # To set up the tracker, you'll need to follow the "Auto-tuning - # a convolutional network for VTA" tutorial. - tracker_host = os.environ.get("TVM_TRACKER_HOST", None) - tracker_port = os.environ.get("TVM_TRACKER_PORT", None) - # Otherwise if you have a device you want to program directly from - # the host, make sure you've set the variables below to the IP of - # your board. - device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") - device_port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091") - if not tracker_host or not tracker_port: - remote = rpc.connect(device_host, int(device_port)) - else: - remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000) - - # Reconfigure the JIT runtime and FPGA. - # You can program the FPGA with your own custom bitstream - # by passing the path to the bitstream file instead of None. - reconfig_start = time.time() - vta.reconfig_runtime(remote) - vta.program_fpga(remote, bitstream=None) - reconfig_time = time.time() - reconfig_start - print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) - -# In simulation mode, host the RPC server locally. -else: - remote = rpc.LocalSession() - -# Get execution context from remote -ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) - -###################################################################### -# Build the inference graph runtime -# --------------------------------- -# Grab vision model from Gluon model zoo and compile with Relay. -# The compilation steps are: -# 1) Front end translation from MxNet into Relay module. -# 2) Apply 8-bit quantization: here we skip the first conv layer, -# and dense layer which will both be executed in fp32 on the CPU. -# 3) Perform graph packing to alter the data layout for tensorization. -# 4) Perform constant folding to reduce number of operators (e.g. eliminate -# batch norm multiply). -# 5) Perform relay build to object file. -# 6) Load the object file onto remote (FPGA device). -# 7) Generate graph runtime, `m`. - -# Load pre-configured AutoTVM schedules -with autotvm.tophub.context(target): - - # Populate the shape and data type dictionary for ImageNet input - dtype_dict = {"data": 'float32'} - shape_dict = {"data": (env.BATCH, 3, 224, 224)} - - # Get off the shelf gluon model, and convert to relay - gluon_model = vision.get_model(model, pretrained=True) - - # Measure build start time - build_start = time.time() - - # Start front end compilation - mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) - - # Update shape and type dictionary - shape_dict.update({k: v.shape for k, v in params.items()}) - dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) - - # Perform quantization in Relay - with relay.quantize.qconfig(global_scale=8.0, - skip_conv_layers=[0]): - relay_prog = relay.quantize.quantize(mod["main"], params=params) - - # Perform graph packing and constant folding for VTA target - if target.device_name == "vta": - assert env.BLOCK_IN == env.BLOCK_OUT - relay_prog = graph_pack( - relay_prog, - env.BATCH, - env.BLOCK_OUT, - env.WGT_WIDTH, - start_name=pack_dict[model][0], - stop_name=pack_dict[model][1]) - - # Compile Relay program with AlterOpLayout disabled - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - if target.device_name != "vta": - graph, lib, params = relay.build( - relay_prog, target=target, - params=params, target_host=env.target_host) - else: - with vta.build_config(): - graph, lib, params = relay.build( - relay_prog, target=target, - params=params, target_host=env.target_host) - - # Measure Relay build time - build_time = time.time() - build_start - print(model + " inference graph built in {0:.2f}s!".format(build_time)) - - # Send the inference library over to the remote RPC server - temp = util.tempdir() - lib.save(temp.relpath("graphlib.o")) - remote.upload(temp.relpath("graphlib.o")) - lib = remote.load_module("graphlib.o") - - # Graph runtime - m = graph_runtime.create(graph, lib, ctx) - -###################################################################### -# Perform image classification -# --------------------------- -# We run classification on an image sample from ImageNet -# We just need to download the categories files, `synset.txt` -# and an input test image. - -# Download ImageNet categories -categ_url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" -categ_fn = "synset.txt" -download.download(join(categ_url, categ_fn), categ_fn) -synset = eval(open(categ_fn).read()) - -# Download test image -image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg' -response = requests.get(image_url) - -# Prepare test image for inference -image = Image.open(BytesIO(response.content)).resize((224, 224)) -plt.imshow(image) -plt.show() -image = np.array(image) - np.array([123., 117., 104.]) -image /= np.array([58.395, 57.12, 57.375]) -image = image.transpose((2, 0, 1)) -image = image[np.newaxis, :] -image = np.repeat(image, env.BATCH, axis=0) - -# Set the network parameters and inputs -m.set_input(**params) -m.set_input('data', image) - -# Perform inference and gather execution statistics -# More on: https://docs.tvm.ai/api/python/module.html#tvm.module.Module.time_evaluator -num = 4 # number of times we run module for a single measurement -rep = 3 # number of measurements (we derive std dev from this) -timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep) - -if env.TARGET in ["sim", "tsim"]: - simulator.clear_stats() - timer() - sim_stats = simulator.stats() - print("\nExecution statistics:") - for k, v in sim_stats.items(): - # Since we execute the workload many times, we need to normalize stats - # Note that there is always one warm up run - # Therefore we divide the overall stats by (num * rep + 1) - print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1))) -else: - tcost = timer() - std = np.std(tcost.results) * 1000 - mean = tcost.mean * 1000 - print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH)) - print("Average per sample inference time: %.2fms" % (mean/env.BATCH)) - -# Get classification results -tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0))) -for b in range(env.BATCH): - top_categories = np.argsort(tvm_output.asnumpy()[b]) - - # Report top-5 classification results - print("\n{} prediction for sample {}".format(model, b)) - print("\t#1:", synset[top_categories[-1]]) - print("\t#2:", synset[top_categories[-2]]) - print("\t#3:", synset[top_categories[-3]]) - print("\t#4:", synset[top_categories[-4]]) - print("\t#5:", synset[top_categories[-5]]) - - # This just checks that one of the 5 top categories - # is one variety of cat; this is by no means an accurate - # assessment of how quantization affects classification - # accuracy but is meant to catch changes to the - # quantization pass that would accuracy in the CI. - cat_detected = False - for k in top_categories[-5:]: - if "cat" in synset[k]: - cat_detected = True - assert(cat_detected) diff --git a/vta/tutorials/frontend/deploy_vision_on_vta.py b/vta/tutorials/frontend/deploy_vision_on_vta.py index 8229de1d1140..431a8f1f8857 100644 --- a/vta/tutorials/frontend/deploy_vision_on_vta.py +++ b/vta/tutorials/frontend/deploy_vision_on_vta.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. """ -Deploy Pretrained ResNet Model from MxNet on VTA +Deploy Pretrained Vision Model from MxNet on VTA ================================================ **Author**: `Thierry Moreau `_ -This tutorial provides an end-to-end demo, on how to run ResNet-18 inference -onto the VTA accelerator design to perform ImageNet classification tasks. +This tutorial provides an end-to-end demo, on how to run ImageNet classification +inference onto the VTA accelerator design to perform ImageNet classification tasks. It showcases Relay as a front end compiler that can perform quantization (VTA only supports int8/32 inference) as well as graph packing (in order to enable tensorization in the core) to massage the compute graph for the hardware target. @@ -141,7 +141,7 @@ ###################################################################### # Build the inference graph runtime # --------------------------------- -# Grab ResNet-18 model from Gluon model zoo and compile with Relay. +# Grab vision model from Gluon model zoo and compile with Relay. # The compilation steps are: # 1) Front end translation from MxNet into Relay module. # 2) Apply 8-bit quantization: here we skip the first conv layer, @@ -156,7 +156,7 @@ # Load pre-configured AutoTVM schedules with autotvm.tophub.context(target): - # Populate the shape and data type dictionary for ResNet input + # Populate the shape and data type dictionary for ImageNet classifier input dtype_dict = {"data": 'float32'} shape_dict = {"data": (env.BATCH, 3, 224, 224)} @@ -215,8 +215,8 @@ m = graph_runtime.create(graph, lib, ctx) ###################################################################### -# Perform ResNet-18 inference -# --------------------------- +# Perform image classification inference +# -------------------------------------- # We run classification on an image sample from ImageNet # We just need to download the categories files, `synset.txt` # and an input test image. From 8eafee1786c4d9a3ffdaee26270e14c7fd05bd53 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 15 Aug 2019 11:59:26 -0700 Subject: [PATCH 12/18] merge conflicts --- vta/python/vta/top/graphpack.py | 1 - .../frontend/deploy_vision_on_vta.py | 53 +++++++++---------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 53084eb4cc59..a4c054880ac2 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -319,7 +319,6 @@ def graph_pack(expr, expr : Expr The transformed expression. """ - import sys assert isinstance(expr, relay.Function) expr = get_subgraph(expr, start_name, stop_name) expr = run_opt_pass(expr, transform.InferType()) diff --git a/vta/tutorials/frontend/deploy_vision_on_vta.py b/vta/tutorials/frontend/deploy_vision_on_vta.py index 431a8f1f8857..18e74bca2a35 100644 --- a/vta/tutorials/frontend/deploy_vision_on_vta.py +++ b/vta/tutorials/frontend/deploy_vision_on_vta.py @@ -76,7 +76,6 @@ target = env.target if device == "vta" else env.target_vta_cpu # Dictionary lookup for when to start/end bit packing -# TODO(zihengjiang, tmoreau89) some quantization will break until #3543 is merged pack_dict = { "alexnet": ["nn.max_pool2d", "nn.batch_flatten"], "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], @@ -111,16 +110,16 @@ # To set up the tracker, you'll need to follow the "Auto-tuning # a convolutional network for VTA" tutorial. tracker_host = os.environ.get("TVM_TRACKER_HOST", None) - tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + tracker_port = os.environ.get("TVM_TRACKER_PORT", None) # Otherwise if you have a device you want to program directly from # the host, make sure you've set the variables below to the IP of # your board. device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") device_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) if not tracker_host or not tracker_port: - remote = rpc.connect(device_host, device_port) + remote = rpc.connect(device_host, int(device_port)) else: - remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000) + remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000) # Reconfigure the JIT runtime and FPGA. # You can program the FPGA with your own custom bitstream @@ -263,29 +262,29 @@ print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1))) else: tcost = timer() - std = np.std(tcost.results) * 1000 / env.BATCH - mean = tcost.mean * 1000 / env.BATCH - print("\nPerformed inference in %.2fms/sample (std = %.2f)" % (mean, std)) + std = np.std(tcost.results) * 1000 + mean = tcost.mean * 1000 + print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH)) + print("Average per sample inference time: %.2fms" % (mean/env.BATCH)) # Get classification results tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0))) -top_categories = np.argsort(tvm_output.asnumpy()[0]) - -# Report top-5 classification results -print("\n%s prediction" % model) -print("\t#1:", synset[top_categories[-1]]) -print("\t#2:", synset[top_categories[-2]]) -print("\t#3:", synset[top_categories[-3]]) -print("\t#4:", synset[top_categories[-4]]) -print("\t#5:", synset[top_categories[-5]]) - -# This just checks that one of the 5 top categories -# is one variety of cat; this is by no means an accurate -# assessment of how quantization affects classification -# accuracy but is meant to catch changes to the -# quantization pass that would accuracy in the CI. -cat_detected = False -for k in top_categories[-5:]: - if "cat" in synset[k]: - cat_detected = True -assert(cat_detected) +for b in range(env.BATCH): + top_categories = np.argsort(tvm_output.asnumpy()[b]) + # Report top-5 classification results + print("\n{} prediction for sample {}".format(model, b)) + print("\t#1:", synset[top_categories[-1]]) + print("\t#2:", synset[top_categories[-2]]) + print("\t#3:", synset[top_categories[-3]]) + print("\t#4:", synset[top_categories[-4]]) + print("\t#5:", synset[top_categories[-5]]) + # This just checks that one of the 5 top categories + # is one variety of cat; this is by no means an accurate + # assessment of how quantization affects classification + # accuracy but is meant to catch changes to the + # quantization pass that would accuracy in the CI. + cat_detected = False + for k in top_categories[-5:]: + if "cat" in synset[k]: + cat_detected = True + assert(cat_detected) From 018943e4d7a010465d3e978c8ddc0f0a70e5edc3 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Sun, 18 Aug 2019 09:34:49 -0700 Subject: [PATCH 13/18] fix, apply quantization to VTA only --- vta/tutorials/frontend/deploy_vision_on_vta.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vta/tutorials/frontend/deploy_vision_on_vta.py b/vta/tutorials/frontend/deploy_vision_on_vta.py index 18e74bca2a35..ccf3b248c3c3 100644 --- a/vta/tutorials/frontend/deploy_vision_on_vta.py +++ b/vta/tutorials/frontend/deploy_vision_on_vta.py @@ -115,7 +115,7 @@ # the host, make sure you've set the variables below to the IP of # your board. device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") - device_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) + device_port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091") if not tracker_host or not tracker_port: remote = rpc.connect(device_host, int(device_port)) else: @@ -172,13 +172,12 @@ shape_dict.update({k: v.shape for k, v in params.items()}) dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) - # Perform quantization in Relay - with relay.quantize.qconfig(global_scale=8.0, - skip_conv_layers=[0]): - relay_prog = relay.quantize.quantize(mod["main"], params=params) - - # Perform graph packing and constant folding for VTA target if target.device_name == "vta": + # Perform quantization in Relay + with relay.quantize.qconfig(global_scale=8.0, + skip_conv_layers=[0]): + relay_prog = relay.quantize.quantize(mod["main"], params=params) + # Perform graph packing and constant folding for VTA target assert env.BLOCK_IN == env.BLOCK_OUT relay_prog = graph_pack( relay_prog, @@ -187,6 +186,8 @@ env.WGT_WIDTH, start_name=pack_dict[model][0], stop_name=pack_dict[model][1]) + else: + relay_prog = mod["main"] # Compile Relay program with AlterOpLayout disabled with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): From 659b512466fd2eede91aba1545f77ffcdd64c85b Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Mon, 2 Sep 2019 12:51:53 -0700 Subject: [PATCH 14/18] improving comments --- src/relay/op/nn/pooling.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 06502c4ca9d0..76dec994ba3d 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -165,8 +165,8 @@ Array Pool2DCompute(const Attrs& attrs, inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)" - << " or 6-D input (last 2 dimensions are split of batch and channel)"; + << " or 5-D input (e.g. NCHWc on for vector instructions)" + << " or 6-D input (e.g. NCHWnc for tensor accelerators)"; if (param->padding.size() == 1) { padding.push_back(padding[0]); From 61deb90051e518294cb5c90b4cbe999d6a9c0e03 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Mon, 2 Sep 2019 12:55:08 -0700 Subject: [PATCH 15/18] trimming models that have runtime issues for the moment --- vta/tutorials/frontend/deploy_vision_on_vta.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vta/tutorials/frontend/deploy_vision_on_vta.py b/vta/tutorials/frontend/deploy_vision_on_vta.py index ccf3b248c3c3..a508fc425aec 100644 --- a/vta/tutorials/frontend/deploy_vision_on_vta.py +++ b/vta/tutorials/frontend/deploy_vision_on_vta.py @@ -77,18 +77,12 @@ # Dictionary lookup for when to start/end bit packing pack_dict = { - "alexnet": ["nn.max_pool2d", "nn.batch_flatten"], "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "resnet152_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], - "vgg11": ["nn.max_pool2d", "nn.batch_flatten"], - "vgg13": ["nn.max_pool2d", "nn.batch_flatten"], - "vgg16": ["nn.max_pool2d", "nn.batch_flatten"], - "vgg19": ["nn.max_pool2d", "nn.batch_flatten"], } # Name of Gluon model to compile From d4ae6c997222a11092a84519b961cc1354edb187 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Mon, 2 Sep 2019 12:56:22 -0700 Subject: [PATCH 16/18] lint --- python/tvm/relay/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index db4ea084522f..2ab9012e6455 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -17,6 +17,11 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import + +# Required to traverse large programs +from sys import setrecursionlimit +setrecursionlimit(10000) + from ..api import register_func from . import base from . import ty @@ -59,10 +64,6 @@ from .scope_builder import ScopeBuilder -# Required to traverse large programs -from sys import setrecursionlimit -setrecursionlimit(10000) - # Span Span = base.Span From dfc9d5f74a35edb657549bd61af83057c12f6a7b Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Mon, 2 Sep 2019 13:17:33 -0700 Subject: [PATCH 17/18] lint --- python/tvm/relay/__init__.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 2ab9012e6455..db4ea084522f 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -17,11 +17,6 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import - -# Required to traverse large programs -from sys import setrecursionlimit -setrecursionlimit(10000) - from ..api import register_func from . import base from . import ty @@ -64,6 +59,10 @@ from .scope_builder import ScopeBuilder +# Required to traverse large programs +from sys import setrecursionlimit +setrecursionlimit(10000) + # Span Span = base.Span From 6954bf389ea550a9adb8044a9455eca1036ead4f Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Mon, 2 Sep 2019 15:56:12 -0700 Subject: [PATCH 18/18] lint --- python/tvm/relay/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index db4ea084522f..b56ef6507782 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -17,6 +17,7 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import +from sys import setrecursionlimit from ..api import register_func from . import base from . import ty @@ -60,7 +61,6 @@ from .scope_builder import ScopeBuilder # Required to traverse large programs -from sys import setrecursionlimit setrecursionlimit(10000) # Span