Skip to content

Commit

Permalink
[Relay][Frontend][TFlite] Add parses support for SLICE
Browse files Browse the repository at this point in the history
* TFlite 1.13: convertor gives nonsense output when size[i]==-1
* TF parser: SLICE need fixing for size[i]==-1 -> gives wrong output
  bcs of indices
  • Loading branch information
inadob committed Dec 9, 2019
1 parent b16e2ff commit 71b76d1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
30 changes: 30 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self, model, subgraph, exp_tab):
'TANH':self.convert_tanh,
'RELU':self.convert_relu,
'SPLIT': self.convert_split,
'SLICE': self.convert_slice,
'TRANSPOSE': self.convert_transpose,
'CAST': self.convert_cast,
'TILE': self.convert_tile,
Expand Down Expand Up @@ -1033,6 +1034,35 @@ def convert_split(self, op):

return out

def convert_slice(self, op):
"""Convert TFLite SLICE"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be == 3"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

begin = list(self.get_tensor_value(input_tensors[1]))
size = list(self.get_tensor_value(input_tensors[2]))
# strided_slice(Relay) needs the slice's end indices, not the size
end = size
input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
input_tensor_rank = len(input_tensor_shape)
for i in range(input_tensor_rank):
if size[i] == -1:
end[i] = input_tensor_shape[i] - begin[i] + 1
else:
end[i] += begin[i]

out = _op.strided_slice(in_expr, begin, end)

return out

def convert_transpose(self, op):
"""transpose implementation."""
try:
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,26 @@ def test_forward_split():
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')

#######################################################################
# slice
# -----

def _test_slice(data, begin, size):
""" One iteration of SLICE """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = array_ops.slice(in_data, begin, size)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_slice():
""" SLICE """
_test_slice(np.arange(4, dtype=np.float32).reshape((4, )), begin=[0], size=[2])
_test_slice(np.arange(18, dtype=np.int32).reshape((3, 2, 3)), begin=[1, 0, 0], size=[1, 1, 3])
# tflite 1.13 outputs nonsense values if size[i] == -1
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])


#######################################################################
# transpose
# ---------
Expand Down Expand Up @@ -1209,6 +1229,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_reshape()
test_all_resize()
test_forward_squeeze()
test_forward_slice()

# NN
test_forward_convolution()
Expand Down

0 comments on commit 71b76d1

Please sign in to comment.