Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLITE]Activation functions support #4978

Merged
merged 3 commits into from
Mar 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 112 additions & 50 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,70 +62,72 @@ def __init__(self, model, subgraph, exp_tab):
# Add more operators
self.convert_map = {
'ABS': self.convert_abs,
'ADD': self.convert_add,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'CAST': self.convert_cast,
'CEIL': self.convert_ceil,
'CONCATENATION': self.convert_concatenation,
'CONV_2D': self.convert_conv2d,
'COS': self.convert_cos,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'DIV': self.convert_div,
'ELU': self.convert_elu,
'EQUAL': self.convert_equal,
Comment on lines +65 to +77
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you change it as alphabetical order?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FrozenGene Yes. its better to be in alphabetical order like TF. And it reduces the conflicts while merging PR.

'EXP': self.convert_exp,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
'FLOOR': self.convert_floor,
'CEIL': self.convert_ceil,
'FULLY_CONNECTED': self.convert_fully_connected,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'L2_NORMALIZATION': self.convert_l2_normalization,
'LESS_EQUAL': self.convert_less_equal,
'LESS': self.convert_less,
'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn,
'LOG': self.convert_log,
'SIN': self.convert_sin,
'COS': self.convert_cos,
'TAN': self.convert_tan,
'SQRT': self.convert_sqrt,
'RSQRT': self.convert_rsqrt,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_OR': self.convert_logical_or,
'LOGISTIC': self.convert_logistic,
'MAX_POOL_2D': self.convert_max_pool2d,
'MAXIMUM': self.convert_maximum,
'MEAN': self._convert_reduce_mean,
'MINIMUM': self.convert_minimum,
'MIRROR_PAD': self.convert_mirror_pad,
'MUL': self.convert_mul,
'NEG': self.convert_neg,
'CONV_2D': self.convert_conv2d,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
'NOT_EQUAL': self.convert_not_equal,
'PACK': self.convert_pack,
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'REDUCE_MAX': self._convert_reduce_max,
'REDUCE_MIN': self._convert_reduce_min,
'REDUCE_PROD': self._convert_reduce_prod,
'RELU':self.convert_relu,
'RESHAPE': self.convert_reshape,
'RESIZE_BILINEAR': self.convert_resize_bilinear,
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'RSQRT': self.convert_rsqrt,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'SPLIT': self.convert_split,
'SQRT': self.convert_sqrt,
'SQUARE': self.convert_square,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'SQUEEZE': self.convert_squeeze,
'MAX_POOL_2D': self.convert_max_pool2d,
'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add,
'SUB': self.convert_sub,
'MUL': self.convert_mul,
'DIV': self.convert_div,
'POW': self.convert_pow,
'MAXIMUM': self.convert_maximum,
'MINIMUM': self.convert_minimum,
'GREATER': self.convert_greater,
'GREATER_EQUAL': self.convert_greater_equal,
'LESS': self.convert_less,
'LESS_EQUAL': self.convert_less_equal,
'EQUAL': self.convert_equal,
'NOT_EQUAL': self.convert_not_equal,
'ZEROS_LIKE': self.convert_zeros_like,
'REDUCE_MIN': self._convert_reduce_min,
'REDUCE_MAX': self._convert_reduce_max,
'MEAN': self._convert_reduce_mean,
'REDUCE_PROD': self._convert_reduce_prod,
'SUM': self._convert_reduce_sum,
'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad,
'MIRROR_PAD': self.convert_mirror_pad,
'PACK': self.convert_pack,
'UNPACK': self.convert_unpack,
'LOGISTIC': self.convert_logistic,
'TAN': self.convert_tan,
'TANH':self.convert_tanh,
'RELU':self.convert_relu,
'SPLIT': self.convert_split,
'SLICE': self.convert_slice,
'TRANSPOSE': self.convert_transpose,
'CAST': self.convert_cast,
'TILE': self.convert_tile,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_OR': self.convert_logical_or,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'SQUARE': self.convert_square,
'L2_NORMALIZATION': self.convert_l2_normalization,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
'ZEROS_LIKE': self.convert_zeros_like,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -455,6 +457,43 @@ def convert_l2_normalization(self, op):

return out

def convert_lrn(self, op):
"""Convert TFLite LOCAL_RESPONSE_NORMALIZATION """
try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.LocalResponseNormalizationOptions import LocalResponseNormalizationOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized LRN operator is not supported yet.')

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"

assert op.BuiltinOptionsType() == BuiltinOptions.LocalResponseNormalizationOptions
op_options = op.BuiltinOptions()
lrn_options = LocalResponseNormalizationOptions()
lrn_options.Init(op_options.Bytes, op_options.Pos)
radius = lrn_options.Radius()
bias = lrn_options.Bias()
alpha = lrn_options.Alpha()
beta = lrn_options.Beta()
size = (radius * 2) + 1
alpha = alpha * size
axis = 3 # NHWC format
out = _op.nn.lrn(in_expr, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)

return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to add an empty line before the return. The same applies to the rest of the functions.


def convert_logistic(self, op):
"""Convert TFLite LOGISTIC"""
try:
Expand Down Expand Up @@ -693,6 +732,29 @@ def convert_neg(self, op):
'TFlite quantized NEG operator is not supported yet.')
return self._convert_unary_elemwise(_op.negative, op)

