Skip to content

Commit

Permalink
[TVMC] Allow manual shape specification in tvmc (apache#7366)
Browse files Browse the repository at this point in the history
* add ability to optionally overide tvm shapes

* add help documentation for --shapes

* improve documentation

* reformat test_compiler using black

* Incorporate feedback from ekalda for better pytorch support and testing.

* address feedback

* switch input shape syntax to be more pythonic

* add commentary

* reformat common.py

* fix lint issue

* format common.py with black

* torch/pytorch test hiccup

* add -s to setup-pytest-env.sh for clearer error msgs

Co-authored-by: Jocelyn <jocelyn@pop-os.localdomain>
  • Loading branch information
2 people authored and Lokiiiiii committed Mar 1, 2021
1 parent d837a10 commit e19bb81
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 32 deletions.
9 changes: 8 additions & 1 deletion python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ def add_tune_parser(subparsers):
# can be improved in future to add integration with a modelzoo
# or URL, for example.
parser.add_argument("FILE", help="path to the input model file")
parser.add_argument(
"--input-shapes",
help="specify non-generic shapes for model to run, format is "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"',
type=common.parse_shape_string,
default=None,
)


def drive_tune(args):
Expand All @@ -235,7 +242,7 @@ def drive_tune(args):
)

target = common.target_from_cli(args.target)
mod, params = frontends.load_model(args.FILE, args.model_format)
mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes)

# min_repeat_ms should be:
# a. the value provided by the user, if any, or
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"""
Common utility functions shared by TVMC modules.
"""
import re
import logging
import os.path
import argparse

from urllib.parse import urlparse

Expand Down Expand Up @@ -136,3 +138,40 @@ def tracker_host_port_from_cli(rpc_tracker_str):
logger.info("RPC tracker port: %s", rpc_port)

return rpc_hostname, rpc_port


def parse_shape_string(inputs_string):
"""Parse an input shape dictionary string to a usable dictionary.
Parameters
----------
inputs_string: str
A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that
indicates the desired shape for specific model inputs.
Returns
-------
shape_dict: dict
A dictionary mapping input names to their shape for use in relay frontend converters.
"""

# Create a regex pattern that extracts each separate input mapping.
pattern = r"\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
input_mappings = re.findall(pattern, inputs_string)
if not input_mappings:
raise argparse.ArgumentTypeError(
"--input-shapes argument must be of the form "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"'
)
shape_dict = {}
for mapping in input_mappings:
# Remove whitespace.
mapping = mapping.replace(" ", "")
# Split mapping into name and shape.
name, shape_string = mapping.split(":")
# Convert shape string into a list of integers or Anys if negative.
shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")]
# Add parsed mapping to shape dictionary.
shape_dict[name] = shape

return shape_dict
16 changes: 14 additions & 2 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def add_compile_parser(subparsers):
# can be improved in future to add integration with a modelzoo
# or URL, for example.
parser.add_argument("FILE", help="path to the input model file")
parser.add_argument(
"--input-shapes",
help="specify non-generic shapes for model to run, format is "
'"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"',
type=common.parse_shape_string,
default=None,
)


def drive_compile(args):
Expand All @@ -98,7 +105,7 @@ def drive_compile(args):
Arguments from command line parser.
Returns
--------
-------
int
Zero if successfully completed
Expand All @@ -112,6 +119,7 @@ def drive_compile(args):
args.model_format,
args.tuning_records,
args.desired_layout,
args.input_shapes,
)

if dumps:
Expand All @@ -129,6 +137,7 @@ def compile_model(
model_format=None,
tuning_records=None,
alter_layout=None,
shape_dict=None,
):
"""Compile a model from a supported framework into a TVM module.
Expand Down Expand Up @@ -158,6 +167,9 @@ def compile_model(
The layout to convert the graph to. Note, the convert layout
pass doesn't currently guarantee the whole of the graph will
be converted to the chosen layout.
shape_dict: dict, optional
A mapping from input names to their shape. When present,
the default shapes in the model will be overwritten.
Returns
-------
Expand All @@ -172,7 +184,7 @@ def compile_model(
"""
dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else None
mod, params = frontends.load_model(path, model_format)
mod, params = frontends.load_model(path, model_format, shape_dict)

if alter_layout:
mod = common.convert_graph_layout(mod, alter_layout)
Expand Down
47 changes: 28 additions & 19 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def suffixes():
"""File suffixes (extensions) used by this frontend"""

