Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][TFlite] Add parses support for unary elemwise ops #4634

Merged
merged 3 commits into from
Jan 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def __init__(self, model, subgraph, exp_tab):

# Add more operators
self.convert_map = {
'ABS': self.convert_abs,
'EXP': self.convert_exp,
'FLOOR': self.convert_floor,
'CEIL': self.convert_ceil,
'LOG': self.convert_log,
'SIN': self.convert_sin,
'COS': self.convert_cos,
'SQRT': self.convert_sqrt,
'RSQRT': self.convert_rsqrt,
'NEG': self.convert_neg,
'CONV_2D': self.convert_conv2d,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
Expand Down Expand Up @@ -483,6 +493,93 @@ def convert_concatenation(self, op):
.format('qnn.op.concatenate'))
return out

def _convert_unary_elemwise(self, relay_op, op):
"""Generic method to convert TFLite unary elemwise functions"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = relay_op(in_expr)

return out

def convert_abs(self, op):
"""Convert TFLite ABS"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized ABS operator is not supported yet.')
return self._convert_unary_elemwise(_op.abs, op)

def convert_ceil(self, op):
"""Convert TFLite CEIL"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized CEIL operator is not supported yet.')
return self._convert_unary_elemwise(_op.ceil, op)

def convert_floor(self, op):
"""Convert TFLite FLOOR"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized FLOOR operator is not supported yet.')
return self._convert_unary_elemwise(_op.floor, op)

def convert_exp(self, op):
"""Convert TFLite EXP"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized EXP operator is not supported yet.')
return self._convert_unary_elemwise(_op.exp, op)

def convert_log(self, op):
"""Convert TFLite LOG"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized LOG operator is not supported yet.')
return self._convert_unary_elemwise(_op.log, op)

def convert_sin(self, op):
"""Convert TFLite SIN"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized SIN operator is not supported yet.')
return self._convert_unary_elemwise(_op.sin, op)

def convert_cos(self, op):
"""Convert TFLite COS"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized COS operator is not supported yet.')
return self._convert_unary_elemwise(_op.cos, op)

def convert_sqrt(self, op):
"""Convert TFLite SQRT"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized SQRT operator is not supported yet.')
return self._convert_unary_elemwise(_op.sqrt, op)

def convert_rsqrt(self, op):
"""Convert TFLite RSQRT"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized RSQRT operator is not supported yet.')
return self._convert_unary_elemwise(_op.rsqrt, op)

def convert_neg(self, op):
"""Convert TFLite NEG"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized NEG operator is not supported yet.')
return self._convert_unary_elemwise(_op.negative, op)

def _convert_elemwise(self, relay_op, op):
"""Generic method to Convert TFLite elemwise"""
try:
Expand Down
106 changes: 106 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,109 @@ def test_forward_concatenation():
np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3))], 1)

#######################################################################
# Unary elemwise
# --------------

def _test_unary_elemwise(math_op, data):
""" One iteration of unary elemwise """

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name='in')
out = math_op(in_data)
compare_tflite_with_tvm(data, ['in:0'], in_data, [out])

#######################################################################
# Abs
# ---

def _test_abs(data):
""" One iteration of abs """
return _test_unary_elemwise(math_ops.abs, data)
#######################################################################
# Ceil
# ----

def _test_ceil(data):
""" One iteration of ceil """
return _test_unary_elemwise(math_ops.ceil, data)
#######################################################################
# Floor
# -----

def _test_floor(data):
""" One iteration of floor """
return _test_unary_elemwise(math_ops.floor, data)
#######################################################################
# Exp
# ---

def _test_exp(data):
""" One iteration of exp """
return _test_unary_elemwise(math_ops.exp, data)
#######################################################################
# Log
# ---

def _test_log(data):
""" One iteration of log """
return _test_unary_elemwise(math_ops.log, data)
#######################################################################
# Sin
# ---

def _test_sin(data):
""" One iteration of sin """
return _test_unary_elemwise(math_ops.sin, data)
#######################################################################
# Cos
# ---

def _test_cos(data):
""" One iteration of cos """
return _test_unary_elemwise(math_ops.cos, data)
#######################################################################
# Sqrt
# ----

def _test_sqrt(data):
""" One iteration of sqrt """
return _test_unary_elemwise(math_ops.sqrt, data)
#######################################################################
# Rsqrt
# -----

def _test_rsqrt(data):
""" One iteration of rsqrt """
return _test_unary_elemwise(math_ops.rsqrt, data)
#######################################################################
# Neg
# ---

def _test_neg(data):
""" One iteration of neg """
return _test_unary_elemwise(math_ops.neg, data)
#######################################################################

def _test_forward_unary_elemwise(test_op):
# functions that need positive input
if test_op in {'_test_log', '_test_sqrt', '_test_rsqrt'}:
test_op(np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)))
test_op(np.arange(6.0, dtype=np.int32).reshape((2, 1, 3)))
else:
np.array(np.random.uniform(-5, 5, (3, 1)), dtype=np.int32)

def test_all_unary_elemwise():
_test_forward_unary_elemwise(_test_abs)
_test_forward_unary_elemwise(_test_ceil)
_test_forward_unary_elemwise(_test_floor)
_test_forward_unary_elemwise(_test_exp)
_test_forward_unary_elemwise(_test_log)
_test_forward_unary_elemwise(_test_sin)
_test_forward_unary_elemwise(_test_cos)
_test_forward_unary_elemwise(_test_sqrt)
_test_forward_unary_elemwise(_test_rsqrt)
_test_forward_unary_elemwise(_test_neg)

#######################################################################
# Element-wise
Expand Down Expand Up @@ -1320,6 +1423,9 @@ def test_forward_mediapipe_hand_landmark():
# Elemwise
test_all_elemwise()

# Unary elemwise
test_all_unary_elemwise()

# Zeros Like
test_forward_zeros_like()

Expand Down