Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMC] Allow manual shape specification in tvmc #7366

Merged
merged 13 commits into from
Feb 9, 2021
9 changes: 8 additions & 1 deletion python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ def add_tune_parser(subparsers):
# can be improved in future to add integration with a modelzoo
# or URL, for example.
parser.add_argument("FILE", help="path to the input model file")
parser.add_argument(
"--input-shapes",
help="specify non-generic shapes for model to run, format is "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"',
type=common.parse_shape_string,
default=None,
)


def drive_tune(args):
Expand All @@ -235,7 +242,7 @@ def drive_tune(args):
)

target = common.target_from_cli(args.target)
mod, params = frontends.load_model(args.FILE, args.model_format)
mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes)

# min_repeat_ms should be:
# a. the value provided by the user, if any, or
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"""
Common utility functions shared by TVMC modules.
"""
import re
import logging
import os.path
import argparse

from urllib.parse import urlparse

Expand Down Expand Up @@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str):
logger.info("RPC tracker port: %s", rpc_port)

return rpc_hostname, rpc_port


def parse_shape_string(inputs_string):
"""Parse an input shape dictionary string to a usable dictionary.
Parameters
----------
inputs_string: str
A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that
indicates the desired shape for specific model inputs.
Returns
-------
shape_dict: dict
A dictionary mapping input names to their shape for use in relay frontend converters.
"""

# Create a regex pattern that extracts each separate input mapping.
pattern = r"\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
input_mappings = re.findall(pattern, inputs_string)
if not input_mappings:
raise argparse.ArgumentTypeError(
"--input-shapes argument must be of the form "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"'
)
shape_dict = {}
for mapping in input_mappings:
# Remove whitespace.
mapping = mapping.replace(" ", "")
# Split mapping into name and shape.
name, shape_string = mapping.split(":")
# Convert shape string into a list of integers or Anys if negative.
shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")]
# Add parsed mapping to shape dictionary.
shape_dict[name] = shape

return shape_dict
16 changes: 14 additions & 2 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def add_compile_parser(subparsers):
# can be improved in future to add integration with a modelzoo
# or URL, for example.
parser.add_argument("FILE", help="path to the input model file")
parser.add_argument(
"--input-shapes",
help="specify non-generic shapes for model to run, format is "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"',
type=common.parse_shape_string,
default=None,
)


def drive_compile(args):
Expand All @@ -98,7 +105,7 @@ def drive_compile(args):
Arguments from command line parser.
Returns
--------
-------
int
Zero if successfully completed
Expand All @@ -112,6 +119,7 @@ def drive_compile(args):
args.model_format,
args.tuning_records,
args.desired_layout,
args.input_shapes,
)

if dumps:
Expand All @@ -129,6 +137,7 @@ def compile_model(
model_format=None,
tuning_records=None,
alter_layout=None,
shape_dict=None,
):
"""Compile a model from a supported framework into a TVM module.
Expand Down Expand Up @@ -158,6 +167,9 @@ def compile_model(
The layout to convert the graph to. Note, the convert layout
pass doesn't currently guarantee the whole of the graph will
be converted to the chosen layout.
shape_dict: dict, optional
A mapping from input names to their shape. When present,
the default shapes in the model will be overwritten.
Returns
-------
Expand All @@ -172,7 +184,7 @@ def compile_model(
"""
dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None
mod, params = frontends.load_model(path, model_format)
mod, params = frontends.load_model(path, model_format, shape_dict)

if alter_layout:
mod = common.convert_graph_layout(mod, alter_layout)
Expand Down
47 changes: 28 additions & 19 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def suffixes():
"""File suffixes (extensions) used by this frontend"""

