Skip to content

Commit

Permalink
[Relay][Frontend] Adding ADD operator to tflite frontend for compilin…
Browse files Browse the repository at this point in the history
…g the MobileNetV2 (apache#2919)
  • Loading branch information
gomida authored and wweic committed Apr 10, 2019
1 parent d060856 commit 4966566
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 8 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def set_expr(self, name, expr):
if name not in self.exprs:
self.exprs[name] = expr

def has_expr(self, name):
return True if name in self.exprs else False

def set_padding(self, paddings):
self.paddings = paddings
self.in_padding = True
Expand Down
49 changes: 48 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(self, model, subgraph, exp_tab):
'SOFTMAX': self.convert_softmax,
'SQUEEZE': self.convert_squeeze,
'MAX_POOL_2D': self.convert_max_pool2d,
"CONCATENATION": self.convert_concatenation
'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -292,6 +293,49 @@ def convert_concatenation(self, op):
out = self.convert_fused_activation_function(out, fused_activation_fn)
return out

def convert_add(self, op):
"""Convert TFLite add"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

lhs_tensor = input_tensors[0]
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)

rhs_tensor = input_tensors[1]
if self.has_expr(rhs_tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses ADD operators
# with constants - it means both will be tensors.
rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
else:
# However, in some corner cases, the ADD operator is not fused,
# we can receive as constant.
rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str)

# In this case, we have to be careful about formatting.
input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy())
if input_shape_length in (1, 2):
pass
elif input_shape_length == 3:
# N H*W C to N C H*W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1))
elif input_shape_length == 4:
# N H W C to N C H W
rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2))
else:
msg = 'Input shape length {} for operator ADD is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length))

out = _op.add(lhs_expr, rhs_expr)
return out

def convert_squeeze(self, op):
"""Convert TFLite squeeze"""
try:
Expand Down Expand Up @@ -554,6 +598,9 @@ def convert_pool2d(self, op, pool_type):
def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def build_str_map(obj):
"""Build string map of TFLite enum int value
Expand Down
92 changes: 85 additions & 7 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from tvm.contrib import util
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
Expand Down Expand Up @@ -99,7 +101,7 @@ def run_tflite_graph(tflite_model_buf, input_data):


def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,
output_tensors, output_need_transpose_nchw=False,
output_tensors, output_need_transpose=False,
init_global_variables=False):
"""Generic function to generate and compare TFLite and TVM output"""
tflite_in_data = convert_to_list(tflite_in_data)
Expand All @@ -126,9 +128,19 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors,

tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device)
for i in range(len(tflite_output)):
if output_need_transpose_nchw:
if output_need_transpose:
dim = len(tvm_output[i].shape)
if dim == 3:
# N C H*W to N H*W C
axes = (0, 2, 1)
elif dim == 4:
# N C H W to N H W C
axes = (0, 2, 3, 1)
else:
raise NotImplementedError("Not support input shape {} of transpose : ".
format(str(dim)))
tvm.testing.assert_allclose(tflite_output[i],
np.transpose(tvm_output[i], axes=(0, 2, 3, 1)),
np.transpose(tvm_output[i], axes=axes),
atol=1e-5, rtol=1e-5)
else:
tvm.testing.assert_allclose(tflite_output[i], tvm_output[i],
Expand All @@ -152,7 +164,7 @@ def _test_pooling_iteration(input_shape, **kwargs):
out = nn_ops.pool(in_data, **kwargs)

compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out],
output_need_transpose_nchw=True)
output_need_transpose=True)


def _test_pooling(input_shape, **kwargs):
Expand Down Expand Up @@ -236,7 +248,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
# TFLite output is NHWC, TVM is NCHW, we need transpose
compare_tflite_with_tvm(tflite_data_array, tvm_data_array,
'Placeholder:0', [in_data], [out],
output_need_transpose_nchw=True)
output_need_transpose=True)


def test_forward_convolution():
Expand Down Expand Up @@ -330,6 +342,53 @@ def test_forward_concatenation():
np.arange(6).reshape((2, 1, 1, 3))], 1)


#######################################################################
# Add
# ---

def _test_add(data):
""" One iteration of add """

assert len(data) == 2
need_transpose = False
if len(data[0].shape) == 1 or len(data[0].shape) == 2:
tvm_data = data
elif len(data[0].shape) == 3:
need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data]
elif len(data[0].shape) == 4:
need_transpose = True
tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data]
else:
raise NotImplementedError("Not support input shape {} of add : ".
format(str(len(data.shape))))

# Test with two tensors
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')]
out = math_ops.add(in_data[0], in_data[1])
compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'],
in_data, [out], need_transpose)

# Test with tensor and constant
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
compare_tflite_with_tvm([data[0]], [tvm_data[0]], ['in:0'],
in_data, [out], need_transpose)


def test_forward_add():
""" Add """
_test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))])
_test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)),
np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))])
_test_add([np.arange(3.0, dtype=np.float32).reshape((1, 3)),
np.arange(3.0, dtype=np.float32).reshape((1, 3))])


#######################################################################
# Squeeze
# -------
Expand Down Expand Up @@ -388,7 +447,7 @@ def test_forward_softmax():
# Mobilenet
# ---------

def test_forward_mobilenet():
def test_forward_mobilenet_v1():
'''test mobilenet v1 tflite model'''
# MobilenetV1
tflite_model_file = tf_testing.get_workload_official(
Expand All @@ -403,6 +462,21 @@ def test_forward_mobilenet():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

def test_forward_mobilenet_v2():
'''test mobilenet v2 tflite model'''
# MobilenetV2
tflite_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz",
"mobilenet_v2_1.0_224.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

#######################################################################
# Inception V3
# ------------
Expand Down Expand Up @@ -436,6 +510,10 @@ def test_forward_inception_v3_net():
test_forward_pooling()
test_forward_softmax()

# Math
test_forward_add()

# End to End
test_forward_mobilenet()
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
test_forward_inception_v3_net()

0 comments on commit 4966566

Please sign in to comment.