Skip to content

Commit

Permalink
Allow tvmc to compile models with AOT executor
Browse files Browse the repository at this point in the history
The tflite_compiled_model fixture was getting duplicated a few times so
I've added a parameterized fixture tflite_tvmc_compiler which combines
tmpdir_factory setup with compile_model

Nested targets broke a basic string split, so in cases where we use
nested targets I replaced the string split with shlex split
  • Loading branch information
Mousius committed Jun 24, 2021
1 parent 5fa1c6d commit c9d7b40
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 64 deletions.
6 changes: 3 additions & 3 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def compile_model(
target_host: Optional[str] = None,
desired_layout: Optional[str] = None,
disabled_pass: Optional[str] = None,
pass_context_configs: Optional[str] = None,
pass_context_configs: Optional[List[str]] = None,
):
"""Compile a model from a supported framework into a TVM module.
Expand Down Expand Up @@ -212,8 +212,8 @@ def compile_model(
disabled_pass: str, optional
Comma-separated list of passes which needs to be disabled
during compilation
pass_context_configs: str, optional
String containing a set of configurations to be passed to the
pass_context_configs: list[str], optional
List of strings containing a set of configurations to be passed to the
PassContext.
Expand Down
15 changes: 12 additions & 3 deletions python/tvm/driver/tvmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"""
import os
import tarfile
import json
from typing import Optional, Union, List, Dict, Callable, TextIO
import numpy as np

Expand Down Expand Up @@ -332,8 +333,13 @@ def import_package(self, package_path: str):
# Model Library Format (MLF)
self.lib_name = None
self.lib_path = None
with open(temp.relpath("metadata.json")) as metadata_json:
metadata = json.load(metadata_json)

graph = temp.relpath("runtime-config/graph/graph.json")
if "graph" in metadata["runtimes"]:
graph = temp.relpath("runtime-config/graph/graph.json")
else:
graph = None
params = temp.relpath("parameters/default.params")

self.type = "mlf"
Expand All @@ -357,8 +363,11 @@ def import_package(self, package_path: str):
with open(params, "rb") as param_file:
self.params = bytearray(param_file.read())

with open(graph) as graph_file:
self.graph = graph_file.read()
if graph is not None:
with open(graph) as graph_file:
self.graph = graph_file.read()
else:
self.graph = None


class TVMCResult(object):
Expand Down
60 changes: 9 additions & 51 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,6 @@ def download_and_untar(model_url, model_sub_path, temp_dir):
return os.path.join(temp_dir, model_sub_path)


def get_sample_compiled_module(target_dir, package_filename, output_format="so"):
"""Support function that returns a TFLite compiled module"""
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
model_file = download_and_untar(
"{}/{}".format(base_url, model_url),
"mobilenet_v1_1.0_224_quant.tflite",
temp_dir=target_dir,
)

tvmc_model = tvmc.frontends.load_model(model_file)
return tvmc.compiler.compile_model(
tvmc_model,
target="llvm",
package_path=os.path.join(target_dir, package_filename),
output_format=output_format,
)


# PyTest fixtures


Expand Down Expand Up @@ -167,40 +148,17 @@ def onnx_mnist():
return model_file


@pytest.fixture(scope="session")
def tflite_compiled_model(tmpdir_factory):

# Not all CI environments will have TFLite installed
# so we need to safely skip this fixture that will
# crash the tests that rely on it.
# As this is a pytest.fixture, we cannot take advantage
# of pytest.importorskip. Using the block below instead.
try:
import tflite
except ImportError:
print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.")
return ""

target_dir = tmpdir_factory.mktemp("data")
return get_sample_compiled_module(target_dir, "mock.tar")


@pytest.fixture(scope="session")
def tflite_compiled_model_mlf(tmpdir_factory):
@pytest.fixture
def tflite_tvmc_compiler(tmpdir_factory):
"""Support function that returns a TFLite compiled module"""

