Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VTA][Relay] Extending Vision model coverage compilation for VTA #3740

Merged
merged 18 commits into from
Sep 5, 2019
Merged
4 changes: 4 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
from sys import setrecursionlimit
from ..api import register_func
from . import base
from . import ty
Expand Down Expand Up @@ -59,6 +60,9 @@

from .scope_builder import ScopeBuilder

# Required to traverse large programs
setrecursionlimit(10000)

# Span
Span = base.Span

Expand Down
7 changes: 5 additions & 2 deletions src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool2d does not support input split on width";

CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
CHECK(inputs[0].ndim() == 4U ||
inputs[0].ndim() == 5U ||
inputs[0].ndim() == 6U)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";
<< " or 5-D input (e.g. NCHWc on for vector instructions)"
<< " or 6-D input (e.g. NCHWnc for tensor accelerators)";

if (param->padding.size() == 1) {
padding.push_back(padding[0]);
Expand Down
43 changes: 28 additions & 15 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def _pack_weight_conv2d_transpose(data, dshape, cfactor):
return data


def _pack_bias(data, dshape, dtype, bfactor, cfactor):
"""Pack the bias parameter.
def _pack_const(data, dshape, dtype, bfactor, cfactor):
"""Pack a constant parameter.
"""
dshape = _to_shape(dshape)
assert len(dshape) == 3
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(self, bfactor, cfactor, weight_bits):
self.conv2d = op.op.get("nn.conv2d")
self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
self.add = op.op.get("add")
self.multiply = op.op.get("multiply")
self.bias_add = op.op.get("nn.bias_add")
self.number_of_conv2d = 0
super().__init__()
Expand Down Expand Up @@ -203,23 +204,35 @@ def visit_call(self, call):
output_padding=call.attrs.output_padding,
out_dtype=call.attrs.out_dtype)
return conv2d
elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape):
elif call.op == self.add and \
tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass
elif call.op == self.add and len(input_types[1].shape) == 3:
data, bias = args
bias = _pack_bias(bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.add, [data, bias])
data, const = args
const = _pack_const(const,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.add, [data, const])
elif call.op == self.multiply and \
tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass
elif call.op == self.multiply and len(input_types[1].shape) == 3:
data, const = args
const = _pack_const(const,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.multiply, [data, const])
elif self.start_pack and call.op == self.bias_add:
data, bias = args
bias = _pack_bias(bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
bias = _pack_const(bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor)
return relay.Call(self.add, [data, bias])
elif self.start_pack and call.op == op.op.get('cast') and \
input_types[0].dtype == 'int32':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""
Deploy Pretrained ResNet Model from MxNet on VTA
Deploy Pretrained Vision Model from MxNet on VTA
================================================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_

This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
onto the VTA accelerator design to perform ImageNet classification tasks.
This tutorial provides an end-to-end demo, on how to run ImageNet classification
inference onto the VTA accelerator design to perform ImageNet classification tasks.
It showcases Relay as a front end compiler that can perform quantization (VTA
only supports int8/32 inference) as well as graph packing (in order to enable
tensorization in the core) to massage the compute graph for the hardware target.
Expand All @@ -40,7 +40,7 @@

from __future__ import absolute_import, print_function

import argparse, json, os, requests, time
import argparse, json, os, requests, sys, time
from io import BytesIO
from os.path import join, isfile
from PIL import Image
Expand All @@ -53,6 +53,7 @@
from tvm import rpc, autotvm, relay
from tvm.contrib import graph_runtime, util, download
from tvm.contrib.debugger import debug_runtime
from tvm.relay import transform

import vta
from vta.testing import simulator
Expand All @@ -61,7 +62,6 @@
# Make sure that TVM was compiled with RPC=1
assert tvm.module.enabled("rpc")


######################################################################
# Define the platform and model targets
# -------------------------------------
Expand All @@ -75,13 +75,22 @@
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu

# Dictionary lookup for when to start/end bit packing
pack_dict = {
"resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
}

# Name of Gluon model to compile
# The ``start_pack`` and ``stop_pack`` labels indicate where
# to start and end the graph packing relay pass: in other words
# where to start and finish offloading to VTA.
model = "resnet18_v1"
start_pack="nn.max_pool2d"
stop_pack="nn.global_avg_pool2d"
assert model in pack_dict

######################################################################
# Obtain an execution remote
Expand Down Expand Up @@ -125,7 +134,7 @@
######################################################################
# Build the inference graph runtime
# ---------------------------------
# Grab ResNet-18 model from Gluon model zoo and compile with Relay.
# Grab vision model from Gluon model zoo and compile with Relay.
# The compilation steps are:
# 1) Front end translation from MxNet into Relay module.
# 2) Apply 8-bit quantization: here we skip the first conv layer,
Expand All @@ -140,7 +149,7 @@
# Load pre-configured AutoTVM schedules
with autotvm.tophub.context(target):

# Populate the shape and data type dictionary for ResNet input
# Populate the shape and data type dictionary for ImageNet classifier input
dtype_dict = {"data": 'float32'}
shape_dict = {"data": (env.BATCH, 3, 224, 224)}

Expand All @@ -157,21 +166,22 @@
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})

# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod["main"], params=params)

# Perform graph packing and constant folding for VTA target
if target.device_name == "vta":
# Perform quantization in Relay
with relay.quantize.qconfig(global_scale=8.0,
skip_conv_layers=[0]):
relay_prog = relay.quantize.quantize(mod["main"], params=params)
# Perform graph packing and constant folding for VTA target
assert env.BLOCK_IN == env.BLOCK_OUT
relay_prog = graph_pack(
relay_prog,
env.BATCH,
env.BLOCK_OUT,
env.WGT_WIDTH,
start_name=start_pack,
stop_name=stop_pack)
start_name=pack_dict[model][0],
stop_name=pack_dict[model][1])
else:
relay_prog = mod["main"]

# Compile Relay program with AlterOpLayout disabled
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
Expand Down Expand Up @@ -199,8 +209,8 @@
m = graph_runtime.create(graph, lib, ctx)

######################################################################
# Perform ResNet-18 inference
# ---------------------------
# Perform image classification inference
# --------------------------------------
# We run classification on an image sample from ImageNet
# We just need to download the categories files, `synset.txt`
# and an input test image.
Expand Down Expand Up @@ -256,15 +266,13 @@
tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
for b in range(env.BATCH):
top_categories = np.argsort(tvm_output.asnumpy()[b])

# Report top-5 classification results
print("\n{} prediction for sample {}".format(model, b))
print("\t#1:", synset[top_categories[-1]])
print("\t#2:", synset[top_categories[-2]])
print("\t#3:", synset[top_categories[-3]])
print("\t#4:", synset[top_categories[-4]])
print("\t#5:", synset[top_categories[-5]])

# This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification
Expand Down