Skip to content

Commit

Permalink
[AUTOTVM] End2End autotvm support for vta (apache#18)
Browse files Browse the repository at this point in the history
* support tuning a whole network

* pass unit test

* update tune resnet

* update all
  • Loading branch information
merrymercy authored and tmoreau89 committed Jan 2, 2019
1 parent 598e2c6 commit 1cd86af
Show file tree
Hide file tree
Showing 20 changed files with 874 additions and 1,042 deletions.
13 changes: 12 additions & 1 deletion python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,18 @@ def set_task(self, task):
for x in arg_bufs]
func = build(s, arg_bufs, "llvm")
tvm_buf = [nd.array(x) for x in self.ref_input]
func(*tvm_buf)

def _run_func():
"""Run tvm function in a thread.
Because there is some issues with python multiprocessing and the thread pool in tvm
"""
func(*tvm_buf)

thread = threading.Thread(target=_run_func)
thread.start()
thread.join()
del thread

self.ref_output = [x.asnumpy() for x in tvm_buf]

def get_build_kwargs(self):
Expand Down
33 changes: 15 additions & 18 deletions python/tvm/autotvm/task/nnvm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import warnings
import logging
import sys


from ... import target as _target
Expand All @@ -18,8 +19,7 @@
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
""" Extract tuning tasks from a nnvm graph.
This function collects tuning tasks by building the graph
with a "tracing" target and tracing all the calls to topi.
This function collects tuning tasks by building the graph and trace all the calls to topi.
Parameters
----------
Expand All @@ -45,7 +45,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
import nnvm
import topi

env = TaskExtractEnv.get()
env = TaskExtractEnv(symbols)

#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
Expand All @@ -63,26 +63,23 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# run compiler to collect all TOPI calls during compilation
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
nnvm.compiler.engine.clear_cache()

# disable logger temporarily
old_state = logger.disabled
logger.disabled = True

# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)

logger.disabled = old_state
logger.disabled = old_state

# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("shape error")

return tasks

Expand Down
7 changes: 7 additions & 0 deletions python/tvm/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,13 @@ def rasp(options=None):
return arm_cpu('rasp3b', options)


def vta(model='unknown', options=None):
opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
opts = _merge_opts(opts, options)
ret = _api_internal._TargetCreate("ext_dev", *opts)
return ret


def create(target_str):
"""Get a target given target string.
Expand Down
11 changes: 8 additions & 3 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Target CreateTarget(const std::string& target_name,

std::string libs_flag = "-libs=";
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
t->options_array.push_back(ir::StringImm::make(item));

Expand All @@ -50,12 +51,16 @@ Target CreateTarget(const std::string& target_name,
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(ir::StringImm::make(t->device_name));
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(ir::StringImm::make(key_item));
}
}
}

if (t->device_name.length() > 0) {
t->keys_array.push_back(ir::StringImm::make(t->device_name));
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "c" || target_name == "llvm") {
Expand Down
4 changes: 4 additions & 0 deletions topi/python/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
from . import image
from . import sparse
from . import hls

# some short cut
from .util import InvalidShapeError

# not import testing by default
# because testing can have extra deps that are not necessary
# we can import them from test cases explicitly
Expand Down
4 changes: 4 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import tvm
from . import tag

class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
pass

def traverse_inline(s, final_op, callback):
"""Traverse computation graph and do auto inline
Expand Down
3 changes: 2 additions & 1 deletion vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# to maintain minimum dependency on the board
if sys.argv[0] not in ("-c", "-m"):
from . import top
from .build_module import build_config, lower, build
from . import graph

from .build_module import build_config, lower, build, vta_autotvm_build_func
from .ptr_alias import reinterpret
36 changes: 36 additions & 0 deletions vta/python/vta/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,39 @@ def build(*args, **kwargs):
with build_config():
return tvm.build(*args, **kwargs)
return tvm.build(*args, **kwargs)


def vta_autotvm_build_func(measure_input, tmp_dir, **kwargs):
"""Custom build func for VTA. Used for autotvm"""

import time
import os
from random import getrandbits
from tvm.autotvm.util import get_const_tuple
from tvm.autotvm.measure.measure_methods import BuildResult, InstantiationError

tic = time.time()
try:
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
target, task, config = measure_input

with target:
s, args = task.instantiate(config)
if not config.valid():
raise InstantiationError(config.errors)

func = build(s, args, target_host=task.target_host)
func2 = build(s, args)

arg_info = tuple((get_const_tuple(x.shape), x.dtype) for x in args)
func.export_library(filename)

# check by local simulator
ctx = tvm.context(str(target))
args = [tvm.nd.empty(x[0], dtype=x[1], ctx=ctx) for x in arg_info]
func2(*args)

except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)

24 changes: 7 additions & 17 deletions vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,35 +227,25 @@ def gemm(self):
"""GEMM intrinsic"""
return self.dev.gemm

# TODO get rid of it
@property
def target_host(self):
"""The target host"""
return "llvm " + self.llvm_triple
def target(self):
return tvm.target.vta(model=self.TARGET)

@property
def target_vta_cpu(self):
def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
return "llvm -device=arm_cpu -model=pynq {}".format(self.llvm_triple)
return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "ultra96":
return "llvm -device=arm_cpu -model=ultra96 {}".format(self.llvm_triple)
return "llvm -target=aarch64-linux-gnu"
elif self.TARGET == "sim":
return "llvm"
else:
raise ValueError("Unknown target %s" % self.TARGET)

@property
def llvm_triple(self):
"""The llvm flags for the target platform"""
if self.TARGET == "pynq":
return "-target=armv7-none-linux-gnueabihf"
elif self.TARGET == "ultra96":
return "-target=aarch64-linux-gnu"
elif self.TARGET == "sim":
return ""
else:
raise ValueError("Unknown target %s" % self.TARGET)
def target_vta_cpu(self):
return tvm.target.arm_cpu(model=self.TARGET)


def get_env():
Expand Down
9 changes: 1 addition & 8 deletions vta/python/vta/top/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
"""TVM TOPI connector, eventually most of these should go to TVM repo"""

from . import vta_conv2d
from . import arm_conv2d
from . import testing

from .bitpack import bitpack
from .vta_dense import packed_dense, schedule_packed_dense
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d
from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose
from . import op
24 changes: 0 additions & 24 deletions vta/python/vta/top/arm_conv2d.py

This file was deleted.

Loading

0 comments on commit 1cd86af

Please sign in to comment.