Skip to content

Commit

Permalink
Clarify that graph JSON is required only for graph executor
Browse files Browse the repository at this point in the history
Plus other clean ups
  • Loading branch information
Mousius committed Jun 25, 2021
1 parent c9d7b40 commit b76aae7
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
6 changes: 2 additions & 4 deletions python/tvm/driver/tvmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,8 @@ def import_package(self, package_path: str):
with open(temp.relpath("metadata.json")) as metadata_json:
metadata = json.load(metadata_json)

if "graph" in metadata["runtimes"]:
graph = temp.relpath("runtime-config/graph/graph.json")
else:
graph = None
is_graph_runtime = "graph" in metadata["runtimes"]
graph = temp.relpath("runtime-config/graph/graph.json") if is_graph_runtime else None
params = temp.relpath("parameters/default.params")

self.type = "mlf"
Expand Down
2 changes: 1 addition & 1 deletion tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def onnx_mnist():


@pytest.fixture
def tflite_tvmc_compiler(tmpdir_factory):
def tflite_compile_model(tmpdir_factory):
"""Support function that returns a TFLite compiled module"""

def model_compiler(model_file, **overrides):
Expand Down
14 changes: 8 additions & 6 deletions tests/python/driver/tvmc/test_mlf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory):
assert str(exp.value) == expected_reason, on_error


def test_tvmc_import_package_mlf(tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler):
def test_tvmc_import_package_mlf_graph(tflite_mobilenet_v1_1_quant, tflite_compile_model):
pytest.importorskip("tflite")

tflite_compiled_model_mlf = tflite_tvmc_compiler(
tflite_compiled_model_mlf = tflite_compile_model(
tflite_mobilenet_v1_1_quant, output_format="mlf"
)

Expand All @@ -101,15 +101,17 @@ def test_tvmc_import_package_mlf(tflite_mobilenet_v1_1_quant, tflite_tvmc_compil

assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive."
assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive."
assert tvmc_package.graph is not None, ".graph must be set in the MLF archive."
assert (
tvmc_package.graph is not None
), ".graph must be set in the MLF archive for Graph executor."
assert tvmc_package.params is not None, ".params must be set in the MLF archive."
assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format."


def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler):
def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile_model):
pytest.importorskip("tflite")

tflite_compiled_model_mlf = tflite_tvmc_compiler(
tflite_compiled_model_mlf = tflite_compile_model(
tflite_mobilenet_v1_1_quant,
target="c --executor=aot",
output_format="mlf",
Expand All @@ -125,6 +127,6 @@ def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_tvmc_co

assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive."
assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive."
assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT."
assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT executor."
assert tvmc_package.params is not None, ".params must be set in the MLF archive."
assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format."
4 changes: 2 additions & 2 deletions tests/python/driver/tvmc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def test_get_top_results_keep_results():


def test_run_tflite_module__with_profile__valid_input(
tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler, imagenet_cat
tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat
):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

inputs = np.load(imagenet_cat)

tflite_compiled_model = tflite_tvmc_compiler(tflite_mobilenet_v1_1_quant)
tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant)
result = tvmc.run(
tflite_compiled_model,
inputs=inputs,
Expand Down

0 comments on commit b76aae7

Please sign in to comment.