Skip to content

Commit

Permalink
[COMPILER] Refactor compiler to enable configuration (apache#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jul 12, 2018
1 parent 2fc7621 commit 8eae56a
Show file tree
Hide file tree
Showing 19 changed files with 977 additions and 729 deletions.
16 changes: 9 additions & 7 deletions vta/examples/resnet18/pynq/imagenet_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
if verbose:
logging.basicConfig(level=logging.INFO)

# Change to -device=tcpu to run cpu only inference.
# Change to -device=vta-cpu to run cpu only inference.
target = "llvm -device=vta"
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"

synset = eval(open(os.path.join(CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
Expand Down Expand Up @@ -117,15 +118,16 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
sym = sym.apply("InferType")

with nnvm.compiler.build_config(opt_level=3):
bdict = {}
if "vta" not in target:
bdict = {"add_lower_pass": []}
else:
bdict = {"add_lower_pass": vta.debug_mode(0)}
with tvm.build_config(**bdict):
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params)
params=params, target_host=target_host)
else:
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)


temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
Expand Down
15 changes: 4 additions & 11 deletions vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
"""TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs

from .hw_spec import *
from .environment import get_env, Environment

try:
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
from .intrin import GEVM, GEMM
from .build import debug_mode
from . import mock, ir_pass
# allow optional import in config mode.
from . import arm_conv2d, vta_conv2d
except AttributeError:
pass

from .rpc_client import reconfig_runtime, program_fpga

try:
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga
from . import graph
except ImportError:
pass
55 changes: 0 additions & 55 deletions vta/python/vta/build.py

This file was deleted.

101 changes: 101 additions & 0 deletions vta/python/vta/build_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""VTA specific buildin for runtime."""
from __future__ import absolute_import as _abs

import tvm
from . import ir_pass
from .environment import get_env


def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x

def early_rewrite(stmt):
"""Try to do storage rewrite in early pass."""
try:
return tvm.ir_pass.StorageRewrite(stmt)
except tvm.TVMError:
return stmt


def build_config(debug_flag=0, **kwargs):
"""Build a build config for VTA.
Parameters
----------
debug_flag : int
The dbeug flag to be passed.
kwargs : dict
Additional configurations.
Returns
-------
build_config: BuildConfig
The build config that can be used in TVM.
Example
--------
.. code-block:: python
# build a vta module.
with vta.build_config():
vta_module = tvm.build(s, ...)
"""
env = get_env()
def add_debug(stmt):
debug = tvm.call_extern(
"int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)

return tvm.make.stmt_seq(debug, stmt)
pass_list = [(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
(1, ir_pass.annotate_alu_coproc_scope),
(1, lambda x: tvm.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
(1, lift_coproc_scope),
(1, ir_pass.inject_coproc_sync),
(1, early_rewrite)]
if debug_flag:
pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite))
return tvm.build_config(add_lower_pass=pass_list, **kwargs)


def lower(*args, **kwargs):
"""Thin wrapper of tvm.lower
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.lower : The original TVM's lower function
"""
cfg = tvm.build_module.current_build_config()
if not cfg.add_lower_pass:
with build_config():
return tvm.lower(*args, **kwargs)
return tvm.lower(*args, **kwargs)


def build(*args, **kwargs):
"""Thin wrapper of tvm.build
This wrapper automatically applies VTA's build_config
if there is no user specified build_config in context.
See Also
--------
tvm.build : The original TVM's build function
"""
cfg = tvm.build_module.current_build_config()
if not cfg.add_lower_pass:
with build_config():
return tvm.build(*args, **kwargs)
return tvm.build(*args, **kwargs)
Loading

0 comments on commit 8eae56a

Please sign in to comment.