Skip to content

Commit

Permalink
[TVMC] Allow optional arguments to be passed to importers (#7674)
Browse files Browse the repository at this point in the history
* add support for optional args for frontends tvmc

* remove unnecessary comments

* Add changes suggested by Matt W. via PR

Co-authored-by: Jocelyn <jocelyn@pop-os.localdomain>
  • Loading branch information
CircleSpin and Jocelyn authored Mar 18, 2021
1 parent 4976bb2 commit 38aed59
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
27 changes: 14 additions & 13 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def suffixes():
"""File suffixes (extensions) used by this frontend"""

@abstractmethod
def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
"""Load a model from a given path.
Parameters
Expand Down Expand Up @@ -101,7 +101,7 @@ def name():
def suffixes():
return ["h5"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0103
tf, keras = import_keras()

Expand Down Expand Up @@ -130,7 +130,8 @@ def load(self, path, shape_dict=None):
input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
if shape_dict is not None:
input_shapes.update(shape_dict)
return relay.frontend.from_keras(model, input_shapes, layout="NHWC")
kwargs.setdefault("layout", "NHWC")
return relay.frontend.from_keras(model, input_shapes, **kwargs)

def is_sequential_p(self, model):
_, keras = import_keras()
Expand Down Expand Up @@ -158,14 +159,14 @@ def name():
def suffixes():
return ["onnx"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import onnx

# pylint: disable=E1101
model = onnx.load(path)

return relay.frontend.from_onnx(model, shape=shape_dict)
return relay.frontend.from_onnx(model, shape=shape_dict, **kwargs)


class TensorflowFrontend(Frontend):
Expand All @@ -179,7 +180,7 @@ def name():
def suffixes():
return ["pb"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tensorflow as tf
import tvm.relay.testing.tf as tf_testing
Expand All @@ -192,7 +193,7 @@ def load(self, path, shape_dict=None):
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

logger.debug("parse TensorFlow model and convert into Relay computation graph")
return relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
return relay.frontend.from_tensorflow(graph_def, shape=shape_dict, **kwargs)


class TFLiteFrontend(Frontend):
Expand All @@ -206,7 +207,7 @@ def name():
def suffixes():
return ["tflite"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tflite.Model as model

Expand All @@ -229,7 +230,7 @@ def load(self, path, shape_dict=None):
raise TVMCException("input file not tflite version 3")

logger.debug("parse TFLite model and convert into Relay computation graph")
mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict)
mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, **kwargs)
return mod, params


Expand All @@ -245,7 +246,7 @@ def suffixes():
# Torch Script is a zip file, but can be named pth
return ["pth", "zip"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import torch

Expand All @@ -259,7 +260,7 @@ def load(self, path, shape_dict=None):
input_shapes = list(shape_dict.items())

logger.debug("parse Torch model and convert into Relay computation graph")
return relay.frontend.from_pytorch(traced_model, input_shapes)
return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs)


ALL_FRONTENDS = [
Expand Down Expand Up @@ -339,7 +340,7 @@ def guess_frontend(path):
raise TVMCException("failed to infer the model format. Please specify --model-format")


def load_model(path, model_format=None, shape_dict=None):
def load_model(path, model_format=None, shape_dict=None, **kwargs):
"""Load a model from a supported framework and convert it
into an equivalent relay representation.
Expand Down Expand Up @@ -367,6 +368,6 @@ def load_model(path, model_format=None, shape_dict=None):
else:
frontend = guess_frontend(path)

mod, params = frontend.load(path, shape_dict)
mod, params = frontend.load(path, shape_dict, **kwargs)

return mod, params
22 changes: 15 additions & 7 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,34 @@ def test_load_model__tflite(tflite_mobilenet_v1_1_quant):
assert "_param_1" in params.keys()


def test_load_model__keras(keras_resnet50):
@pytest.mark.parametrize("load_model_kwargs", [{}, {"layout": "NCHW"}])
def test_load_model__keras(keras_resnet50, load_model_kwargs):
# some CI environments wont offer TensorFlow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

mod, params = tvmc.frontends.load_model(keras_resnet50)
mod, params = tvmc.frontends.load_model(keras_resnet50, **load_model_kwargs)
assert type(mod) is IRModule
assert type(params) is dict
## check whether one known value is part of the params dict
assert "_param_1" in params.keys()


def verify_load_model__onnx(model, **kwargs):
mod, params = tvmc.frontends.load_model(model, **kwargs)
assert type(mod) is IRModule
assert type(params) is dict
return mod, params


def test_load_model__onnx(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

mod, params = tvmc.frontends.load_model(onnx_resnet50)
assert type(mod) is IRModule
assert type(params) is dict
## check whether one known value is part of the params dict
mod, params = verify_load_model__onnx(onnx_resnet50)
# check whether one known value is part of the params dict
assert "resnetv24_batchnorm0_gamma" in params.keys()
mod, params = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
# check that the parameter dict is empty, implying that they have been folded into constants
assert params == {}


def test_load_model__pb(pb_mobilenet_v1_1_quant):
Expand Down

0 comments on commit 38aed59

Please sign in to comment.