Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMC] Separate model loading from model compilation in TVMC. #7739

Merged
merged 15 commits into from
Apr 2, 2021
23 changes: 9 additions & 14 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,16 @@ def drive_compile(args):
Zero if successfully completed

"""
mod, params = frontends.load_model(args.FILE, args.model_format, args.input_shapes)

graph, lib, params, dumps = compile_model(
args.FILE,
mod,
params,
args.target,
args.dump_code,
None,
args.model_format,
args.tuning_records,
args.desired_layout,
args.input_shapes,
)

if dumps:
Expand All @@ -130,14 +130,13 @@ def drive_compile(args):


def compile_model(
path,
mod,
params,
target,
dump_code=None,
target_host=None,
model_format=None,
tuning_records=None,
alter_layout=None,
shape_dict=None,
):
"""Compile a model from a supported framework into a TVM module.

Expand All @@ -147,8 +146,10 @@ def compile_model(

Parameters
----------
path: str
Path to a file
mod: IRModule
The relay module to be compiled.
params: dict
A dictionary containing the module's parameters.
target : str
The target for which to compile. Can be a plain string or
a path.
Expand All @@ -158,18 +159,13 @@ def compile_model(
target_host : str, optional
The target of the host machine if host-side code
needs to be generated.
model_format: str, optional
A string representing a name of a frontend to be used
tuning_records: str, optional
Path to the file produced by the tuning to be used during
compilation.
alter_layout: str, optional
The layout to convert the graph to. Note, the convert layout
pass doesn't currently guarantee the whole of the graph will
be converted to the chosen layout.
shape_dict: dict, optional
A mapping from input names to their shape. When present,
the default shapes in the model will be overwritten.

Returns
-------
Expand All @@ -184,7 +180,6 @@ def compile_model(

"""
dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None
mod, params = frontends.load_model(path, model_format, shape_dict)
config = {}

if alter_layout:
Expand Down
3 changes: 2 additions & 1 deletion tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def get_sample_compiled_module(target_dir):
temp_dir=target_dir,
)

return tvmc.compiler.compile_model(model_file, target="llvm")
mod, params = tvmc.frontends.load_model(model_file)
return tvmc.compiler.compile_model(mod, params, target="llvm")


# PyTest fixtures
Expand Down
41 changes: 23 additions & 18 deletions tests/python/driver/tvmc/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def test_save_dumps(tmpdir_factory):

def verify_compile_tflite_module(model, shape_dict=None):
pytest.importorskip("tflite")

mod, params = tvmc.load(model, shape_dict=shape_dict)
graph, lib, params, dumps = tvmc.compile(
model, target="llvm", dump_code="ll", alter_layout="NCHW", shape_dict=shape_dict
mod, params, target="llvm", dump_code="ll", alter_layout="NCHW"
)

# check for output types
Expand Down Expand Up @@ -74,8 +74,10 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")

mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_1_quant,
mod,
params,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'",
dump_code="asm",
)
Expand All @@ -91,7 +93,8 @@ def test_compile_keras__save_module(keras_resnet50, tmpdir_factory):
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

graph, lib, params, dumps = tvmc.compile(keras_resnet50, target="llvm", dump_code="ll")
mod, params = tvmc.load(keras_resnet50)
graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm", dump_code="ll")

expected_temp_dir = tmpdir_factory.mktemp("saved_output")
expected_file_name = "saved.tar"
Expand All @@ -109,8 +112,10 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50):
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

mod, params = tvmc.load(keras_resnet50)
graph, lib, params, dumps = tvmc.compile(
keras_resnet50,
mod,
params,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'",
dump_code="asm",
)
Expand All @@ -126,10 +131,8 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50):
def verify_compile_onnx_module(model, shape_dict=None):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

graph, lib, params, dumps = tvmc.compile(
model, target="llvm", dump_code="ll", shape_dict=shape_dict
)
mod, params = tvmc.load(model, shape_dict=shape_dict)
graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm", dump_code="ll")

# check for output types
assert type(graph) is str
Expand All @@ -156,8 +159,10 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

mod, params = tvmc.load(onnx_resnet50)
graph, lib, params, dumps = tvmc.compile(
onnx_resnet50,
mod,
params,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
dump_code="asm",
)
Expand All @@ -173,9 +178,10 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
@tvm.testing.requires_opencl
def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
pytest.importorskip("tflite")

mod, params = tvmc.load(tflite_mobilenet_v1_0_25_128)
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_0_25_128,
mod,
params,
target="opencl",
target_host="llvm",
alter_layout="NCHW",
Expand All @@ -194,9 +200,9 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
)
def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")

mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_1_quant, target="ethos-n77, llvm", dump_code="relay"
mod, params, target="ethos-n77, llvm", dump_code="relay"
)

# check for output types
Expand All @@ -208,7 +214,7 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant

@mock.patch("tvm.relay.build")
@mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target")
@mock.patch("tvm.driver.tvmc.frontends.load_model")
@mock.patch("tvm.driver.tvmc.load")
@mock.patch("tvm.transform.PassContext")
def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_relay):
mock_codegen = {}
Expand All @@ -219,9 +225,8 @@ def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_
mock_ct.return_value = mock_codegen
mock_relay.return_value = mock.MagicMock()

graph, lib, params, dumps = tvmc.compile(
"no_file_needed", target="mockcodegen -testopt=value, llvm"
)
mod, params = tvmc.load("no_file_needed")
graph, lib, params, dumps = tvmc.compile(mod, params, target="mockcodegen -testopt=value, llvm")

mock_pc.assert_called_once_with(
opt_level=3, config={"relay.ext.mock.options": {"testopt": "value"}}
Expand Down