Skip to content

Commit

Permalink
[TFLITE]Select op support for tflite frontend (apache#5486)
Browse files Browse the repository at this point in the history
* [TFLITE]Select/Where op support for tflite frontend

* Review comment fixed

* Review comment fixed
  • Loading branch information
siju-samuel authored and Trevor Morris committed Jun 18, 2020
1 parent f822bf4 commit c08872d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
28 changes: 28 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self, model, subgraph, exp_tab):
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'SELECT': self.convert_select,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
Expand All @@ -140,6 +141,7 @@ def __init__(self, model, subgraph, exp_tab):
'TRANSPOSE_CONV': self.convert_transpose_conv,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
'WHERE': self.convert_select,
'ZEROS_LIKE': self.convert_zeros_like,
}

Expand Down Expand Up @@ -1697,6 +1699,18 @@ def convert_slice(self, op):

return out

def convert_select(self, op):
"""Convert TFLite SELECT"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be == 3"
cond = self.get_tensor_expr(input_tensors[0])
x = self.get_tensor_expr(input_tensors[1])
y = self.get_tensor_expr(input_tensors[2])

out = _op.where(cond, x, y)

return out

def convert_transpose(self, op):
"""transpose implementation."""
input_tensors = self.get_input_tensors(op)
Expand Down Expand Up @@ -2357,6 +2371,20 @@ def get_expr(self, input_tensor_idx):
def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def get_tensor_expr(self, tensor):
""" Returns constant expr for constant else a tensor expr"""
if self.has_expr(tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
expr = self.get_expr(tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)

return expr


def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
Expand Down
22 changes: 22 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,27 @@ def test_all_reduce():


#######################################################################
# Select, Where
# -------------

def test_forward_select():
with tf.Graph().as_default():
with tf.Session() as sess:
input1 = tf.placeholder(
tf.int32, shape=[1, 4, 4, 3], name='input1')
input2 = tf.placeholder(
tf.int32, shape=[1, 4, 4, 3], name='input2')
mask = input1 > input2
out = tf.where(mask, input1 + 1, input2 * 2)
in_data1 = np.random.uniform(
0, 10, size=(1, 4, 4, 3)).astype("int32")
in_data2 = np.random.uniform(
0, 10, size=(1, 4, 4, 3)).astype("int32")

compare_tflite_with_tvm([in_data1, in_data2], [
'input1:0', 'input2:0'], [input1, input2], [out])


# Squeeze
# -------

Expand Down Expand Up @@ -1997,6 +2018,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_select()

# NN
test_forward_convolution()
Expand Down

0 comments on commit c08872d

Please sign in to comment.