diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 795a61edcbb3..82a12e2c264a 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -120,18 +120,35 @@ class AutoTvmModuleLoader: project_options : dict project generation option + + project_dir: str + if use_existing is False: The path to save the generated microTVM Project. + if use_existing is True: The path to a generated microTVM Project for debugging. + + use_existing: bool + skips the project generation and opens transport to the project at the project_dir address. """ def __init__( - self, template_project_dir: Union[pathlib.Path, str], project_options: dict = None + self, + template_project_dir: Union[pathlib.Path, str], + project_options: dict = None, + project_dir: Union[pathlib.Path, str] = None, + use_existing: bool = False, ): self._project_options = project_options + self._use_existing = use_existing if isinstance(template_project_dir, (pathlib.Path, str)): self._template_project_dir = str(template_project_dir) elif not isinstance(template_project_dir, str): raise TypeError(f"Incorrect type {type(template_project_dir)}.") + if isinstance(project_dir, (pathlib.Path, str)): + self._project_dir = str(project_dir) + else: + self._project_dir = None + @contextlib.contextmanager def __call__(self, remote_kw, build_result): with open(build_result.filename, "rb") as build_file: @@ -147,6 +164,8 @@ def __call__(self, remote_kw, build_result): build_result_bin, self._template_project_dir, json.dumps(self._project_options), + self._project_dir, + self._use_existing, ], ) system_lib = remote.get_function("runtime.SystemLib")() diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index 967eaee62958..8f7e55ad735b 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -20,7 +20,9 @@ import json import logging import sys - +import pathlib +import shutil +from typing import Union from ..error import register_error from .._ffi import get_global_func, register_func from ..contrib import graph_executor @@ -259,6 +261,8 @@ def compile_and_create_micro_session( mod_src_bytes: bytes, template_project_dir: str, project_options: dict = None, + project_dir: Union[pathlib.Path, str] = None, + use_existing: bool = False, ): """Compile the given libraries and sources into a MicroBinary, then invoke create_micro_session. @@ -275,25 +279,44 @@ def compile_and_create_micro_session( project_options: dict Options for the microTVM API Server contained in template_project_dir. - """ - temp_dir = utils.tempdir() - # Keep temp directory for generate project - temp_dir.set_keep_for_debug(True) - model_library_format_path = temp_dir / "model.tar.gz" - with open(model_library_format_path, "wb") as mlf_f: - mlf_f.write(mod_src_bytes) + project_dir: str + if use_existing is False: The path to save the generated microTVM Project. + if use_existing is True: The path to a generated microTVM Project for debugging. - try: - template_project = project.TemplateProject.from_directory(template_project_dir) - generated_project = template_project.generate_project_from_mlf( - model_library_format_path, - str(temp_dir / "generated-project"), + use_existing: bool + skips the project generation and opens transport to the project at the project_dir address. + """ + + if use_existing: + project_dir = pathlib.Path(project_dir) + assert project_dir.is_dir(), f"{project_dir} does not exist." + build_dir = project_dir / "generated-project" / "build" + shutil.rmtree(build_dir) + generated_project = project.GeneratedProject.from_directory( + project_dir / "generated-project", options=json.loads(project_options), ) - except Exception as exception: - logging.error("Project Generate Error: %s", str(exception)) - raise exception + else: + if project_dir: + temp_dir = utils.tempdir(custom_path=project_dir, keep_for_debug=True) + else: + temp_dir = utils.tempdir() + + model_library_format_path = temp_dir / "model.tar.gz" + with open(model_library_format_path, "wb") as mlf_f: + mlf_f.write(mod_src_bytes) + + try: + template_project = project.TemplateProject.from_directory(template_project_dir) + generated_project = template_project.generate_project_from_mlf( + model_library_format_path, + str(temp_dir / "generated-project"), + options=json.loads(project_options), + ) + except Exception as exception: + logging.error("Project Generate Error: %s", str(exception)) + raise exception generated_project.build() generated_project.flash() diff --git a/python/tvm/micro/testing/evaluation.py b/python/tvm/micro/testing/evaluation.py index c8a90ff5b40f..1d80ed5568b2 100644 --- a/python/tvm/micro/testing/evaluation.py +++ b/python/tvm/micro/testing/evaluation.py @@ -27,6 +27,7 @@ from pathlib import Path from contextlib import ExitStack import tempfile +import shutil import tvm from tvm.relay.op.contrib import cmsisnn @@ -53,6 +54,7 @@ def tune_model( "project_type": "host_driven", **(project_options or {}), } + module_loader = tvm.micro.AutoTvmModuleLoader( template_project_dir=tvm.micro.get_microtvm_template_projects(platform), project_options=project_options, @@ -99,6 +101,7 @@ def create_aot_session( timeout_override=None, use_cmsis_nn=False, project_options=None, + use_existing=False, ): """AOT-compiles and uploads a model to a microcontroller, and returns the RPC session""" @@ -125,21 +128,31 @@ def create_aot_session( parameter_size = len(tvm.runtime.save_param_dict(lowered.get_params())) print(f"Model parameter size: {parameter_size}") - project = tvm.micro.generate_project( - str(tvm.micro.get_microtvm_template_projects(platform)), - lowered, - build_dir / "project", - { - f"{platform}_board": board, - "project_type": "host_driven", - # {} shouldn't be the default value for project options ({} - # is mutable), so we use this workaround - **(project_options or {}), - }, - ) + project_options = { + f"{platform}_board": board, + "project_type": "host_driven", + # {} shouldn't be the default value for project options ({} + # is mutable), so we use this workaround + **(project_options or {}), + } + + if use_existing: + shutil.rmtree(build_dir / "project" / "build") + project = tvm.micro.GeneratedProject.from_directory( + build_dir / "project", + options=project_options, + ) + + else: + project = tvm.micro.generate_project( + str(tvm.micro.get_microtvm_template_projects(platform)), + lowered, + build_dir / "project", + project_options, + ) + project.build() project.flash() - return tvm.micro.Session(project.transport(), timeout_override=timeout_override)