From b54beed37ca2baad6002990b014a2119223e0900 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Wed, 8 Dec 2021 04:33:54 -0800 Subject: [PATCH] [TVMC][MicroTVM] Fix tvmc micro `project_dir` arg relative path (#9663) * Add fix for project dir path * address @gromero comments --- python/tvm/driver/tvmc/common.py | 11 ++++++++-- python/tvm/driver/tvmc/micro.py | 27 ++++++++++++----------- python/tvm/driver/tvmc/runner.py | 6 ++++-- tests/micro/common/test_tvmc.py | 37 +++++++++++++++++++++++++------- 4 files changed, 56 insertions(+), 25 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 97b7c5206a38..5319193886b4 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -22,12 +22,12 @@ import logging import os.path import argparse - +import pathlib +from typing import Union from collections import defaultdict from urllib.parse import urlparse import tvm - from tvm.driver import tvmc from tvm import relay from tvm import transform @@ -786,3 +786,10 @@ def get_and_check_options(passed_options, valid_options): check_options_choices(opts, valid_options) return opts + + +def get_project_dir(project_dir: Union[pathlib.Path, str]) -> str: + """Get project directory path""" + if not os.path.isabs(project_dir): + return os.path.abspath(project_dir) + return project_dir diff --git a/python/tvm/driver/tvmc/micro.py b/python/tvm/driver/tvmc/micro.py index ef72446b931c..a9c17b840ca6 100644 --- a/python/tvm/driver/tvmc/micro.py +++ b/python/tvm/driver/tvmc/micro.py @@ -29,6 +29,7 @@ TVMCSuppressedArgumentParser, get_project_options, get_and_check_options, + get_project_dir, ) try: @@ -238,16 +239,16 @@ def drive_micro(args): def create_project_handler(args): """Creates a new project dir.""" + project_dir = get_project_dir(args.project_dir) - if os.path.exists(args.project_dir): + if os.path.exists(project_dir): if args.force: - shutil.rmtree(args.project_dir) + shutil.rmtree(project_dir) else: raise TVMCException( "The specified project dir already exists. " "To force overwriting it use '-f' or '--force'." ) - project_dir = args.project_dir template_dir = str(Path(args.template_dir).resolve()) if not os.path.exists(template_dir): @@ -268,21 +269,20 @@ def create_project_handler(args): def build_handler(args): """Builds a firmware image given a project dir.""" + project_dir = get_project_dir(args.project_dir) - if not os.path.exists(args.project_dir): - raise TVMCException(f"{args.project_dir} doesn't exist.") + if not os.path.exists(project_dir): + raise TVMCException(f"{project_dir} doesn't exist.") - if os.path.exists(args.project_dir + "/build"): + if os.path.exists(project_dir + "/build"): if args.force: - shutil.rmtree(args.project_dir + "/build") + shutil.rmtree(project_dir + "/build") else: raise TVMCException( - f"There is already a build in {args.project_dir}. " + f"There is already a build in {project_dir}. " "To force rebuild it use '-f' or '--force'." ) - project_dir = args.project_dir - options = get_and_check_options(args.project_option, args.valid_options) try: @@ -295,10 +295,11 @@ def build_handler(args): def flash_handler(args): """Flashes a firmware image to a target device given a project dir.""" - if not os.path.exists(args.project_dir + "/build"): - raise TVMCException(f"Could not find a build in {args.project_dir}") - project_dir = args.project_dir + project_dir = get_project_dir(args.project_dir) + + if not os.path.exists(project_dir + "/build"): + raise TVMCException(f"Could not find a build in {project_dir}") options = get_and_check_options(args.project_option, args.valid_options) diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index b140cf67b10b..4a3790666cdc 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -40,6 +40,7 @@ TVMCSuppressedArgumentParser, get_project_options, get_and_check_options, + get_project_dir, ) from .main import register_parser from .model import TVMCPackage, TVMCResult @@ -147,7 +148,7 @@ def add_run_parser(subparsers, main_parser): "Please build TVM with micro support (USE_MICRO ON)!" ) - project_dir = known_args.PATH + project_dir = get_project_dir(known_args.PATH) try: project_ = project.GeneratedProject.from_directory(project_dir, None) @@ -496,7 +497,8 @@ def run_module( if tvmc_package.type != "mlf": raise TVMCException(f"Model {tvmc_package.package_path} is not a MLF archive.") - project_dir = os.path.dirname(tvmc_package.package_path) + project_dir = get_project_dir(tvmc_package.package_path) + project_dir = os.path.dirname(project_dir) # This is guaranteed to work since project_dir was already checked when # building the dynamic parser to accommodate the project options, so no diff --git a/tests/micro/common/test_tvmc.py b/tests/micro/common/test_tvmc.py index d462b3fadd9b..eb0b3a628442 100644 --- a/tests/micro/common/test_tvmc.py +++ b/tests/micro/common/test_tvmc.py @@ -23,6 +23,7 @@ import pathlib import sys import os +import shutil import tvm from tvm.contrib.download import download_testdata @@ -66,13 +67,22 @@ def test_tvmc_exist(board): @tvm.testing.requires_micro -def test_tvmc_model_build_only(board): +@pytest.mark.parametrize( + "output_dir,", + [pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())], +) +def test_tvmc_model_build_only(board, output_dir): target, platform = _get_target_and_platform(board) + if not os.path.isabs(output_dir): + out_dir_temp = os.path.abspath(output_dir) + if os.path.isdir(out_dir_temp): + shutil.rmtree(out_dir_temp) + os.mkdir(out_dir_temp) + model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data") - temp_dir = pathlib.Path(tempfile.mkdtemp()) - tar_path = str(temp_dir / "model.tar") - project_dir = str(temp_dir / "project") + tar_path = str(output_dir / "model.tar") + project_dir = str(output_dir / "project") runtime = "crt" executor = "graph" @@ -118,17 +128,27 @@ def test_tvmc_model_build_only(board): ["micro", "build", project_dir, platform, "--project-option", f"{platform}_board={board}"] ) assert cmd_result == 0, "tvmc micro failed in step: build" + shutil.rmtree(output_dir) @pytest.mark.requires_hardware @tvm.testing.requires_micro -def test_tvmc_model_run(board): +@pytest.mark.parametrize( + "output_dir,", + [pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())], +) +def test_tvmc_model_run(board, output_dir): target, platform = _get_target_and_platform(board) + if not os.path.isabs(output_dir): + out_dir_temp = os.path.abspath(output_dir) + if os.path.isdir(out_dir_temp): + shutil.rmtree(out_dir_temp) + os.mkdir(out_dir_temp) + model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data") - temp_dir = pathlib.Path(tempfile.mkdtemp()) - tar_path = str(temp_dir / "model.tar") - project_dir = str(temp_dir / "project") + tar_path = str(output_dir / "model.tar") + project_dir = str(output_dir / "project") runtime = "crt" executor = "graph" @@ -193,6 +213,7 @@ def test_tvmc_model_run(board): ] ) assert cmd_result == 0, "tvmc micro failed in step: run" + shutil.rmtree(output_dir) if __name__ == "__main__":