Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMC] Python Scripting Init Files #7698

Merged
merged 7 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/driver/tvmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
"""
TVMC - TVM driver command-line interface
"""

from . import autotuner
from . import compiler
from . import runner
from .frontends import load_model as load
from .compiler import compile_model as compile
from .runner import run_module as run
20 changes: 9 additions & 11 deletions tests/python/driver/tvmc/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_save_dumps(tmpdir_factory):
def verify_compile_tflite_module(model, shape_dict=None):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
model, target="llvm", dump_code="ll", alter_layout="NCHW", shape_dict=shape_dict
)

Expand Down Expand Up @@ -74,7 +74,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_1_quant,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'",
dump_code="asm",
Expand All @@ -91,9 +91,7 @@ def test_compile_keras__save_module(keras_resnet50, tmpdir_factory):
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

graph, lib, params, dumps = tvmc.compiler.compile_model(
keras_resnet50, target="llvm", dump_code="ll"
)
graph, lib, params, dumps = tvmc.compile(keras_resnet50, target="llvm", dump_code="ll")

expected_temp_dir = tmpdir_factory.mktemp("saved_output")
expected_file_name = "saved.tar"
Expand All @@ -111,7 +109,7 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50):
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
keras_resnet50,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'",
dump_code="asm",
Expand All @@ -129,7 +127,7 @@ def verify_compile_onnx_module(model, shape_dict=None):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
model, target="llvm", dump_code="ll", shape_dict=shape_dict
)

Expand Down Expand Up @@ -158,7 +156,7 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
onnx_resnet50,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
dump_code="asm",
Expand All @@ -176,7 +174,7 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_0_25_128,
target="opencl",
target_host="llvm",
Expand All @@ -197,7 +195,7 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_1_quant, target="ethos-n77, llvm", dump_code="relay"
)

Expand All @@ -221,7 +219,7 @@ def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_
mock_ct.return_value = mock_codegen
mock_relay.return_value = mock.MagicMock()

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
"no_file_needed", target="mockcodegen -testopt=value, llvm"
)

Expand Down
18 changes: 8 additions & 10 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,22 @@ def test_load_model__invalid_path__no_language():
pytest.importorskip("tflite")

with pytest.raises(FileNotFoundError):
tvmc.frontends.load_model("not/a/file.tflite")
tvmc.load("not/a/file.tflite")


def test_load_model__invalid_path__with_language():
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

with pytest.raises(FileNotFoundError):
tvmc.frontends.load_model("not/a/file.txt", model_format="onnx")
tvmc.load("not/a/file.txt", model_format="onnx")


def test_load_model__tflite(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

mod, params = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
assert type(mod) is IRModule
assert type(params) is dict
# check whether one known value is part of the params dict
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_load_model__pb(pb_mobilenet_v1_1_quant):
# some CI environments wont offer TensorFlow, so skip in case it is not present
pytest.importorskip("tensorflow")

mod, params = tvmc.frontends.load_model(pb_mobilenet_v1_1_quant)
mod, params = tvmc.load(pb_mobilenet_v1_1_quant)
assert type(mod) is IRModule
assert type(params) is dict
# check whether one known value is part of the params dict
Expand All @@ -161,7 +161,7 @@ def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tensorflow")

with pytest.raises(OSError):
tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="keras")
tvmc.load(tflite_mobilenet_v1_1_quant, model_format="keras")


def test_load_model___wrong_language__to_tflite(keras_resnet50):
Expand All @@ -179,7 +179,7 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant):
from google.protobuf.message import DecodeError

with pytest.raises(DecodeError):
tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="onnx")
tvmc.load(tflite_mobilenet_v1_1_quant, model_format="onnx")


@pytest.mark.skip(reason="https://github.com/apache/tvm/issues/7455")
Expand All @@ -188,9 +188,7 @@ def test_load_model__pth(pytorch_resnet18):
pytest.importorskip("torch")
pytest.importorskip("torchvision")

mod, params = tvmc.frontends.load_model(
pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]}
)
mod, params = tvmc.load(pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]})
assert type(mod) is IRModule
assert type(params) is dict
# check whether one known value is part of the params dict
Expand All @@ -202,7 +200,7 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
pytest.importorskip("torch")

with pytest.raises(RuntimeError) as e:
tvmc.frontends.load_model(
tvmc.load(
tflite_mobilenet_v1_1_quant,
model_format="pytorch",
shape_dict={"input": [1, 3, 224, 224]},
Expand Down
2 changes: 1 addition & 1 deletion tests/python/driver/tvmc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_run_tflite_module__with_profile__valid_input(
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

outputs, times = tvmc.runner.run_module(
outputs, times = tvmc.run(
tflite_compiled_module_as_tarfile,
inputs_file=imagenet_cat,
hostname=None,
Expand Down