Skip to content

Commit

Permalink
[microTVM] Add support for AutoTVM (#8715)
Browse files Browse the repository at this point in the history
* Initial commit of API server impl.

* initial commit of api client

* Add TVM-side glue code to use Project API

* Change tvm.micro.Session to use Project API

* Rework how crt_config.h is used on the host.

 * use template crt_config.h for host test runtime; delete
   src/runtime/crt/host/crt_config.h so that it doesn't diverge from
   the template
 * bring template crt_config.h inline with the one actually in use
  * rename to MAX_STRLEN_DLTYPE
 * Create a dedicated TVM-side host crt_config.h in src/runtime/micro

* Modify Transport infrastructure to work with Project API

* Add host microTVM API server

* Zephyr implementation of microTVM API server

 * move all zephyr projects to apps/microtvm/zephyr/template_project

* consolidate CcompilerAnnotator

* Allow model library format with c backend, add test.

* Update unit tests

* fix incorrect doc

* Delete old Zephyr build infrastructure

* Delete old build abstractions

* Delete old Transport implementations and simplify module

* lint

* ASF header

* address gromero comments

* final fixes?

* fix is_shutdown

* fix user-facing API

* fix TempDirectory / operator

* Update micro_tflite tutorial

* lint

* fix test_crt and test_link_params

* undo global micro import, hopefully fix fixture

* lint

* fix more tests

* Add session_constructor_args to tracker request() function.

 * Allows tracker clients to open non-traditional RPC sessions

* Generate entry_func symbol in C host codegen.

 * Needed for AutoTVM.

* print MeasureErrorNo enum value in MeasureResult repr

* Add microTVM session constructor.

 * This constructor is to be called from the RPC driver to flash and
   connect to the RPC server on the microcontroller.

* add build_kwargs as a Builder constructor arg.

 * build_kwargs is derived from pre-configured args, the runner, and
   now from the script.
 * user-supplied build kwargs override the other two, and a warning is
   printed if any key is overridden.

* Add do_fork option to Builder, to support stateful builders

 * When AutoTVM builder forks, any global state modified by the
   build_func is lost between builds

* Checkin module_loader used to build and flash microTVM for autotuning.

* Import micro into top-level when enabled.

 * AutoTVM RPC server needs to load the micro session constructor.

* Add tvm.contrib.random.random_fill to microTVM.

 * Allows autotuning with random data.

* Move compilation to runner :O

* Add a tutorial for AutoTVM with microcontrollers.

* Fix si_prefix in autotuner callback

* black format and git-clang-format

* Switch tutorial back to qemu version

* improve error reporting so CI will show test error

* black format

* autotvm is working

* fix tutorial

* fix dependencies

* fix auto tune issue

* lint

* address comments

* fix lint

* test crt and zephyr added

* fix func registery size

* moved autotune test and fixed

* fix crt test

* address comments

* change relay text

* change relay in text_zephyr

* class added

* changed relay module in tutorial and cleanup

* address comments

* address TK comments

* change fork

* final comments

* retrigger due to flahy test

* fix tutorial

* retrigger

* fix changes due to merge

Co-authored-by: Andrew Reusch <areusch@octoml.ai>
  • Loading branch information
mehrdadh and areusch authored Sep 9, 2021
1 parent e4478aa commit aa2b37d
Show file tree
Hide file tree
Showing 22 changed files with 778 additions and 33 deletions.
2 changes: 1 addition & 1 deletion apps/bundle_deploy/crt_config/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#define TVM_CRT_MAX_REGISTERED_MODULES 2

/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512

/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 512
Expand Down
1 change: 1 addition & 0 deletions apps/microtvm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ tensorflow-estimator = {version = "^2.1", optional = true}
# TFLite frontend
tflite = {version = "2.1.0", optional = true}
wheel = "*"
cloudpickle = "^1.6.0"


[tool.poetry.extras]
Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/crt/error_codes.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ typedef enum {
kTvmErrorFunctionCallNumArguments = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 0),
kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1),
kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2),
kTvmErrorFunctionCallInvalidArg = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 3),

// Time Evaluator - times functions for use with debug runtime.
kTvmErrorTimeEvaluatorBadHandle = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryTimeEvaluator, 0),
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel

if support.libinfo().get("USE_MICRO", "OFF") == "ON":
from . import micro

# NOTE: This file should be python2 compatible so we can
# raise proper error message when user run the package using
# an older version of the python
Expand Down
33 changes: 29 additions & 4 deletions python/tvm/autotvm/measure/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code"""
import enum
import multiprocessing
from collections import namedtuple

Expand Down Expand Up @@ -52,8 +53,19 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
The absolute time stamp when we finish measurement.
"""

def __repr__(self):
error_no_str = (
str(self.error_no)
if self.error_no not in MeasureErrorNo
else str(MeasureErrorNo(self.error_no))
)
return (
f"{self.__class__.__name__}(costs={self.costs!r}, error_no={error_no_str}, "
f"all_cost={self.all_cost}, timestamp={self.timestamp!r})"
)

