From e13576b45d6b3d0bf80a3f59b89e523fb0d87efb Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Wed, 29 Jan 2020 17:02:10 +0530 Subject: [PATCH] [FRONTEND][TFLITE]Gather, StridedSlice op added --- python/tvm/relay/frontend/tflite.py | 154 +++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 53 +++++++ 2 files changed, 207 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5902b92c3f567..10d80118a7ca0 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -112,6 +112,8 @@ def __init__(self, model, subgraph, exp_tab): 'PRELU': self.convert_prelu, 'TRANSPOSE_CONV': self.convert_transpose_conv, 'SQUARED_DIFFERENCE': self.convert_squared_difference, + 'GATHER': self.convert_gather, + 'STRIDED_SLICE': self.convert_strided_slice, } def check_unsupported_ops(self): @@ -747,6 +749,158 @@ def convert_squared_difference(self, op): out = _op.power(difference, relay.const(2, exp_type)) return out + def convert_gather(self, op): + """Method to Convert TFLite Gather operator""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized gather operator is not supported yet.') + input_tensors = self.get_input_tensors(op) + + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.GatherOptions import GatherOptions + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions + op_options = op.BuiltinOptions() + gather_options = GatherOptions() + gather_options.Init(op_options.Bytes, op_options.Pos) + axis = gather_options.Axis() + + data = self.get_expr(input_tensors[0].tensor_idx) + + indices = input_tensors[1] + indices_type = indices.tensor.Type() + + assert indices_type in (TensorType.INT32, TensorType.INT64) + indices_type_str = self.get_tensor_type_str(indices_type) + indices = self.exp_tab.new_const(self.get_tensor_value(indices), + dtype=indices_type_str) + out = _op.take(data, indices, axis=axis) + return out + + def convert_strided_slice(self, op): + """Method to Convert TFLite Strided Slice operator""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized strided slice operator is not supported yet.') + input_tensors = self.get_input_tensors(op) + + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.StridedSliceOptions import StridedSliceOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + data_expr = self.get_expr(input_tensors[0].tensor_idx) + + begin = list(self.get_tensor_value(input_tensors[1])) + end = list(self.get_tensor_value(input_tensors[2])) + stride = list(self.get_tensor_value(input_tensors[3])) + + assert op.BuiltinOptionsType() == BuiltinOptions.StridedSliceOptions + op_options = op.BuiltinOptions() + options = StridedSliceOptions() + options.Init(op_options.Bytes, op_options.Pos) + begin_mask = options.BeginMask() + end_mask = options.EndMask() + ellipsis_mask = options.EllipsisMask() + new_axis_mask = options.NewAxisMask() + shrink_axis_mask = options.ShrinkAxisMask() + + data_shape = list(input_tensors[0].tensor.ShapeAsNumpy()) + + data_dim = len(data_shape) + stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy())) + + def _transform_mask(stride_dim, ellipsis_mask): + """Handle mask inputs to create new begin, end, stride and output shape""" + m_begin = [0] * data_dim + m_end = [0] * data_dim + m_stride = [0] * data_dim + fshape_indices = [] + #Count new axis after ellipsis_mask, consider while applying ellipsis_mask. + ellipsis_seen = False + new_axes_after_ellipsis = 0 + for i in range(stride_dim): + mask = 1 << i + if ellipsis_seen and (mask & new_axis_mask) != 0: + new_axes_after_ellipsis += 1 + if (mask & ellipsis_mask) != 0: + ellipsis_seen = True + if not ellipsis_seen: + #Used later for extending the stride attributes in the below loop. + ellipsis_mask |= (1 << stride_dim) + stride_dim += 1 + final_index = 0 + for index in range(stride_dim): + mask = 1 << index + if mask & ellipsis_mask: + #Identify the end index for applying ellipsis_mask + to_index = min(((data_dim - (stride_dim-index)) + 1 \ + + new_axes_after_ellipsis), data_dim) + for i in range(final_index, to_index): + m_begin[final_index] = 0 + m_end[final_index] = data_shape[final_index] + m_stride[final_index] = 1 + fshape_indices.append(final_index) + final_index += 1 + elif mask &new_axis_mask: + fshape_indices.append(-1) + elif not mask & new_axis_mask: + if final_index == len(m_begin): + break + if mask & begin_mask: + m_begin[final_index] = data_shape[final_index] \ + if stride[index] < 0 else 0 + elif begin[index]: + m_begin[final_index] = begin[index] + if mask & end_mask: + m_end[final_index] = 0 if stride[index] < 0 \ + else data_shape[final_index] + elif end[index]: + m_end[final_index] = end[index] + m_stride[final_index] = stride[index] + if mask & shrink_axis_mask: + #Tensorflow make axis with shrink_axis_mask as dimension 1 + m_begin[final_index] = data_shape[final_index] + begin[index] \ + if begin[index] < 0 else begin[index] + m_end[final_index] = begin[index] + 1 + m_stride[final_index] = 1 + fshape_indices.append(-2) + else: + fshape_indices.append(final_index) + + final_index += 1 + return m_begin, m_end, m_stride, fshape_indices + + fshape_indices = None + if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: + begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) + + out = _op.strided_slice(data_expr, begin=begin, end=end, strides=stride) + out_shape = _infer_shape(out) + if not fshape_indices: + fshape_indices = range(len(out_shape)) + + #Create final output shape. + final_output = [] + for gather_index in fshape_indices: + if gather_index == -1: + final_output.append(1) + elif gather_index == -2: + pass + else: + final_output.append(out_shape[gather_index]) + + if not final_output: + return out + return _op.reshape(out, newshape=tuple(final_output)) + def convert_zeros_like(self, op): """Convert TFLite ZEROS LIKE""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index b7550f40af1e7..bc0462770ed2e 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -244,6 +244,57 @@ def test_forward_slice(): _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1]) _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1]) +####################################################################### +# Gather +# ------ + +def _test_gather(dshape, indices, axis, dtype): + """ One iteration of Gather """ + data = np.random.uniform(1, 10, size=dshape).astype(dtype) + indices = np.asarray(indices).astype('int32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = array_ops.gather(in_data, indices, axis=axis) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_gather(): + """ GATHER """ + _test_gather((4,), [1], None, 'float32') + _test_gather((1, 4), [0], 0, 'int32') + _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32') + _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32') + _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32') + _test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32') + _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32') + _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32') + _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32') + +####################################################################### +# StridedSlice +# ------------ + +def _test_stridedslice(ip_shape, begin, end, stride, dtype, + begin_mask=0, end_mask=0, new_axis_mask=0, + shrink_axis_mask=0, ellipsis_mask=0): + """ One iteration of a Stridedslice """ + data = np.random.uniform(size=ip_shape).astype(dtype) + + with tf.Graph().as_default(): + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + out = array_ops.strided_slice(in_data, begin, end, stride, + begin_mask=begin_mask, + end_mask=end_mask, new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask, + ellipsis_mask=ellipsis_mask) + compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out]) + +def test_forward_stridedslice(): + '''test StridedSlice''' + _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1) + _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32') + _test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1) + ####################################################################### # transpose # --------- @@ -1456,6 +1507,8 @@ def test_forward_mediapipe_hand_landmark(): test_all_resize() test_forward_squeeze() test_forward_slice() + test_forward_gather() + test_forward_stridedslice() # NN test_forward_convolution()