Skip to content

Commit

Permalink
[TVMC] Separate model loading from model compilation in TVMC. (apache…
Browse files Browse the repository at this point in the history
…#7739)

* add to init files for clean tvmc python

* adjust tests to new imports

* add to compiler.py

* update so model loads in drive_compile

* update test_compiler.py to load outside of tvmc.compile, need to correct one error

* fix mock.patch test

* remove merge artifact (circular import issue)

* change typo and merge artifact

* fix import in test_compiler.py

* black needed files

* remove unnecessary argument model_format from compile_module

* load before compile in conftest.py

* fix conftest.py issue

* fix typo in test_compiler.py

Co-authored-by: Jocelyn <jocelyn@pop-os.localdomain>
Co-authored-by: Josh Fromm <jwfromm@uw.edu>
  • Loading branch information
3 people authored and Trevor Morris committed May 6, 2021
1 parent 66a1434 commit b1a68c3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 33 deletions.
23 changes: 9 additions & 14 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,16 @@ def drive_compile(args):
Zero if successfully completed
"""
mod, params = frontends.load_model(args.FILE, args.model_format, args.input_shapes)

graph, lib, params, dumps = compile_model(
args.FILE,
mod,
params,
args.target,
args.dump_code,
None,
args.model_format,
args.tuning_records,
args.desired_layout,
args.input_shapes,
)

if dumps:
Expand All @@ -131,14 +131,13 @@ def drive_compile(args):


def compile_model(
path,
mod,
params,
target,
dump_code=None,
target_host=None,
model_format=None,
tuning_records=None,
alter_layout=None,
shape_dict=None,
):
"""Compile a model from a supported framework into a TVM module.
Expand All @@ -148,8 +147,10 @@ def compile_model(
Parameters
----------
path: str
Path to a file
mod: IRModule
The relay module to be compiled.
params: dict
A dictionary containing the module's parameters.
target : str
The target for which to compile. Can be a plain string or
a path.
Expand All @@ -159,18 +160,13 @@ def compile_model(
target_host : str, optional
The target of the host machine if host-side code
needs to be generated.
model_format: str, optional
A string representing a name of a frontend to be used
tuning_records: str, optional
Path to the file produced by the tuning to be used during
compilation.
alter_layout: str, optional
The layout to convert the graph to. Note, the convert layout
pass doesn't currently guarantee the whole of the graph will
be converted to the chosen layout.
shape_dict: dict, optional
A mapping from input names to their shape. When present,
the default shapes in the model will be overwritten.
Returns
-------
Expand All @@ -185,7 +181,6 @@ def compile_model(
"""
dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None
mod, params = frontends.load_model(path, model_format, shape_dict)
config = {}

if alter_layout:
Expand Down
3 changes: 2 additions & 1 deletion tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def get_sample_compiled_module(target_dir):
temp_dir=target_dir,
)

return tvmc.compiler.compile_model(model_file, target="llvm")
mod, params = tvmc.frontends.load_model(model_file)
return tvmc.compiler.compile_model(mod, params, target="llvm")


# PyTest fixtures
Expand Down
41 changes: 23 additions & 18 deletions tests/python/driver/tvmc/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def test_save_dumps(tmpdir_factory):

def verify_compile_tflite_module(model, shape_dict=None):
pytest.importorskip("tflite")

mod, params = tvmc.load(model, shape_dict=shape_dict)
graph, lib, params, dumps = tvmc.compile(
model, target="llvm", dump_code="ll", alter_layout="NCHW", shape_dict=shape_dict
mod, params, target="llvm", dump_code="ll", alter_layout="NCHW"
)

# check for output types
Expand Down Expand Up @@ -74,8 +74,10 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")

mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_1_quant,
mod,
params,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'",
dump_code="asm",
)
Expand All @@ -91,7 +93,8 @@ def test_compile_keras__save_module(keras_resnet50, tmpdir_factory):
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

graph, lib, params, dumps = tvmc.compile(keras_resnet50, target="llvm", dump_code="ll")
mod, params = tvmc.load(keras_resnet50)
graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm", dump_code="ll")

expected_temp_dir = tmpdir_factory.mktemp("saved_output")
expected_file_name = "saved.tar"
Expand All @@ -109,8 +112,10 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50):
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

mod, params = tvmc.load(keras_resnet50)
graph, lib, params, dumps = tvmc.compile(
keras_resnet50,
mod,
params,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'",
dump_code="asm",
)
Expand All @@ -126,10 +131,8 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50):
def verify_compile_onnx_module(model, shape_dict=None):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

graph, lib, params, dumps = tvmc.compile(
model, target="llvm", dump_code="ll", shape_dict=shape_dict
)
mod, params = tvmc.load(model, shape_dict=shape_dict)
graph, lib, params, dumps = tvmc.compile(mod, params, target="llvm", dump_code="ll")

# check for output types
assert type(graph) is str
Expand All @@ -156,8 +159,10 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

mod, params = tvmc.load(onnx_resnet50)
graph, lib, params, dumps = tvmc.compile(
onnx_resnet50,
mod,
params,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
dump_code="asm",
)
Expand All @@ -173,9 +178,10 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
@tvm.testing.requires_opencl
def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
pytest.importorskip("tflite")

mod, params = tvmc.load(tflite_mobilenet_v1_0_25_128)
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_0_25_128,
mod,
params,
target="opencl --host=llvm",
alter_layout="NCHW",
)
Expand All @@ -193,9 +199,9 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
)
def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")

mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_1_quant, target="ethos-n77, llvm", dump_code="relay"
mod, params, target="ethos-n77, llvm", dump_code="relay"
)

# check for output types
Expand All @@ -207,7 +213,7 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant

@mock.patch("tvm.relay.build")
@mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target")
@mock.patch("tvm.driver.tvmc.frontends.load_model")
@mock.patch("tvm.driver.tvmc.load")
@mock.patch("tvm.transform.PassContext")
def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_relay):
mock_codegen = {}
Expand All @@ -218,9 +224,8 @@ def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_
mock_ct.return_value = mock_codegen
mock_relay.return_value = mock.MagicMock()

graph, lib, params, dumps = tvmc.compile(
"no_file_needed", target="mockcodegen -testopt=value, llvm"
)
mod, params = tvmc.load("no_file_needed")
graph, lib, params, dumps = tvmc.compile(mod, params, target="mockcodegen -testopt=value, llvm")

mock_pc.assert_called_once_with(
opt_level=3, config={"relay.ext.mock.options": {"testopt": "value"}}
Expand Down

0 comments on commit b1a68c3

Please sign in to comment.