From b76aae788d131b3acc4b8ad9cf7a06dfa547b648 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Fri, 25 Jun 2021 13:48:53 +0100 Subject: [PATCH] Clarify that graph JSON is required only for graph executor Plus other clean ups --- python/tvm/driver/tvmc/model.py | 6 ++---- tests/python/driver/tvmc/conftest.py | 2 +- tests/python/driver/tvmc/test_mlf.py | 14 ++++++++------ tests/python/driver/tvmc/test_runner.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 6b93815c1b615..8c8828ddd49b5 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -336,10 +336,8 @@ def import_package(self, package_path: str): with open(temp.relpath("metadata.json")) as metadata_json: metadata = json.load(metadata_json) - if "graph" in metadata["runtimes"]: - graph = temp.relpath("runtime-config/graph/graph.json") - else: - graph = None + is_graph_runtime = "graph" in metadata["runtimes"] + graph = temp.relpath("runtime-config/graph/graph.json") if is_graph_runtime else None params = temp.relpath("parameters/default.params") self.type = "mlf" diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 9c39b9bcceaa9..209c371a296a9 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -149,7 +149,7 @@ def onnx_mnist(): @pytest.fixture -def tflite_tvmc_compiler(tmpdir_factory): +def tflite_compile_model(tmpdir_factory): """Support function that returns a TFLite compiled module""" def model_compiler(model_file, **overrides): diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 81580a657ea64..8cebbd33ff2c1 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -85,10 +85,10 @@ def test_tvmc_export_package_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): assert str(exp.value) == expected_reason, on_error -def test_tvmc_import_package_mlf(tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler): +def test_tvmc_import_package_mlf_graph(tflite_mobilenet_v1_1_quant, tflite_compile_model): pytest.importorskip("tflite") - tflite_compiled_model_mlf = tflite_tvmc_compiler( + tflite_compiled_model_mlf = tflite_compile_model( tflite_mobilenet_v1_1_quant, output_format="mlf" ) @@ -101,15 +101,17 @@ def test_tvmc_import_package_mlf(tflite_mobilenet_v1_1_quant, tflite_tvmc_compil assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive." assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive." - assert tvmc_package.graph is not None, ".graph must be set in the MLF archive." + assert ( + tvmc_package.graph is not None + ), ".graph must be set in the MLF archive for Graph executor." assert tvmc_package.params is not None, ".params must be set in the MLF archive." assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format." -def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler): +def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile_model): pytest.importorskip("tflite") - tflite_compiled_model_mlf = tflite_tvmc_compiler( + tflite_compiled_model_mlf = tflite_compile_model( tflite_mobilenet_v1_1_quant, target="c --executor=aot", output_format="mlf", @@ -125,6 +127,6 @@ def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_tvmc_co assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive." assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive." - assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT." + assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT executor." assert tvmc_package.params is not None, ".params must be set in the MLF archive." assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format." diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 72e42256c9adb..7acb376baba64 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -72,14 +72,14 @@ def test_get_top_results_keep_results(): def test_run_tflite_module__with_profile__valid_input( - tflite_mobilenet_v1_1_quant, tflite_tvmc_compiler, imagenet_cat + tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat ): # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite") inputs = np.load(imagenet_cat) - tflite_compiled_model = tflite_tvmc_compiler(tflite_mobilenet_v1_1_quant) + tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant) result = tvmc.run( tflite_compiled_model, inputs=inputs,