Skip to content

Commit

Permalink
[FRONTEND][TFLITE]Gather, StridedSlice op added
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Jan 29, 2020
1 parent 1b8522e commit e13576b
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 0 deletions.
154 changes: 154 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(self, model, subgraph, exp_tab):
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'GATHER': self.convert_gather,
'STRIDED_SLICE': self.convert_strided_slice,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -747,6 +749,158 @@ def convert_squared_difference(self, op):
out = _op.power(difference, relay.const(2, exp_type))
return out

def convert_gather(self, op):
"""Method to Convert TFLite Gather operator"""
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized gather operator is not supported yet.')
input_tensors = self.get_input_tensors(op)

try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.GatherOptions import GatherOptions
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions
op_options = op.BuiltinOptions()
gather_options = GatherOptions()
gather_options.Init(op_options.Bytes, op_options.Pos)
axis = gather_options.Axis()

data = self.get_expr(input_tensors[0].tensor_idx)

indices = input_tensors[1]
indices_type = indices.tensor.Type()

assert indices_type in (TensorType.INT32, TensorType.INT64)
indices_type_str = self.get_tensor_type_str(indices_type)
indices = self.exp_tab.new_const(self.get_tensor_value(indices),
dtype=indices_type_str)
out = _op.take(data, indices, axis=axis)
return out

def convert_strided_slice(self, op):
"""Method to Convert TFLite Strided Slice operator"""
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized strided slice operator is not supported yet.')
input_tensors = self.get_input_tensors(op)

try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.StridedSliceOptions import StridedSliceOptions
except ImportError:
raise ImportError("The tflite package must be installed")

data_expr = self.get_expr(input_tensors[0].tensor_idx)

begin = list(self.get_tensor_value(input_tensors[1]))
end = list(self.get_tensor_value(input_tensors[2]))
stride = list(self.get_tensor_value(input_tensors[3]))

assert op.BuiltinOptionsType() == BuiltinOptions.StridedSliceOptions
op_options = op.BuiltinOptions()
options = StridedSliceOptions()
options.Init(op_options.Bytes, op_options.Pos)
begin_mask = options.BeginMask()
end_mask = options.EndMask()
ellipsis_mask = options.EllipsisMask()
new_axis_mask = options.NewAxisMask()
shrink_axis_mask = options.ShrinkAxisMask()

data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())

data_dim = len(data_shape)
stride_dim = len(list(input_tensors[3].tensor.ShapeAsNumpy()))

def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
fshape_indices = []
#Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
for i in range(stride_dim):
mask = 1 << i
if ellipsis_seen and (mask & new_axis_mask) != 0:
new_axes_after_ellipsis += 1
if (mask & ellipsis_mask) != 0:
ellipsis_seen = True
if not ellipsis_seen:
#Used later for extending the stride attributes in the below loop.
ellipsis_mask |= (1 << stride_dim)
stride_dim += 1
final_index = 0
for index in range(stride_dim):
mask = 1 << index
if mask & ellipsis_mask:
#Identify the end index for applying ellipsis_mask
to_index = min(((data_dim - (stride_dim-index)) + 1 \
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
elif mask &new_axis_mask:
fshape_indices.append(-1)
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
fshape_indices.append(final_index)

final_index += 1
return m_begin, m_end, m_stride, fshape_indices

fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)

out = _op.strided_slice(data_expr, begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out)
if not fshape_indices:
fshape_indices = range(len(out_shape))

#Create final output shape.
final_output = []
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
elif gather_index == -2:
pass
else:
final_output.append(out_shape[gather_index])

if not final_output:
return out
return _op.reshape(out, newshape=tuple(final_output))

def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE"""
try:
Expand Down
53 changes: 53 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,57 @@ def test_forward_slice():
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])

#######################################################################
# Gather
# ------

def _test_gather(dshape, indices, axis, dtype):
""" One iteration of Gather """
data = np.random.uniform(1, 10, size=dshape).astype(dtype)
indices = np.asarray(indices).astype('int32')

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.gather(in_data, indices, axis=axis)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_gather():
""" GATHER """
_test_gather((4,), [1], None, 'float32')
_test_gather((1, 4), [0], 0, 'int32')
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32')
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'int32')
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32')
_test_gather((2, 2), [[[1, 0], [0, 1]]], 0, 'float32')
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32')
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32')
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')

#######################################################################
# StridedSlice
# ------------

def _test_stridedslice(ip_shape, begin, end, stride, dtype,
begin_mask=0, end_mask=0, new_axis_mask=0,
shrink_axis_mask=0, ellipsis_mask=0):
""" One iteration of a Stridedslice """
data = np.random.uniform(size=ip_shape).astype(dtype)

with tf.Graph().as_default():
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
out = array_ops.strided_slice(in_data, begin, end, stride,
begin_mask=begin_mask,
end_mask=end_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask,
ellipsis_mask=ellipsis_mask)
compare_tflite_with_tvm(data, 'in_data:0', [in_data], [out])

def test_forward_stridedslice():
'''test StridedSlice'''
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=1)

#######################################################################
# transpose
# ---------
Expand Down Expand Up @@ -1456,6 +1507,8 @@ def test_forward_mediapipe_hand_landmark():
test_all_resize()
test_forward_squeeze()
test_forward_slice()
test_forward_gather()
test_forward_stridedslice()

# NN
test_forward_convolution()
Expand Down

0 comments on commit e13576b

Please sign in to comment.