diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index d07f2af3e08b..abfb60e44ea7 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -258,6 +258,9 @@ def set_expr(self, name, expr): if name not in self.exprs: self.exprs[name] = expr + def has_expr(self, name): + return True if name in self.exprs else False + def set_padding(self, paddings): self.paddings = paddings self.in_padding = True diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 0e31500fe67d..9d616c9e4025 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -46,7 +46,8 @@ def __init__(self, model, subgraph, exp_tab): 'SOFTMAX': self.convert_softmax, 'SQUEEZE': self.convert_squeeze, 'MAX_POOL_2D': self.convert_max_pool2d, - "CONCATENATION": self.convert_concatenation + 'CONCATENATION': self.convert_concatenation, + 'ADD': self.convert_add } def check_unsupported_ops(self): @@ -292,6 +293,49 @@ def convert_concatenation(self, op): out = self.convert_fused_activation_function(out, fused_activation_fn) return out + def convert_add(self, op): + """Convert TFLite add""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + lhs_tensor = input_tensors[0] + lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + + rhs_tensor = input_tensors[1] + if self.has_expr(rhs_tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses ADD operators + # with constants - it means both will be tensors. + rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + else: + # However, in some corner cases, the ADD operator is not fused, + # we can receive as constant. + rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) + rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), + dtype=rhs_type_str) + + # In this case, we have to be careful about formatting. + input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy()) + if input_shape_length in (1, 2): + pass + elif input_shape_length == 3: + # N H*W C to N C H*W + rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1)) + elif input_shape_length == 4: + # N H W C to N C H W + rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2)) + else: + msg = 'Input shape length {} for operator ADD is not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) + + out = _op.add(lhs_expr, rhs_expr) + return out + def convert_squeeze(self, op): """Convert TFLite squeeze""" try: @@ -554,6 +598,9 @@ def convert_pool2d(self, op, pool_type): def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def has_expr(self, input_tensor_idx): + return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def build_str_map(obj): """Build string map of TFLite enum int value diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 3c41792ae903..9abbfad8e429 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -11,6 +11,8 @@ from tvm.contrib import util import tensorflow as tf from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables @@ -99,7 +101,7 @@ def run_tflite_graph(tflite_model_buf, input_data): def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, - output_tensors, output_need_transpose_nchw=False, + output_tensors, output_need_transpose=False, init_global_variables=False): """Generic function to generate and compare TFLite and TVM output""" tflite_in_data = convert_to_list(tflite_in_data) @@ -126,9 +128,19 @@ def compare_tflite_with_tvm(tflite_in_data, tvm_in_data, in_name, input_tensors, tvm_output = run_tvm_graph(tflite_model_buffer, tvm_in_data, in_node, target=device) for i in range(len(tflite_output)): - if output_need_transpose_nchw: + if output_need_transpose: + dim = len(tvm_output[i].shape) + if dim == 3: + # N C H*W to N H*W C + axes = (0, 2, 1) + elif dim == 4: + # N C H W to N H W C + axes = (0, 2, 3, 1) + else: + raise NotImplementedError("Not support input shape {} of transpose : ". + format(str(dim))) tvm.testing.assert_allclose(tflite_output[i], - np.transpose(tvm_output[i], axes=(0, 2, 3, 1)), + np.transpose(tvm_output[i], axes=axes), atol=1e-5, rtol=1e-5) else: tvm.testing.assert_allclose(tflite_output[i], tvm_output[i], @@ -152,7 +164,7 @@ def _test_pooling_iteration(input_shape, **kwargs): out = nn_ops.pool(in_data, **kwargs) compare_tflite_with_tvm(x, tvm_data, 'Placeholder:0', [in_data], [out], - output_need_transpose_nchw=True) + output_need_transpose=True) def _test_pooling(input_shape, **kwargs): @@ -236,7 +248,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, # TFLite output is NHWC, TVM is NCHW, we need transpose compare_tflite_with_tvm(tflite_data_array, tvm_data_array, 'Placeholder:0', [in_data], [out], - output_need_transpose_nchw=True) + output_need_transpose=True) def test_forward_convolution(): @@ -330,6 +342,53 @@ def test_forward_concatenation(): np.arange(6).reshape((2, 1, 1, 3))], 1) +####################################################################### +# Add +# --- + +def _test_add(data): + """ One iteration of add """ + + assert len(data) == 2 + need_transpose = False + if len(data[0].shape) == 1 or len(data[0].shape) == 2: + tvm_data = data + elif len(data[0].shape) == 3: + need_transpose = True + tvm_data = [np.transpose(d, axes=(0, 2, 1)) for d in data] + elif len(data[0].shape) == 4: + need_transpose = True + tvm_data = [np.transpose(d, axes=(0, 3, 1, 2)) for d in data] + else: + raise NotImplementedError("Not support input shape {} of add : ". + format(str(len(data.shape)))) + + # Test with two tensors + with tf.Graph().as_default(): + in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in_0'), + array_ops.placeholder(shape=data[1].shape, dtype=data[1].dtype, name='in_1')] + out = math_ops.add(in_data[0], in_data[1]) + compare_tflite_with_tvm(data, tvm_data, ['in_0:0','in_1:0'], + in_data, [out], need_transpose) + + # Test with tensor and constant + with tf.Graph().as_default(): + in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] + out = math_ops.add(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype)) + compare_tflite_with_tvm([data[0]], [tvm_data[0]], ['in:0'], + in_data, [out], need_transpose) + + +def test_forward_add(): + """ Add """ + _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), + np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3))]) + _test_add([np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)), + np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))]) + _test_add([np.arange(3.0, dtype=np.float32).reshape((1, 3)), + np.arange(3.0, dtype=np.float32).reshape((1, 3))]) + + ####################################################################### # Squeeze # ------- @@ -388,7 +447,7 @@ def test_forward_softmax(): # Mobilenet # --------- -def test_forward_mobilenet(): +def test_forward_mobilenet_v1(): '''test mobilenet v1 tflite model''' # MobilenetV1 tflite_model_file = tf_testing.get_workload_official( @@ -403,6 +462,21 @@ def test_forward_mobilenet(): tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) +def test_forward_mobilenet_v2(): + '''test mobilenet v2 tflite model''' + # MobilenetV2 + tflite_model_file = tf_testing.get_workload_official( + "http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz", + "mobilenet_v2_1.0_224.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') + tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + ####################################################################### # Inception V3 # ------------ @@ -436,6 +510,10 @@ def test_forward_inception_v3_net(): test_forward_pooling() test_forward_softmax() + # Math + test_forward_add() + # End to End - test_forward_mobilenet() + test_forward_mobilenet_v1() + test_forward_mobilenet_v2() test_forward_inception_v3_net()