Skip to content

Commit

Permalink
[TVMC][MicroTVM] Fix tvmc micro project_dir arg relative path (#9663)
Browse files Browse the repository at this point in the history
* Add fix for project dir path

* address @gromero comments
  • Loading branch information
mehrdadh authored Dec 8, 2021
1 parent fcea393 commit b54beed
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 25 deletions.
11 changes: 9 additions & 2 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import logging
import os.path
import argparse

import pathlib
from typing import Union
from collections import defaultdict
from urllib.parse import urlparse

import tvm

from tvm.driver import tvmc
from tvm import relay
from tvm import transform
Expand Down Expand Up @@ -786,3 +786,10 @@ def get_and_check_options(passed_options, valid_options):
check_options_choices(opts, valid_options)

return opts


def get_project_dir(project_dir: Union[pathlib.Path, str]) -> str:
"""Get project directory path"""
if not os.path.isabs(project_dir):
return os.path.abspath(project_dir)
return project_dir
27 changes: 14 additions & 13 deletions python/tvm/driver/tvmc/micro.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TVMCSuppressedArgumentParser,
get_project_options,
get_and_check_options,
get_project_dir,
)

try:
Expand Down Expand Up @@ -238,16 +239,16 @@ def drive_micro(args):

def create_project_handler(args):
"""Creates a new project dir."""
project_dir = get_project_dir(args.project_dir)

if os.path.exists(args.project_dir):
if os.path.exists(project_dir):
if args.force:
shutil.rmtree(args.project_dir)
shutil.rmtree(project_dir)
else:
raise TVMCException(
"The specified project dir already exists. "
"To force overwriting it use '-f' or '--force'."
)
project_dir = args.project_dir

template_dir = str(Path(args.template_dir).resolve())
if not os.path.exists(template_dir):
Expand All @@ -268,21 +269,20 @@ def create_project_handler(args):

def build_handler(args):
"""Builds a firmware image given a project dir."""
project_dir = get_project_dir(args.project_dir)

if not os.path.exists(args.project_dir):
raise TVMCException(f"{args.project_dir} doesn't exist.")
if not os.path.exists(project_dir):
raise TVMCException(f"{project_dir} doesn't exist.")

if os.path.exists(args.project_dir + "/build"):
if os.path.exists(project_dir + "/build"):
if args.force:
shutil.rmtree(args.project_dir + "/build")
shutil.rmtree(project_dir + "/build")
else:
raise TVMCException(
f"There is already a build in {args.project_dir}. "
f"There is already a build in {project_dir}. "
"To force rebuild it use '-f' or '--force'."
)

project_dir = args.project_dir

options = get_and_check_options(args.project_option, args.valid_options)

try:
Expand All @@ -295,10 +295,11 @@ def build_handler(args):

def flash_handler(args):
"""Flashes a firmware image to a target device given a project dir."""
if not os.path.exists(args.project_dir + "/build"):
raise TVMCException(f"Could not find a build in {args.project_dir}")

project_dir = args.project_dir
project_dir = get_project_dir(args.project_dir)

if not os.path.exists(project_dir + "/build"):
raise TVMCException(f"Could not find a build in {project_dir}")

options = get_and_check_options(args.project_option, args.valid_options)

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TVMCSuppressedArgumentParser,
get_project_options,
get_and_check_options,
get_project_dir,
)
from .main import register_parser
from .model import TVMCPackage, TVMCResult
Expand Down Expand Up @@ -147,7 +148,7 @@ def add_run_parser(subparsers, main_parser):
"Please build TVM with micro support (USE_MICRO ON)!"
)

project_dir = known_args.PATH
project_dir = get_project_dir(known_args.PATH)

try:
project_ = project.GeneratedProject.from_directory(project_dir, None)
Expand Down Expand Up @@ -496,7 +497,8 @@ def run_module(
if tvmc_package.type != "mlf":
raise TVMCException(f"Model {tvmc_package.package_path} is not a MLF archive.")

project_dir = os.path.dirname(tvmc_package.package_path)
project_dir = get_project_dir(tvmc_package.package_path)
project_dir = os.path.dirname(project_dir)

# This is guaranteed to work since project_dir was already checked when
# building the dynamic parser to accommodate the project options, so no
Expand Down
37 changes: 29 additions & 8 deletions tests/micro/common/test_tvmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pathlib
import sys
import os
import shutil

import tvm
from tvm.contrib.download import download_testdata
Expand Down Expand Up @@ -66,13 +67,22 @@ def test_tvmc_exist(board):


@tvm.testing.requires_micro
def test_tvmc_model_build_only(board):
@pytest.mark.parametrize(
"output_dir,",
[pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())],
)
def test_tvmc_model_build_only(board, output_dir):
target, platform = _get_target_and_platform(board)

if not os.path.isabs(output_dir):
out_dir_temp = os.path.abspath(output_dir)
if os.path.isdir(out_dir_temp):
shutil.rmtree(out_dir_temp)
os.mkdir(out_dir_temp)

model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data")
temp_dir = pathlib.Path(tempfile.mkdtemp())
tar_path = str(temp_dir / "model.tar")
project_dir = str(temp_dir / "project")
tar_path = str(output_dir / "model.tar")
project_dir = str(output_dir / "project")

runtime = "crt"
executor = "graph"
Expand Down Expand Up @@ -118,17 +128,27 @@ def test_tvmc_model_build_only(board):
["micro", "build", project_dir, platform, "--project-option", f"{platform}_board={board}"]
)
assert cmd_result == 0, "tvmc micro failed in step: build"
shutil.rmtree(output_dir)


@pytest.mark.requires_hardware
@tvm.testing.requires_micro
def test_tvmc_model_run(board):
@pytest.mark.parametrize(
"output_dir,",
[pathlib.Path("./tvmc_relative_path_test"), pathlib.Path(tempfile.mkdtemp())],
)
def test_tvmc_model_run(board, output_dir):
target, platform = _get_target_and_platform(board)

if not os.path.isabs(output_dir):
out_dir_temp = os.path.abspath(output_dir)
if os.path.isdir(out_dir_temp):
shutil.rmtree(out_dir_temp)
os.mkdir(out_dir_temp)

model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data")
temp_dir = pathlib.Path(tempfile.mkdtemp())
tar_path = str(temp_dir / "model.tar")
project_dir = str(temp_dir / "project")
tar_path = str(output_dir / "model.tar")
project_dir = str(output_dir / "project")

runtime = "crt"
executor = "graph"
Expand Down Expand Up @@ -193,6 +213,7 @@ def test_tvmc_model_run(board):
]
)
assert cmd_result == 0, "tvmc micro failed in step: run"
shutil.rmtree(output_dir)


if __name__ == "__main__":
Expand Down

0 comments on commit b54beed

Please sign in to comment.