Skip to content

Commit

Permalink
[tvmc] Introduce 'run' subcommand (part 4/4) (apache#6578)
Browse files Browse the repository at this point in the history
* [tvmc] Introduce 'run' subcommand (part 4/4)

 * Add 'tvmc run' subcommand to execute compiled modules
 * Include options to locally or remotelly using RPC
 * Include support to cpu and gpu devices


Co-authored-by: Marcus Shawcroft <marcus.shawcroft@arm.com>
Co-authored-by: Matthew Barrett <matthew.barrett@arm.com>

* adjust based on code review comments

* make test fixture to safely skip environments without tflite

* make --help option more clear

* improve error message to show expected inputs

* code-review adjusts

* update doc-string to default zeros->random

Co-authored-by: Marcus Shawcroft <marcus.shawcroft@arm.com>
Co-authored-by: Matthew Barrett <matthew.barrett@arm.com>
  • Loading branch information
3 people authored and Tushar Dey committed Oct 15, 2020
1 parent 5cb6b25 commit 2c4ad40
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 7 deletions.
1 change: 1 addition & 0 deletions python/tvm/driver/tvmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@

from . import autotuner
from . import compiler
from . import runner
35 changes: 35 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import logging
import os.path

from urllib.parse import urlparse

import tvm

from tvm import relay
Expand Down Expand Up @@ -102,3 +104,36 @@ def target_from_cli(target):
logger.debug("creating target from input: %s", target)

return tvm.target.Target(target)


def tracker_host_port_from_cli(rpc_tracker_str):
"""Extract hostname and (optional) port from strings
like "1.2.3.4:9090" or "4.3.2.1".
Used as a helper function to cover --rpc-tracker
command line argument, in different subcommands.
Parameters
----------
rpc_tracker_str : str
hostname (or IP address) and port of the RPC tracker,
in the format 'hostname[:port]'.
Returns
-------
rpc_hostname : str or None
hostname or IP address, extracted from input.
rpc_port : int or None
port number extracted from input (9090 default).
"""

rpc_hostname = rpc_port = None

if rpc_tracker_str:
parsed_url = urlparse("//%s" % rpc_tracker_str)
rpc_hostname = parsed_url.hostname
rpc_port = parsed_url.port or 9090
logger.info("RPC tracker hostname: %s", rpc_hostname)
logger.info("RPC tracker port: %s", rpc_port)

return rpc_hostname, rpc_port
4 changes: 2 additions & 2 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def compile_model(
target_host = target_host or ""

if tuning_records and os.path.exists(tuning_records):
# TODO (@leandron) a new PR will introduce the 'tune' subcommand
# the is used to generate the tuning records file
logger.debug("tuning records file provided: %s", tuning_records)
with autotvm.apply_history_best(tuning_records):
with tvm.transform.PassContext(opt_level=3):
Expand All @@ -212,6 +210,8 @@ def compile_model(
source = str(mod) if source_type == "relay" else lib.get_source(source_type)
dumps[source_type] = source

# TODO we need to update this return to use the updated graph module APIs
# as these getter functions will be deprecated in the next release (@leandron)
return graph_module.get_json(), graph_module.get_lib(), graph_module.get_params(), dumps


Expand Down
41 changes: 36 additions & 5 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import pytest
import tarfile

import tvm.driver.tvmc.compiler
import numpy as np

from tvm.contrib.download import download_testdata
from PIL import Image

from tvm.driver import tvmc

from tvm.driver.tvmc.common import convert_graph_layout
from tvm.contrib.download import download_testdata

# Support functions

Expand All @@ -40,7 +42,7 @@ def download_and_untar(model_url, model_sub_path, temp_dir):


def get_sample_compiled_module(target_dir):
"""Support function that retuns a TFLite compiled module"""
"""Support function that returns a TFLite compiled module"""
base_url = "https://storage.googleapis.com/download.tensorflow.org/models"
model_url = "mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
model_file = download_and_untar(
Expand All @@ -49,7 +51,7 @@ def get_sample_compiled_module(target_dir):
temp_dir=target_dir,
)

return tvmc.compiler.compile_model(model_file, targets=["llvm"])
return tvmc.compiler.compile_model(model_file, target="llvm")


# PyTest fixtures
Expand Down Expand Up @@ -110,10 +112,39 @@ def onnx_resnet50():

@pytest.fixture(scope="session")
def tflite_compiled_module_as_tarfile(tmpdir_factory):

# Not all CI environments will have TFLite installed
# so we need to safely skip this fixture that will
# crash the tests that rely on it.
# As this is a pytest.fixture, we cannot take advantage
# of pytest.importorskip. Using the block below instead.
try:
import tflite
except ImportError:
print("Cannot import tflite, which is required by tflite_compiled_module_as_tarfile.")
return ""

target_dir = tmpdir_factory.mktemp("data")
graph, lib, params, _ = get_sample_compiled_module(target_dir)

module_file = os.path.join(target_dir, "mock.tar")
tvmc.compiler.save_module(module_file, graph, lib, params)

return module_file


@pytest.fixture(scope="session")
def imagenet_cat(tmpdir_factory):
tmpdir_name = tmpdir_factory.mktemp("data")
cat_file_name = "imagenet_cat.npz"

cat_url = "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true"
image_path = download_testdata(cat_url, "inputs", module=["tvmc"])
resized_image = Image.open(image_path).resize((224, 224))
image_data = np.asarray(resized_image).astype("float32")
image_data = np.expand_dims(image_data, axis=0)

cat_file_full_path = os.path.join(tmpdir_name, cat_file_name)
np.savez(cat_file_full_path, input=image_data)

return cat_file_full_path
31 changes: 31 additions & 0 deletions tests/python/driver/tvmc/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,34 @@ def _is_layout_transform(node):
tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"


def test_tracker_host_port_from_cli__hostname_port():
input_str = "1.2.3.4:9090"
expected_host = "1.2.3.4"
expected_port = 9090

actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str)

assert expected_host == actual_host
assert expected_port == actual_port


def test_tracker_host_port_from_cli__hostname_port__empty():
input_str = ""

actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str)

assert actual_host is None
assert actual_port is None


def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():
input_str = "1.2.3.4"
expected_host = "1.2.3.4"
expected_port = 9090

actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str)

assert expected_host == actual_host
assert expected_port == actual_port

0 comments on commit 2c4ad40

Please sign in to comment.