@abstractmethod
def load(self, path):
def load(self, path, shape_dict=None):
"""Load a model from a given path.
Parameters
----------
path: str
Path to a file
shape_dict: dict, optional
Mapping from input names to their shapes.
Returns
-------
Expand Down Expand Up @@ -99,7 +101,7 @@ def name():
def suffixes():
return ["h5"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0103
tf, keras = import_keras()

Expand All @@ -125,8 +127,10 @@ def load(self, path):
)

inputs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
shape_dict = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
return relay.frontend.from_keras(model, shape_dict, layout="NHWC")
input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
if shape_dict is not None:
input_shapes.update(shape_dict)
return relay.frontend.from_keras(model, input_shapes, layout="NHWC")

def is_sequential_p(self, model):
_, keras = import_keras()
Expand Down Expand Up @@ -154,14 +158,14 @@ def name():
def suffixes():
return ["onnx"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0415
import onnx

# pylint: disable=E1101
model = onnx.load(path)

return relay.frontend.from_onnx(model)
return relay.frontend.from_onnx(model, shape=shape_dict)


class TensorflowFrontend(Frontend):
Expand All @@ -175,7 +179,7 @@ def name():
def suffixes():
return ["pb"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0415
import tensorflow as tf
import tvm.relay.testing.tf as tf_testing
Expand All @@ -188,7 +192,7 @@ def load(self, path):
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

logger.debug("parse TensorFlow model and convert into Relay computation graph")
return relay.frontend.from_tensorflow(graph_def)
return relay.frontend.from_tensorflow(graph_def, shape=shape_dict)


class TFLiteFrontend(Frontend):
Expand All @@ -215,7 +219,7 @@ def name():
def suffixes():
return ["tflite"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0415
import tflite.Model as model

Expand All @@ -238,11 +242,13 @@ def load(self, path):
raise TVMCException("input file not tflite version 3")

logger.debug("tflite_input_type")
shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model)
input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model)
if shape_dict is not None:
input_shapes.update(shape_dict)

logger.debug("parse TFLite model and convert into Relay computation graph")
mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict
)
return mod, params

Expand Down Expand Up @@ -285,17 +291,18 @@ def suffixes():
# Torch Script is a zip file, but can be named pth
return ["pth", "zip"]

def load(self, path):
def load(self, path, shape_dict=None):
# pylint: disable=C0415
import torch

traced_model = torch.jit.load(path)

inputs = list(traced_model.graph.inputs())[1:]
input_shapes = [inp.type().sizes() for inp in inputs]
if shape_dict is None:
raise TVMCException("--input-shapes must be specified for %s" % self.name())

traced_model = torch.jit.load(path)
traced_model.eval() # Switch to inference mode
input_shapes = [("input{}".format(idx), shape) for idx, shape in enumerate(shapes)]

# Convert shape dictionary to list for Pytorch frontend compatibility
input_shapes = list(shape_dict.items())

logger.debug("parse Torch model and convert into Relay computation graph")
return relay.frontend.from_pytorch(traced_model, input_shapes)
Expand Down Expand Up @@ -378,7 +385,7 @@ def guess_frontend(path):
raise TVMCException("failed to infer the model format. Please specify --model-format")


def load_model(path, model_format=None):
def load_model(path, model_format=None, shape_dict=None):
"""Load a model from a supported framework and convert it
into an equivalent relay representation.
Expand All @@ -389,6 +396,8 @@ def load_model(path, model_format=None):
model_format : str, optional
The underlying framework used to create the model.
If not specified, this will be inferred from the file type.
shape_dict : dict, optional
Mapping from input names to their shapes.
Returns
-------
Expand All @@ -404,6 +413,6 @@ def load_model(path, model_format=None):
else:
frontend = guess_frontend(path)

mod, params = frontend.load(path)
mod, params = frontend.load(path, shape_dict)

return mod, params
17 changes: 17 additions & 0 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ def keras_resnet50(tmpdir_factory):
return model_file_name


@pytest.fixture(scope="session")
def pytorch_resnet18(tmpdir_factory):
try:
import torch
import torchvision.models as models
except ImportError:
# Not all environments provide Pytorch, so skip if that's the case.
return ""
model = models.resnet18()
model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "resnet18.pth")
# Trace model into torchscript.
traced_cpu = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
torch.jit.save(traced_cpu, model_file_name)

return model_file_name


@pytest.fixture(scope="session")
def onnx_resnet50():
base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model"
Expand Down
33 changes: 33 additions & 0 deletions tests/python/driver/tvmc/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

import tvm
from tvm import relay
from tvm.driver import tvmc


Expand Down Expand Up @@ -149,3 +150,35 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090():

assert expected_host == actual_host
assert expected_port == actual_port


def test_shape_parser():
# Check that a valid input is parsed correctly
shape_string = "input:[10,10,10]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10]}
# Check that multiple valid input shapes are parse correctly
shape_string = "input:[10,10,10] input2:[20,20,20,20]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
# Check that alternate syntax parses correctly
shape_string = "input: [10, 10, 10] input2: [20, 20, 20, 20]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
shape_string = "input:[10,10,10],input2:[20,20,20,20]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}
# Check that negative dimensions parse to Any correctly.
shape_string = "input:[-1,3,224,224]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
# Convert to strings to allow comparison with Any.
assert str(shape_dict) == "{'input': [?, 3, 224, 224]}"

# Check that invalid pattern raises expected error.
shape_string = "input:[a,10]"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
# Check that input with invalid separators raises error.
shape_string = "input:5,10 input2:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
Loading

0 comments on commit e19bb81

Please sign in to comment.