class MeasureErrorNo(object):

class MeasureErrorNo(enum.IntEnum):
"""Error type for MeasureResult"""

NO_ERROR = 0 # no error
Expand All @@ -77,12 +89,15 @@ class Builder(object):
n_parallel: int, optional
The number of tasks submitted in parallel
By default it will use all cpu cores
build_kwargs: dict, optional
Keyword args given to the build function.
"""

def __init__(self, timeout=10, n_parallel=None):
def __init__(self, timeout=10, n_parallel=None, build_kwargs=None):
self.timeout = timeout
self.n_parallel = n_parallel or multiprocessing.cpu_count()
self.build_kwargs = {}
self.user_build_kwargs = build_kwargs if build_kwargs is not None else {}
self.runner_build_kwargs = None
self.task = None

def set_task(self, task, build_kwargs=None):
Expand All @@ -97,7 +112,17 @@ def set_task(self, task, build_kwargs=None):
The additional kwargs for build function
"""
self.task = task
self.build_kwargs = build_kwargs
self.build_kwargs = dict(build_kwargs.items()) if build_kwargs is not None else {}
if any(k in self.build_kwargs for k in self.user_build_kwargs):
logging.warn(
"Overriding these runner-supplied kwargs with user-supplied:\n%s",
"\n".join(
f" * {k}: from {build_kwargs[k]!r} to {self.user_build_kwargs[k]!r}"
for k in sorted([k for k in build_kwargs if k in self.user_build_kwargs])
),
)
for k, v in self.user_build_kwargs.items():
self.build_kwargs[k] = v

def build(self, measure_inputs):
"""Build programs
Expand Down
27 changes: 24 additions & 3 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,22 @@ class LocalBuilder(Builder):
The timeout of a compilation
n_parallel: int
The number of tasks run in parallel. "None" will use all cpu cores
build_kwargs: dict
If supplied, additional kwargs passed to build_func. Overrides any build_kwargs supplied
by the Runner.
build_func: callable or str
If is 'default', use default build function
If is 'ndk', use function for android ndk
If id 'stackvm', use function for stackvm
If is callable, use it as custom build function, expect lib_format field.
do_fork: bool
If False, do not fork when building. Requires n_parallel=1.
"""

def __init__(self, timeout=10, n_parallel=None, build_func="default"):
super(LocalBuilder, self).__init__(timeout, n_parallel)
def __init__(
self, timeout=10, n_parallel=None, build_kwargs=None, build_func="default", do_fork=False
):
super(LocalBuilder, self).__init__(timeout, n_parallel, build_kwargs)

if isinstance(build_func, str):
if build_func == "default":
Expand All @@ -99,6 +106,11 @@ def __init__(self, timeout=10, n_parallel=None, build_func="default"):
else:
raise ValueError("Invalid build_func" + build_func)
self.build_func = _WrappedBuildFunc(build_func)
if not do_fork:
assert n_parallel in (
None,
1,
), f"if do_fork=False, need n_parallel=None or 1; got {n_parallel}"
self.executor = PopenPoolExecutor(
timeout=timeout, initializer=reset_global_scope, initargs=(AutotvmGlobalScope.current,)
)
Expand Down Expand Up @@ -518,7 +530,16 @@ def __call__(self, measure_input, tmp_dir, **kwargs):
)
# TODO(tvm-team) consider linline _build_func_common
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, self.build_func)
if self.build_func.output_format == ".model-library-format":
# Late import to preserve autoTVM with USE_MICRO OFF
try:
from tvm import micro # pylint: disable=import-outside-toplevel
except ImportError:
raise ImportError("Requires USE_MICRO")

micro.export_model_library_format(func, filename)
else:
func.export_library(filename, self.build_func)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/tuner/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def __del__(self):

