Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TFLite] ADD_N operator #5474

Merged
merged 9 commits into from
May 7, 2020
Merged
57 changes: 35 additions & 22 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, model, subgraph, exp_tab):
self.convert_map = {
'ABS': self.convert_abs,
'ADD': self.convert_add,
'ADD_N': self.convert_add_n,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'CAST': self.convert_cast,
Expand Down Expand Up @@ -774,6 +775,21 @@ def convert_square(self, op):

return out

def get_tensor_or_const_expr(self, tensor):
if self.has_expr(tensor.tensor_idx):
maheshambule marked this conversation as resolved.
Show resolved Hide resolved
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
expr = self.get_expr(tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor),
dtype=type_str)

return expr


def _convert_elemwise(self, relay_op, op):
"""Generic method to Convert TFLite elemwise"""
try:
Expand All @@ -789,29 +805,10 @@ def _convert_elemwise(self, relay_op, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

lhs_tensor = input_tensors[0]
if self.has_expr(lhs_tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type())
lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor),
dtype=lhs_type_str)

lhs_tensor= input_tensors[0]
rhs_tensor = input_tensors[1]
if self.has_expr(rhs_tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type())
rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor),
dtype=rhs_type_str)
lhs_expr = self.get_tensor_or_const_expr(lhs_tensor)
rhs_expr = self.get_tensor_or_const_expr(rhs_tensor)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
Expand Down Expand Up @@ -863,6 +860,22 @@ def convert_add(self, op):
return self._convert_elemwise(_qnn.op.add, op)
return self._convert_elemwise(_op.add, op)

def convert_add_n(self, op):
"""Convert TFLite ADD"""
# TFLite does not have support for quantized form of ADD_N
# Hence not adding checks for it.

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"

input_tensors = self.get_input_tensors(op)
lhs_expr = self.get_tensor_or_const_expr(input_tensors[0])
for rhs_tensor in input_tensors[1:]:
maheshambule marked this conversation as resolved.
Show resolved Hide resolved
rhs_expr = self.get_tensor_or_const_expr(rhs_tensor)
lhs_expr = _op.add(lhs_expr, rhs_expr)
return lhs_expr


def convert_sub(self, op):
"""Convert TFLite SUB"""
# Check if the input tensor is quantized, call QNN op
Expand Down
36 changes: 36 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,41 @@ def test_forward_mediapipe_hand_landmark():
tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
rtol=1e-5, atol=1e-5)

#######################################################################
maheshambule marked this conversation as resolved.
Show resolved Hide resolved
# AddN
# ----------------------


def _test_forward_add_n(inputs):
tf.reset_default_graph()
with tf.Graph().as_default():
temp = []
for each in inputs:
temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype))
output = tf.add_n(temp)
compare_tflite_with_tvm([each for each in inputs], [each.name for each in temp],
[each for each in temp], [output])


def test_forward_add_n():
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32)
m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32)
in0 = x
in1 = [x, y]
in2 = (x, y, z)
in3 = m
in4 = [m, n]
in5 = (m, n, o)
_test_forward_add_n(in0)
_test_forward_add_n(in1)
_test_forward_add_n(in2)
_test_forward_add_n(in3)
_test_forward_add_n(in4)
_test_forward_add_n(in5)

#######################################################################
# Main
# ----
Expand Down Expand Up @@ -1948,6 +1983,7 @@ def test_forward_mediapipe_hand_landmark():

# Elemwise
test_all_elemwise()
test_forward_add_n()

# Unary elemwise
test_all_unary_elemwise()
Expand Down