From 71b76d11a34b0fe918d85a65415f14536414892e Mon Sep 17 00:00:00 2001 From: Ina_Dobreva Date: Wed, 4 Dec 2019 17:59:33 +0000 Subject: [PATCH] [Relay][Frontend][TFlite] Add parses support for SLICE * TFlite 1.13: convertor gives nonsense output when size[i]==-1 * TF parser: SLICE need fixing for size[i]==-1 -> gives wrong output bcs of indices --- python/tvm/relay/frontend/tflite.py | 30 ++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 21 ++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index e2dc0e77d980e..77026f21822c4 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -90,6 +90,7 @@ def __init__(self, model, subgraph, exp_tab): 'TANH':self.convert_tanh, 'RELU':self.convert_relu, 'SPLIT': self.convert_split, + 'SLICE': self.convert_slice, 'TRANSPOSE': self.convert_transpose, 'CAST': self.convert_cast, 'TILE': self.convert_tile, @@ -1033,6 +1034,35 @@ def convert_split(self, op): return out + def convert_slice(self, op): + """Convert TFLite SLICE""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be == 3" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + begin = list(self.get_tensor_value(input_tensors[1])) + size = list(self.get_tensor_value(input_tensors[2])) + # strided_slice(Relay) needs the slice's end indices, not the size + end = size + input_tensor_shape = input_tensor.tensor.ShapeAsNumpy() + input_tensor_rank = len(input_tensor_shape) + for i in range(input_tensor_rank): + if size[i] == -1: + end[i] = input_tensor_shape[i] - begin[i] + 1 + else: + end[i] += begin[i] + + out = _op.strided_slice(in_expr, begin, end) + + return out + def convert_transpose(self, op): """transpose implementation.""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 555dc579b0d8a..6bcab3a11c635 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -219,6 +219,26 @@ def test_forward_split(): _test_split((1, 3, 6, 5), -2, 3, 'float32') _test_split((1, 3, 5, 6), -1, 3, 'float32') +####################################################################### +# slice +# ----- + +def _test_slice(data, begin, size): + """ One iteration of SLICE """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = array_ops.slice(in_data, begin, size) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_slice(): + """ SLICE """ + _test_slice(np.arange(4, dtype=np.float32).reshape((4, )), begin=[0], size=[2]) + _test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3]) + # tflite 1.13 outputs nonsense values if size[i] == -1 + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1]) + + ####################################################################### # transpose # --------- @@ -1209,6 +1229,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_reshape() test_all_resize() test_forward_squeeze() + test_forward_slice() # NN test_forward_convolution()