Skip to content

Commit

Permalink
add the option to open a saved project for debugging.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkatanbaf committed Sep 15, 2022
1 parent 397cf87 commit 9ccdd30
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 30 deletions.
21 changes: 20 additions & 1 deletion python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,35 @@ class AutoTvmModuleLoader:
project_options : dict
project generation option
project_dir: str
if use_existing is False: The path to save the generated microTVM Project.
if use_existing is True: The path to a generated microTVM Project for debugging.
use_existing: bool
skips the project generation and opens transport to the project at the project_dir address.
"""

def __init__(
self, template_project_dir: Union[pathlib.Path, str], project_options: dict = None
self,
template_project_dir: Union[pathlib.Path, str],
project_options: dict = None,
project_dir: Union[pathlib.Path, str] = None,
use_existing: bool = False,
):
self._project_options = project_options
self._use_existing = use_existing

if isinstance(template_project_dir, (pathlib.Path, str)):
self._template_project_dir = str(template_project_dir)
elif not isinstance(template_project_dir, str):
raise TypeError(f"Incorrect type {type(template_project_dir)}.")

if isinstance(project_dir, (pathlib.Path, str)):
self._project_dir = str(project_dir)
else:
self._project_dir = None

@contextlib.contextmanager
def __call__(self, remote_kw, build_result):
with open(build_result.filename, "rb") as build_file:
Expand All @@ -147,6 +164,8 @@ def __call__(self, remote_kw, build_result):
build_result_bin,
self._template_project_dir,
json.dumps(self._project_options),
self._project_dir,
self._use_existing,
],
)
system_lib = remote.get_function("runtime.SystemLib")()
Expand Down
55 changes: 39 additions & 16 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import json
import logging
import sys

import pathlib
import shutil
from typing import Union
from ..error import register_error
from .._ffi import get_global_func, register_func
from ..contrib import graph_executor
Expand Down Expand Up @@ -259,6 +261,8 @@ def compile_and_create_micro_session(
mod_src_bytes: bytes,
template_project_dir: str,
project_options: dict = None,
project_dir: Union[pathlib.Path, str] = None,
use_existing: bool = False,
):
"""Compile the given libraries and sources into a MicroBinary, then invoke create_micro_session.
Expand All @@ -275,25 +279,44 @@ def compile_and_create_micro_session(
project_options: dict
Options for the microTVM API Server contained in template_project_dir.
"""
temp_dir = utils.tempdir()
# Keep temp directory for generate project
temp_dir.set_keep_for_debug(True)
model_library_format_path = temp_dir / "model.tar.gz"
with open(model_library_format_path, "wb") as mlf_f:
mlf_f.write(mod_src_bytes)
project_dir: str
if use_existing is False: The path to save the generated microTVM Project.
if use_existing is True: The path to a generated microTVM Project for debugging.
try:
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
str(temp_dir / "generated-project"),
use_existing: bool
skips the project generation and opens transport to the project at the project_dir address.
"""

if use_existing:
project_dir = pathlib.Path(project_dir)
assert project_dir.is_dir(), f"{project_dir} does not exist."
build_dir = project_dir / "generated-project" / "build"
shutil.rmtree(build_dir)
generated_project = project.GeneratedProject.from_directory(
project_dir / "generated-project",
options=json.loads(project_options),
)
except Exception as exception:
logging.error("Project Generate Error: %s", str(exception))
raise exception
else:
if project_dir:
temp_dir = utils.tempdir(custom_path=project_dir, keep_for_debug=True)
else:
temp_dir = utils.tempdir()

model_library_format_path = temp_dir / "model.tar.gz"
with open(model_library_format_path, "wb") as mlf_f:
mlf_f.write(mod_src_bytes)

try:
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
str(temp_dir / "generated-project"),
options=json.loads(project_options),
)
except Exception as exception:
logging.error("Project Generate Error: %s", str(exception))
raise exception

generated_project.build()
generated_project.flash()
Expand Down
39 changes: 26 additions & 13 deletions python/tvm/micro/testing/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pathlib import Path
from contextlib import ExitStack
import tempfile
import shutil

import tvm
from tvm.relay.op.contrib import cmsisnn
Expand All @@ -53,6 +54,7 @@ def tune_model(
"project_type": "host_driven",
**(project_options or {}),
}

module_loader = tvm.micro.AutoTvmModuleLoader(
template_project_dir=tvm.micro.get_microtvm_template_projects(platform),
project_options=project_options,
Expand Down Expand Up @@ -99,6 +101,7 @@ def create_aot_session(
timeout_override=None,
use_cmsis_nn=False,
project_options=None,
use_existing=False,
):
"""AOT-compiles and uploads a model to a microcontroller, and returns the RPC session"""

Expand All @@ -125,21 +128,31 @@ def create_aot_session(
parameter_size = len(tvm.runtime.save_param_dict(lowered.get_params()))
print(f"Model parameter size: {parameter_size}")

project = tvm.micro.generate_project(
str(tvm.micro.get_microtvm_template_projects(platform)),
lowered,
build_dir / "project",
{
f"{platform}_board": board,
"project_type": "host_driven",
# {} shouldn't be the default value for project options ({}
# is mutable), so we use this workaround
**(project_options or {}),
},
)
project_options = {
f"{platform}_board": board,
"project_type": "host_driven",
# {} shouldn't be the default value for project options ({}
# is mutable), so we use this workaround
**(project_options or {}),
}

if use_existing:
shutil.rmtree(build_dir / "project" / "build")
project = tvm.micro.GeneratedProject.from_directory(
build_dir / "project",
options=project_options,
)

else:
project = tvm.micro.generate_project(
str(tvm.micro.get_microtvm_template_projects(platform)),
lowered,
build_dir / "project",
project_options,
)

project.build()
project.flash()

return tvm.micro.Session(project.transport(), timeout_override=timeout_override)


Expand Down

0 comments on commit 9ccdd30

Please sign in to comment.