# Not all CI environments will have TFLite installed
# so we need to safely skip this fixture that will
# crash the tests that rely on it.
# As this is a pytest.fixture, we cannot take advantage
# of pytest.importorskip. Using the block below instead.
try:
import tflite
except ImportError:
print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.")
return ""
def model_compiler(model_file, **overrides):
package_path = tmpdir_factory.mktemp("data").join("mock.tar")
tvmc_model = tvmc.frontends.load_model(model_file)
args = {"target": "llvm", **overrides}
return tvmc.compiler.compile_model(tvmc_model, package_path=package_path, **args)

target_dir = tmpdir_factory.mktemp("data")
return get_sample_compiled_module(target_dir, "mock.tar", "mlf")
return model_compiler


@pytest.fixture(scope="session")
Expand Down
43 changes: 37 additions & 6 deletions tests/python/driver/tvmc/test_mlf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,28 @@

import pytest
import os
import shlex

import tvm
from tvm.driver import tvmc
from tvm.driver.tvmc.main import _main
from tvm.driver.tvmc.model import TVMCPackage, TVMCException


def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
@pytest.mark.parametrize(
["target", "pass_configs"], [["llvm", []], ["c --executor=aot", ["tir.disable_vectorize=1"]]]
)
def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory, target, pass_configs):
pytest.importorskip("tflite")

output_dir = tmpdir_factory.mktemp("mlf")
input_model = tflite_mobilenet_v1_1_quant
output_file = os.path.join(output_dir, "mock.tar")

# Compile the input model and generate a Model Library Format (MLF) archive.
tvmc_cmd = (
f"tvmc compile {input_model} --target='llvm' --output {output_file} --output-format mlf"
)
tvmc_args = tvmc_cmd.split(" ")[1:]
pass_config_args = " ".join([f"--pass-config {pass_config}" for pass_config in pass_configs])
tvmc_cmd = f"tvmc compile {input_model} --target='{target}' {pass_config_args} --output {output_file} --output-format mlf"
tvmc_args = shlex.split(tvmc_cmd)[1:]
_main(tvmc_args)
assert os.path.exists(output_file), "Could not find the exported MLF archive."

Expand Down Expand Up @@ -82,9 +85,13 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
assert str(exp.value) == expected_reason, on_error


def test_tvmc_import_package_mlf(tflite_compiled_model_mlf):
def test_tvmc_import_package_mlf(tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler):
pytest.importorskip("tflite")

tflite_compiled_model_mlf = tflite_tvmc_compiler(
tflite_mobilenet_v1_1_quant, output_format="mlf"
)

# Compile and export a model to a MLF archive so it can be imported.
exported_tvmc_package = tflite_compiled_model_mlf
archive_path = exported_tvmc_package.package_path
Expand All @@ -97,3 +104,27 @@ def test_tvmc_import_package_mlf(tflite_compiled_model_mlf):
assert tvmc_package.graph is not None, ".graph must be set in the MLF archive."
assert tvmc_package.params is not None, ".params must be set in the MLF archive."
assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format."


def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler):
pytest.importorskip("tflite")

tflite_compiled_model_mlf = tflite_tvmc_compiler(
tflite_mobilenet_v1_1_quant,
target="c --executor=aot",
output_format="mlf",
pass_context_configs=["tir.disable_vectorize=1"],
)

# Compile and export a model to a MLF archive so it can be imported.
exported_tvmc_package = tflite_compiled_model_mlf
archive_path = exported_tvmc_package.package_path

# Import the MLF archive. TVMCPackage constructor will call import_package method.
tvmc_package = TVMCPackage(archive_path)

assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive."
assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive."
assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT."
assert tvmc_package.params is not None, ".params must be set in the MLF archive."
assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format."
5 changes: 4 additions & 1 deletion tests/python/driver/tvmc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ def test_get_top_results_keep_results():
assert len(sut[1]) == expected_number_of_results_per_line


def test_run_tflite_module__with_profile__valid_input(tflite_compiled_model, imagenet_cat):
def test_run_tflite_module__with_profile__valid_input(
tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler, imagenet_cat
):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

inputs = np.load(imagenet_cat)

tflite_compiled_model = tflite_tvmc_compiler(tflite_mobilenet_v1_1_quant)
result = tvmc.run(
tflite_compiled_model,
inputs=inputs,
Expand Down

0 comments on commit c9d7b40

Please sign in to comment.