if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write(
"\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) "
"| %.2f s" % (prefix, 0, 0, 0, total, time.time() - tic)
"\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) "
"| %.2f s" % (prefix, 0, 0, si_prefix, 0, total, time.time() - tic)
)
sys.stdout.flush()

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""MicroTVM module for bare-metal backends"""

from .build import autotvm_build_func
from .build import AutoTvmModuleLoader
from .build import get_standalone_crt_dir
from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError
from .project import generate_project, GeneratedProject, TemplateProject
Expand Down
55 changes: 55 additions & 0 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

"""Defines top-level glue functions for building microTVM artifacts."""

import json
import logging
import os
import pathlib

from .._ffi import libinfo
from .. import rpc as _rpc


_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,3 +60,55 @@ def get_standalone_crt_dir() -> str:
raise CrtNotFoundError()

return STANDALONE_CRT_DIR


class AutoTvmModuleLoader:
"""MicroTVM AutoTVM Module Loader
Parameters
----------
template_project_dir : str
project template path
project_options : dict
project generation option
"""

def __init__(self, template_project_dir: str, project_options: dict = None):
self._project_options = project_options

if isinstance(template_project_dir, pathlib.Path):
self._template_project_dir = str(template_project_dir)
elif not isinstance(template_project_dir, str):
raise TypeError(f"Incorrect type {type(template_project_dir)}.")

def __call__(self, remote_kw, build_result):
with open(build_result.filename, "rb") as build_file:
build_result_bin = build_file.read()

tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"])
remote = tracker.request(
remote_kw["device_key"],
priority=remote_kw["priority"],
session_timeout=remote_kw["timeout"],
session_constructor_args=[
"tvm.micro.compile_and_create_micro_session",
build_result_bin,
self._template_project_dir,
json.dumps(self._project_options),
],
)
system_lib = remote.get_function("runtime.SystemLib")()
yield remote, system_lib
try:
remote.get_function("tvm.micro.destroy_micro_session")()
except tvm.error.TVMError as exception:
_LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1)


def autotvm_build_func():
"""A dummy build function which causes autotvm to use a different export format."""


# A sentinel value for the output format.
autotvm_build_func.output_format = ".model-library-format"
17 changes: 10 additions & 7 deletions python/tvm/micro/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,9 @@ def __init__(self, api_client):
if not self._info["is_template"]:
raise NotATemplateProjectError()

def generate_project(self, graph_executor_factory, project_dir, options):
"""Generate a project given GraphRuntimeFactory."""
model_library_dir = utils.tempdir()
model_library_format_path = model_library_dir.relpath("model.tar")
export_model_library_format(graph_executor_factory, model_library_format_path)

def generate_project_from_mlf(self, model_library_format_path, project_dir, options):
self._api_client.generate_project(
model_library_format_path=model_library_format_path,
model_library_format_path=str(model_library_format_path),
standalone_crt_dir=get_standalone_crt_dir(),
project_dir=project_dir,
options=options,
Expand All @@ -119,6 +114,14 @@ def generate_project(self, graph_executor_factory, project_dir, options):
def info(self):
return self._info

def generate_project(self, graph_executor_factory, project_dir, options):
"""Generate a project given GraphRuntimeFactory."""
model_library_dir = utils.tempdir()
model_library_format_path = model_library_dir.relpath("model.tar")
export_model_library_format(graph_executor_factory, model_library_format_path)

return self.generate_project_from_mlf(model_library_format_path, project_dir, options)


def generate_project(
template_project_dir: typing.Union[pathlib.Path, str],
Expand Down
73 changes: 72 additions & 1 deletion python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

"""Defines a top-level glue class that operates the Transport and Flasher classes."""

import json
import logging
import sys

from ..error import register_error
from .._ffi import get_global_func
from .._ffi import get_global_func, register_func
from ..contrib import graph_executor
from ..contrib import utils
from ..contrib.debugger import debug_executor
from ..rpc import RPCSession
from . import project
from .transport import IoTimeoutError
from .transport import TransportLogger

Expand Down Expand Up @@ -234,3 +237,71 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
graph_json_str,
dump_root=dump_root,
)


RPC_SESSION = None


@register_func("tvm.micro.compile_and_create_micro_session")
def compile_and_create_micro_session(
mod_src_bytes: bytes,
template_project_dir: str,
project_options: dict = None,
):
"""Compile the given libraries and sources into a MicroBinary, then invoke create_micro_session.
Parameters
----------
mod_src_bytes : bytes
The content of a tarfile which contains the TVM-generated sources which together form the
SystemLib. This tar is expected to be created by export_library. The tar will be extracted
into a directory and the sources compiled into a MicroLibrary using the Compiler.
template_project_dir: str
The path to a template microTVM Project API project which is used to generate the embedded
project that is built and flashed onto the target device.
project_options: dict
Options for the microTVM API Server contained in template_project_dir.
"""
global RPC_SESSION

temp_dir = utils.tempdir()
# Keep temp directory for generate project
temp_dir.set_keep_for_debug(True)
model_library_format_path = temp_dir / "model.tar.gz"
with open(model_library_format_path, "wb") as mlf_f:
mlf_f.write(mod_src_bytes)

try:
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
temp_dir / "generated-project",
options=json.loads(project_options),
)
except Exception as exception:
logging.error("Project Generate Error: %s", str(exception))
raise exception

generated_project.build()
generated_project.flash()
transport = generated_project.transport()

RPC_SESSION = Session(transport_context_manager=transport)
RPC_SESSION.__enter__()
return RPC_SESSION._rpc._sess


@register_func
def destroy_micro_session():
"""Destroy RPC session for microTVM autotune."""
global RPC_SESSION

if RPC_SESSION is not None:
exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None)
RPC_SESSION = None
if (exc_type, exc_value, traceback) != (None, None, None):
exc = exc_type(exc_value) # See PEP 3109
exc.__traceback__ = traceback
raise exc
Loading

0 comments on commit aa2b37d

Please sign in to comment.