diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index f484290bb5d0..ba7722f6b38e 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -111,16 +111,16 @@ def drive_compile(args): Zero if successfully completed """ + mod, params = frontends.load_model(args.FILE, args.model_format, args.input_shapes) graph, lib, params, dumps = compile_model( - args.FILE, + mod, + params, args.target, args.dump_code, None, - args.model_format, args.tuning_records, args.desired_layout, - args.input_shapes, ) if dumps: @@ -131,14 +131,13 @@ def drive_compile(args): def compile_model( - path, + mod, + params, target, dump_code=None, target_host=None, - model_format=None, tuning_records=None, alter_layout=None, - shape_dict=None, ): """Compile a model from a supported framework into a TVM module. @@ -148,8 +147,10 @@ def compile_model( Parameters ---------- - path: str - Path to a file + mod: IRModule + The relay module to be compiled. + params: dict + A dictionary containing the module's parameters. target : str The target for which to compile. Can be a plain string or a path. @@ -159,8 +160,6 @@ def compile_model( target_host : str, optional The target of the host machine if host-side code needs to be generated. - model_format: str, optional - A string representing a name of a frontend to be used tuning_records: str, optional Path to the file produced by the tuning to be used during compilation. @@ -168,9 +167,6 @@ def compile_model( The layout to convert the graph to. Note, the convert layout pass doesn't currently guarantee the whole of the graph will be converted to the chosen layout. - shape_dict: dict, optional - A mapping from input names to their shape. When present, - the default shapes in the model will be overwritten. Returns ------- @@ -185,7 +181,6 @@ def compile_model( """ dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None - mod, params = frontends.load_model(path, model_format, shape_dict) config = {} if alter_layout: diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 534953deecbc..3345b4f07585 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -51,7 +51,8 @@ def get_sample_compiled_module(target_dir): temp_dir=target_dir, ) - return tvmc.compiler.compile_model(model_file, target="llvm") + mod, params = tvmc.frontends.load_model(model_file) + return tvmc.compiler.compile_model(mod, params, target="llvm") # PyTest fixtures diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 17b2834feb11..6d17b4e37114 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -44,9 +44,9 @@ def test_save_dumps(tmpdir_factory): def verify_compile_tflite_module(model, shape_dict=None): pytest.importorskip("tflite") - + mod, params = tvmc.load(model, shape_dict=shape_dict) graph, lib, params, dumps = tvmc.compile( - model, target="llvm", dump_code="ll", alter_layout="NCHW", shape_dict=shape_dict + mod, params, target="llvm", dump_code="ll", alter_layout="NCHW" ) # check for output types @@ -74,8 +74,10 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") + mod, params = tvmc.load(tflite_mobilenet_v1_1_quant) graph, lib, params, dumps = tvmc.compile( - tflite_mobilenet_v1_1_quant, + mod, + params, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'", dump_code="asm", ) @@ -91,7 +93,8 @@ def test_compile_keras__save_module(keras_resnet50, tmpdir_factory): # some CI environments wont offer tensorflow/Keras, so skip in case it is not present pytest.importorskip("tensorflow") - graph, lib, params, dumps = tvmc.compile(keras_resnet50, target="llvm", dump_code="ll") + mod, params = tvmc.load(keras_resnet50) + graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm", dump_code="ll") expected_temp_dir = tmpdir_factory.mktemp("saved_output") expected_file_name = "saved.tar" @@ -109,8 +112,10 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50): # some CI environments wont offer tensorflow/Keras, so skip in case it is not present pytest.importorskip("tensorflow") + mod, params = tvmc.load(keras_resnet50) graph, lib, params, dumps = tvmc.compile( - keras_resnet50, + mod, + params, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'", dump_code="asm", ) @@ -126,10 +131,8 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50): def verify_compile_onnx_module(model, shape_dict=None): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") - - graph, lib, params, dumps = tvmc.compile( - model, target="llvm", dump_code="ll", shape_dict=shape_dict - ) + mod, params = tvmc.load(model, shape_dict=shape_dict) + graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm", dump_code="ll") # check for output types assert type(graph) is str @@ -156,8 +159,10 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") + mod, params = tvmc.load(onnx_resnet50) graph, lib, params, dumps = tvmc.compile( - onnx_resnet50, + mod, + params, target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", dump_code="asm", ) @@ -173,9 +178,10 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50): @tvm.testing.requires_opencl def test_compile_opencl(tflite_mobilenet_v1_0_25_128): pytest.importorskip("tflite") - + mod, params = tvmc.load(tflite_mobilenet_v1_0_25_128) graph, lib, params, dumps = tvmc.compile( - tflite_mobilenet_v1_0_25_128, + mod, + params, target="opencl --host=llvm", alter_layout="NCHW", ) @@ -193,9 +199,9 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128): ) def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") - + mod, params = tvmc.load(tflite_mobilenet_v1_1_quant) graph, lib, params, dumps = tvmc.compile( - tflite_mobilenet_v1_1_quant, target="ethos-n77, llvm", dump_code="relay" + mod, params, target="ethos-n77, llvm", dump_code="relay" ) # check for output types @@ -207,7 +213,7 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant @mock.patch("tvm.relay.build") @mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target") -@mock.patch("tvm.driver.tvmc.frontends.load_model") +@mock.patch("tvm.driver.tvmc.load") @mock.patch("tvm.transform.PassContext") def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_relay): mock_codegen = {} @@ -218,9 +224,8 @@ def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_ mock_ct.return_value = mock_codegen mock_relay.return_value = mock.MagicMock() - graph, lib, params, dumps = tvmc.compile( - "no_file_needed", target="mockcodegen -testopt=value, llvm" - ) + mod, params = tvmc.load("no_file_needed") + graph, lib, params, dumps = tvmc.compile(mod, params, target="mockcodegen -testopt=value, llvm") mock_pc.assert_called_once_with( opt_level=3, config={"relay.ext.mock.options": {"testopt": "value"}}