@abstractmethod
def load(self, path):
def load(self, path, shape_dict=None):
"""Load a model from a given path.
Parameters
----------
path: str
Path to a file
shape_dict: dict, optional
Mapping from input names to their shapes.
Returns
-------
Expand Down Expand Up @@ -99,7 +101,7 @@ def name():
def suffixes():
return ["h5"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0103
tf, keras = import_keras()

Expand All @@ -125,8 +127,10 @@ def load(self, path):
)

inputs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
shape_dict = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
return relay.frontend.from_keras(model, shape_dict, layout="NHWC")
input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
if shape_dict is not None:
input_shapes.update(shape_dict)
return relay.frontend.from_keras(model, input_shapes, layout="NHWC")

def is_sequential_p(self, model):
_, keras = import_keras()
Expand Down Expand Up @@ -154,14 +158,14 @@ def name():
def suffixes():
return ["onnx"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0415
import onnx

# pylint: disable=E1101
model = onnx.load(path)

return relay.frontend.from_onnx(model)
return relay.frontend.from_onnx(model, shape=shape_dict)


class TensorflowFrontend(Frontend):
Expand All @@ -175,7 +179,7 @@ def name():
def suffixes():
return ["pb"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0415
import tensorflow as tf
import tvm.relay.testing.tf as tf_testing
Expand All @@ -188,7 +192,7 @@ def load(self, path):
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

logger.debug("parse TensorFlow model and convert into Relay computation graph")
return relay.frontend.from_tensorflow(graph_def)
return relay.frontend.from_tensorflow(graph_def, shape=shape_dict)


class TFLiteFrontend(Frontend):
Expand All @@ -215,7 +219,7 @@ def name():
def suffixes():
return ["tflite"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0415
import tflite.Model as model

Expand All @@ -238,11 +242,13 @@ def load(self, path):
raise TVMCException("input file not tflite version 3")

logger.debug("tflite_input_type")
shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model)
input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model)
if shape_dict is not None:
input_shapes.update(shape_dict)

logger.debug("parse TFLite model and convert into Relay computation graph")
mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict
)
return mod, params

Expand Down Expand Up @@ -285,17 +291,18 @@ def suffixes():
# Torch Script is a zip file, but can be named pth
return ["pth", "zip"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0415
import torch

traced_model = torch.jit.load(path)

inputs = list(traced_model.graph.inputs())[1:]
input_shapes = [inp.type().sizes() for inp in inputs]
if shape_dict is None:
raise TVMCException("--input-shapes must be specified for %s" % self.name())
jwfromm marked this conversation as resolved.
Show resolved Hide resolved

traced_model = torch.jit.load(path)
traced_model.eval() # Switch to inference mode
input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)]

# Convert shape dictionary to list for Pytorch frontend compatibility
input_shapes = list(shape_dict.items())

logger.debug("parse Torch model and convert into Relay computation graph")
return relay.frontend.from_pytorch(traced_model, input_shapes)
Expand Down Expand Up @@ -378,7 +385,7 @@ def guess_frontend(path):
raise TVMCException("failed to infer the model format. Please specify --model-format")


def load_model(path, model_format=None):
def load_model(path, model_format=None, shape_dict=None):
"""Load a model from a supported framework and convert it
into an equivalent relay representation.
Expand All @@ -389,6 +396,8 @@ def load_model(path, model_format=None):
model_format : str, optional
The underlying framework used to create the model.
If not specified, this will be inferred from the file type.
shape_dict : dict, optional
Mapping from input names to their shapes.
Returns
-------
Expand All @@ -404,6 +413,6 @@ def load_model(path, model_format=None):
else:
frontend = guess_frontend(path)

mod, params = frontend.load(path)
mod, params = frontend.load(path, shape_dict)

return mod, params
17 changes: 17 additions & 0 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ def keras_resnet50(tmpdir_factory):
return model_file_name


@pytest.fixture(scope="session")
def pytorch_resnet18(tmpdir_factory):
try:
import torch
import torchvision.models as models
except ImportError:
# Not all environments provide Pytorch, so skip if that's the case.
return ""
model = models.resnet18()
model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet18.pth")
# Trace model into torchscript.
traced_cpu = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
torch.jit.save(traced_cpu, model_file_name)

return model_file_name


@pytest.fixture(scope="session")
def onnx_resnet50():
base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model"
Expand Down
33 changes: 33 additions & 0 deletions tests/python/driver/tvmc/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

import tvm
from tvm import relay
from tvm.driver import tvmc


Expand Down Expand Up @@ -149,3 +150,35 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():

assert expected_host == actual_host
assert expected_port == actual_port


def test_shape_parser():
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
# Check that a valid input is parsed correctly
shape_string = "input:[10,10,10]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10]}
# Check that multiple valid input shapes are parse correctly
shape_string = "input:[10,10,10] input2:[20,20,20,20]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
# Check that alternate syntax parses correctly
shape_string = "input: [10, 10, 10] input2: [20, 20, 20, 20]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
shape_string = "input:[10,10,10],input2:[20,20,20,20]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
# Check that negative dimensions parse to Any correctly.
shape_string = "input:[-1,3,224,224]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
# Convert to strings to allow comparison with Any.
assert str(shape_dict) == "{'input': [?, 3, 224, 224]}"

# Check that invalid pattern raises expected error.
shape_string = "input:[a,10]"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
# Check that input with invalid separators raises error.
shape_string = "input:5,10 input2:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
Loading