def convert_elu(self, op):
"""Convert TFLite ELU"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, add the necessary imports.

try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized ELU operator is not supported yet.')
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
exp_type = self.get_tensor_type_str(input_tensor.tensor.Type())
out = relay.const(-1.0, exp_type) * \
_op.nn.relu(relay.const(1., exp_type) - _op.exp(in_expr)) + \
_op.nn.relu(in_expr)

return out

def convert_square(self, op):
"""Convert TFLite SQUARE"""
try:
Expand Down
54 changes: 42 additions & 12 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def convert_to_list(x):


#######################################################################
# Get a real image for e2e testing.
# --------------------------------------
# Get a real image for e2e testing
# --------------------------------
def get_real_image(im_height, im_width):
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
img_name = 'elephant-299.jpg'
Expand Down Expand Up @@ -299,7 +299,7 @@ def test_forward_transpose():

#######################################################################
# Cast
# --------
# ----

def _test_cast(data, cast_dtype):
""" One iteration of CAST """
Expand All @@ -316,8 +316,8 @@ def test_forward_cast():
_test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)

#######################################################################
# tile
# ---------
# Tile
# ----


def _test_forward_tile(in_shape, reps, dtype):
Expand Down Expand Up @@ -758,6 +758,14 @@ def _test_square(data):
""" One iteration of square """
return _test_unary_elemwise(math_ops.square, data)

#######################################################################
# Elu
# ---

def _test_elu(data):
""" One iteration of elu """
return _test_unary_elemwise(nn_ops.elu, data)

def _test_forward_unary_elemwise(test_op):
# functions that need positive input
if test_op.__name__ in {'_test_log', '_test_sqrt', '_test_rsqrt'}:
Expand All @@ -780,10 +788,11 @@ def test_all_unary_elemwise():
_test_forward_unary_elemwise(_test_ceil)
_test_forward_unary_elemwise(_test_cos)
_test_forward_unary_elemwise(_test_tan)
_test_forward_unary_elemwise(_test_elu)

#######################################################################
# Element-wise
# ---
# ------------

def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None):
""" One iteration of elemwise """
Expand Down Expand Up @@ -1049,7 +1058,7 @@ def test_all_logical():

#######################################################################
# Zeros like
# --------
# ----------

def _test_zeros_like(data):
""" One iteration of ZEROS LIKE """
Expand Down Expand Up @@ -1237,7 +1246,7 @@ def test_forward_pad():

#######################################################################
# Pack
# -------------
# ----

def _test_pack(data, axis):
""" One iteration of pack """
Expand Down Expand Up @@ -1291,6 +1300,26 @@ def test_forward_unpack():
_test_unpack(np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3)
_test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2)


#######################################################################
# Local response normalization
# ----------------------------

def _test_local_response_normalization(data, depth_radius, bias, alpha, beta):
""" One iteration of LOCAL_RESPONSE_NORMALIZATION """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
out = nn_ops.local_response_normalization(in_data, depth_radius=depth_radius, bias=bias, alpha=alpha, beta=beta)
compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])

def test_forward_local_response_normalization():
""" LOCAL_RESPONSE_NORMALIZATION """
data = np.random.uniform(size=(1, 6, 4, 3)).astype('float32')
# LOCAL_RESPONSE_NORMALIZATION come with TFLite >= 1.14.0 fbs schema
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_local_response_normalization(data, depth_radius=5, bias=1, alpha=1, beta=0.5)


#######################################################################
# L2 normalization
# ----------------
Expand Down Expand Up @@ -1350,7 +1379,7 @@ def test_forward_softmax():

#######################################################################
# Tanh
# --------
# ----

def _test_tanh(data):
""" One iteration of TANH """
Expand All @@ -1365,7 +1394,7 @@ def test_forward_tanh():

#######################################################################
# ReLu
# --------
# ----

def _test_relu(data):
""" One iteration of ReLU """
Expand Down Expand Up @@ -1393,7 +1422,7 @@ def test_forward_prelu():

#######################################################################
# Fully Connected
# -------
# ---------------

def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
""" One iteration of fully connected """
Expand Down Expand Up @@ -1518,7 +1547,7 @@ def test_forward_mobilenet_v2():

#######################################################################
# Inception
# ------------
# ---------

def test_forward_inception_v3_net():
"""Test the Inception V3 TF Lite model."""
Expand Down Expand Up @@ -1696,6 +1725,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_prelu()
test_forward_fully_connected()
test_forward_l2_normalization()
test_forward_local_response_normalization()

# Elemwise
test_all_elemwise()
Expand Down