Skip to content

Commit

Permalink
[Relay][Frontend][TFlite] Add parses support for unary elemwise ops (a…
Browse files Browse the repository at this point in the history
…pache#4634)

* [Relay][Frontend][Tflite] Add parses support for unary elemwise ops

* Add generic method to convert unary functions: abs, exp, ceil, floor
  log, sin, cos, sqrt, rsqrt, neg
* Add relevant tests

* Delete excessive underscores as requested in PR review

* Change parameter name as suggested in PR review
  • Loading branch information
inadob authored and alexwong committed Feb 28, 2020
1 parent 23f60be commit 444ee5e
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 0 deletions.
97 changes: 97 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def __init__(self, model, subgraph, exp_tab):

# Add more operators
self.convert_map = {
'ABS': self.convert_abs,
'EXP': self.convert_exp,
'FLOOR': self.convert_floor,
'CEIL': self.convert_ceil,
'LOG': self.convert_log,
'SIN': self.convert_sin,
'COS': self.convert_cos,
'SQRT': self.convert_sqrt,
'RSQRT': self.convert_rsqrt,
'NEG': self.convert_neg,
'CONV_2D': self.convert_conv2d,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
Expand Down Expand Up @@ -495,6 +505,93 @@ def convert_concatenation(self, op):
.format('qnn.op.concatenate'))
return out

def _convert_unary_elemwise(self, relay_op, op):
"""Generic method to convert TFLite unary elemwise functions"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = relay_op(in_expr)

return out

def convert_abs(self, op):
"""Convert TFLite ABS"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized ABS operator is not supported yet.')
return self._convert_unary_elemwise(_op.abs, op)

def convert_ceil(self, op):
"""Convert TFLite CEIL"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized CEIL operator is not supported yet.')
return self._convert_unary_elemwise(_op.ceil, op)

def convert_floor(self, op):
"""Convert TFLite FLOOR"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized FLOOR operator is not supported yet.')
return self._convert_unary_elemwise(_op.floor, op)

def convert_exp(self, op):
"""Convert TFLite EXP"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized EXP operator is not supported yet.')
return self._convert_unary_elemwise(_op.exp, op)

def convert_log(self, op):
"""Convert TFLite LOG"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized LOG operator is not supported yet.')
return self._convert_unary_elemwise(_op.log, op)

def convert_sin(self, op):
"""Convert TFLite SIN"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized SIN operator is not supported yet.')
return self._convert_unary_elemwise(_op.sin, op)

def convert_cos(self, op):
"""Convert TFLite COS"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized COS operator is not supported yet.')
return self._convert_unary_elemwise(_op.cos, op)

def convert_sqrt(self, op):
"""Convert TFLite SQRT"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized SQRT operator is not supported yet.')
return self._convert_unary_elemwise(_op.sqrt, op)

def convert_rsqrt(self, op):
"""Convert TFLite RSQRT"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized RSQRT operator is not supported yet.')
return self._convert_unary_elemwise(_op.rsqrt, op)

def convert_neg(self, op):
"""Convert TFLite NEG"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized NEG operator is not supported yet.')
return self._convert_unary_elemwise(_op.negative, op)

def _convert_elemwise(self, relay_op, op):
"""Generic method to Convert TFLite elemwise"""
try:
Expand Down
106 changes: 106 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,109 @@ def test_forward_concatenation():
np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3))], 1)

#######################################################################
# Unary elemwise
# --------------

def _test_unary_elemwise(math_op, data):
""" One iteration of unary elemwise """

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name='in')
out = math_op(in_data)
compare_tflite_with_tvm(data, ['in:0'], in_data, [out])

#######################################################################
# Abs
# ---

def _test_abs(data):
""" One iteration of abs """
return _test_unary_elemwise(math_ops.abs, data)
#######################################################################
# Ceil
# ----

def _test_ceil(data):
""" One iteration of ceil """
return _test_unary_elemwise(math_ops.ceil, data)
#######################################################################
# Floor
# -----

def _test_floor(data):
""" One iteration of floor """
return _test_unary_elemwise(math_ops.floor, data)
#######################################################################
# Exp
# ---

def _test_exp(data):
""" One iteration of exp """
return _test_unary_elemwise(math_ops.exp, data)
#######################################################################
# Log
# ---

def _test_log(data):
""" One iteration of log """
return _test_unary_elemwise(math_ops.log, data)
#######################################################################
# Sin
# ---

def _test_sin(data):
""" One iteration of sin """
return _test_unary_elemwise(math_ops.sin, data)
#######################################################################
# Cos
# ---

def _test_cos(data):
""" One iteration of cos """
return _test_unary_elemwise(math_ops.cos, data)
#######################################################################
# Sqrt
# ----

def _test_sqrt(data):
""" One iteration of sqrt """
return _test_unary_elemwise(math_ops.sqrt, data)
#######################################################################
# Rsqrt
# -----

def _test_rsqrt(data):
""" One iteration of rsqrt """
return _test_unary_elemwise(math_ops.rsqrt, data)
#######################################################################
# Neg
# ---

def _test_neg(data):
""" One iteration of neg """
return _test_unary_elemwise(math_ops.neg, data)
#######################################################################

def _test_forward_unary_elemwise(test_op):
# functions that need positive input
if test_op in {'_test_log', '_test_sqrt', '_test_rsqrt'}:
test_op(np.arange(6.0, dtype=np.float32).reshape((2, 1, 3)))
test_op(np.arange(6.0, dtype=np.int32).reshape((2, 1, 3)))
else:
np.array(np.random.uniform(-5, 5, (3, 1)), dtype=np.int32)

def test_all_unary_elemwise():
_test_forward_unary_elemwise(_test_abs)
_test_forward_unary_elemwise(_test_ceil)
_test_forward_unary_elemwise(_test_floor)
_test_forward_unary_elemwise(_test_exp)
_test_forward_unary_elemwise(_test_log)
_test_forward_unary_elemwise(_test_sin)
_test_forward_unary_elemwise(_test_cos)
_test_forward_unary_elemwise(_test_sqrt)
_test_forward_unary_elemwise(_test_rsqrt)
_test_forward_unary_elemwise(_test_neg)

#######################################################################
# Element-wise
Expand Down Expand Up @@ -1320,6 +1423,9 @@ def test_forward_mediapipe_hand_landmark():
# Elemwise
test_all_elemwise()

# Unary elemwise
test_all_unary_elemwise()

# Zeros Like
test_forward_zeros_like()

Expand Down

0 comments on commit 444ee5e

Please sign in to comment.