Skip to content

Commit

Permalink
[TFLITE]Activation functions support (apache#4978)
Browse files Browse the repository at this point in the history
* [TFLITE]elu, leaky_relu, lrn, log_softmax activation functions

* removed ops present in pr 4805

* review_comments updated
  • Loading branch information
siju-samuel authored and zhiics committed Apr 17, 2020
1 parent e1ca67f commit a353f40
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 62 deletions.
162 changes: 112 additions & 50 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,70 +62,72 @@ def __init__(self, model, subgraph, exp_tab):
# Add more operators
self.convert_map = {
'ABS': self.convert_abs,
'ADD': self.convert_add,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'CAST': self.convert_cast,
'CEIL': self.convert_ceil,
'CONCATENATION': self.convert_concatenation,
'CONV_2D': self.convert_conv2d,
'COS': self.convert_cos,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'DIV': self.convert_div,
'ELU': self.convert_elu,
'EQUAL': self.convert_equal,
'EXP': self.convert_exp,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
'FLOOR': self.convert_floor,
'CEIL': self.convert_ceil,
'FULLY_CONNECTED': self.convert_fully_connected,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'L2_NORMALIZATION': self.convert_l2_normalization,
'LESS_EQUAL': self.convert_less_equal,
'LESS': self.convert_less,
'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn,
'LOG': self.convert_log,
'SIN': self.convert_sin,
'COS': self.convert_cos,
'TAN': self.convert_tan,
'SQRT': self.convert_sqrt,
'RSQRT': self.convert_rsqrt,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_OR': self.convert_logical_or,
'LOGISTIC': self.convert_logistic,
'MAX_POOL_2D': self.convert_max_pool2d,
'MAXIMUM': self.convert_maximum,
'MEAN': self._convert_reduce_mean,
'MINIMUM': self.convert_minimum,
'MIRROR_PAD': self.convert_mirror_pad,
'MUL': self.convert_mul,
'NEG': self.convert_neg,
'CONV_2D': self.convert_conv2d,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
'NOT_EQUAL': self.convert_not_equal,
'PACK': self.convert_pack,
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'REDUCE_MAX': self._convert_reduce_max,
'REDUCE_MIN': self._convert_reduce_min,
'REDUCE_PROD': self._convert_reduce_prod,
'RELU':self.convert_relu,
'RESHAPE': self.convert_reshape,
'RESIZE_BILINEAR': self.convert_resize_bilinear,
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'RSQRT': self.convert_rsqrt,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'SPLIT': self.convert_split,
'SQRT': self.convert_sqrt,
'SQUARE': self.convert_square,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'SQUEEZE': self.convert_squeeze,
'MAX_POOL_2D': self.convert_max_pool2d,
'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add,
'SUB': self.convert_sub,
'MUL': self.convert_mul,
'DIV': self.convert_div,
'POW': self.convert_pow,
'MAXIMUM': self.convert_maximum,
'MINIMUM': self.convert_minimum,
'GREATER': self.convert_greater,
'GREATER_EQUAL': self.convert_greater_equal,
'LESS': self.convert_less,
'LESS_EQUAL': self.convert_less_equal,
'EQUAL': self.convert_equal,
'NOT_EQUAL': self.convert_not_equal,
'ZEROS_LIKE': self.convert_zeros_like,
'REDUCE_MIN': self._convert_reduce_min,
'REDUCE_MAX': self._convert_reduce_max,
'MEAN': self._convert_reduce_mean,
'REDUCE_PROD': self._convert_reduce_prod,
'SUM': self._convert_reduce_sum,
'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad,
'MIRROR_PAD': self.convert_mirror_pad,
'PACK': self.convert_pack,
'UNPACK': self.convert_unpack,
'LOGISTIC': self.convert_logistic,
'TAN': self.convert_tan,
'TANH':self.convert_tanh,
'RELU':self.convert_relu,
'SPLIT': self.convert_split,
'SLICE': self.convert_slice,
'TRANSPOSE': self.convert_transpose,
'CAST': self.convert_cast,
'TILE': self.convert_tile,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_OR': self.convert_logical_or,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'SQUARE': self.convert_square,
'L2_NORMALIZATION': self.convert_l2_normalization,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
'ZEROS_LIKE': self.convert_zeros_like,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -455,6 +457,43 @@ def convert_l2_normalization(self, op):

return out

def convert_lrn(self, op):
"""Convert TFLite LOCAL_RESPONSE_NORMALIZATION """
try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.LocalResponseNormalizationOptions import LocalResponseNormalizationOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized LRN operator is not supported yet.')

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"

assert op.BuiltinOptionsType() == BuiltinOptions.LocalResponseNormalizationOptions
op_options = op.BuiltinOptions()
lrn_options = LocalResponseNormalizationOptions()
lrn_options.Init(op_options.Bytes, op_options.Pos)
radius = lrn_options.Radius()
bias = lrn_options.Bias()
alpha = lrn_options.Alpha()
beta = lrn_options.Beta()
size = (radius * 2) + 1
alpha = alpha * size
axis = 3 # NHWC format
out = _op.nn.lrn(in_expr, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)

return out

def convert_logistic(self, op):
"""Convert TFLite LOGISTIC"""
try:
Expand Down Expand Up @@ -693,6 +732,29 @@ def convert_neg(self, op):
'TFlite quantized NEG operator is not supported yet.')
return self._convert_unary_elemwise(_op.negative, op)

def convert_elu(self, op):
"""Convert TFLite ELU"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(op, Operator)

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized ELU operator is not supported yet.')
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
exp_type = self.get_tensor_type_str(input_tensor.tensor.Type())
out = relay.const(-1.0, exp_type) * \
_op.nn.relu(relay.const(1., exp_type) - _op.exp(in_expr)) + \
_op.nn.relu(in_expr)

return out

def convert_square(self, op):
"""Convert TFLite SQUARE"""
try:
Expand Down
54 changes: 42 additions & 12 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def convert_to_list(x):


#######################################################################
# Get a real image for e2e testing.
# --------------------------------------
# Get a real image for e2e testing
# --------------------------------
def get_real_image(im_height, im_width):
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
img_name = 'elephant-299.jpg'
Expand Down Expand Up @@ -299,7 +299,7 @@ def test_forward_transpose():

#######################################################################
# Cast
# --------
# ----

def _test_cast(data, cast_dtype):
""" One iteration of CAST """
Expand All @@ -316,8 +316,8 @@ def test_forward_cast():
_test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)

#######################################################################
# tile
# ---------
# Tile
# ----


def _test_forward_tile(in_shape, reps, dtype):
Expand Down Expand Up @@ -758,6 +758,14 @@ def _test_square(data):
""" One iteration of square """
return _test_unary_elemwise(math_ops.square, data)

#######################################################################
# Elu
# ---

def _test_elu(data):
""" One iteration of elu """
return _test_unary_elemwise(nn_ops.elu, data)

def _test_forward_unary_elemwise(test_op):
# functions that need positive input
if test_op.__name__ in {'_test_log', '_test_sqrt', '_test_rsqrt'}:
Expand All @@ -780,10 +788,11 @@ def test_all_unary_elemwise():
_test_forward_unary_elemwise(_test_ceil)
_test_forward_unary_elemwise(_test_cos)
_test_forward_unary_elemwise(_test_tan)
_test_forward_unary_elemwise(_test_elu)

#######################################################################
# Element-wise
# ---
# ------------

def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None):
""" One iteration of elemwise """
Expand Down Expand Up @@ -1049,7 +1058,7 @@ def test_all_logical():

#######################################################################
# Zeros like
# --------
# ----------

def _test_zeros_like(data):
""" One iteration of ZEROS LIKE """
Expand Down Expand Up @@ -1237,7 +1246,7 @@ def test_forward_pad():

#######################################################################
# Pack
# -------------
# ----

def _test_pack(data, axis):
""" One iteration of pack """
Expand Down Expand Up @@ -1291,6 +1300,26 @@ def test_forward_unpack():
_test_unpack(np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3)
_test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2)


#######################################################################
# Local response normalization
# ----------------------------

def _test_local_response_normalization(data, depth_radius, bias, alpha, beta):
""" One iteration of LOCAL_RESPONSE_NORMALIZATION """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype='float32', name='in_0')
out = nn_ops.local_response_normalization(in_data, depth_radius=depth_radius, bias=bias, alpha=alpha, beta=beta)
compare_tflite_with_tvm(data, 'in_0:0', [in_data], [out])

def test_forward_local_response_normalization():
""" LOCAL_RESPONSE_NORMALIZATION """
data = np.random.uniform(size=(1, 6, 4, 3)).astype('float32')
# LOCAL_RESPONSE_NORMALIZATION come with TFLite >= 1.14.0 fbs schema
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
_test_local_response_normalization(data, depth_radius=5, bias=1, alpha=1, beta=0.5)


#######################################################################
# L2 normalization
# ----------------
Expand Down Expand Up @@ -1350,7 +1379,7 @@ def test_forward_softmax():

#######################################################################
# Tanh
# --------
# ----

def _test_tanh(data):
""" One iteration of TANH """
Expand All @@ -1365,7 +1394,7 @@ def test_forward_tanh():

#######################################################################
# ReLu
# --------
# ----

def _test_relu(data):
""" One iteration of ReLU """
Expand Down Expand Up @@ -1393,7 +1422,7 @@ def test_forward_prelu():

#######################################################################
# Fully Connected
# -------
# ---------------

def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
""" One iteration of fully connected """
Expand Down Expand Up @@ -1518,7 +1547,7 @@ def test_forward_mobilenet_v2():

#######################################################################
# Inception
# ------------
# ---------

def test_forward_inception_v3_net():
"""Test the Inception V3 TF Lite model."""
Expand Down Expand Up @@ -1696,6 +1725,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_prelu()
test_forward_fully_connected()
test_forward_l2_normalization()
test_forward_local_response_normalization()

# Elemwise
test_all_elemwise()
Expand Down

0 comments on commit a353f40

Please sign in to comment.