From 8a9ca25fad4b609b3dd5f78c5dd781f0b04048e6 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Thu, 30 Apr 2020 18:46:18 +0530 Subject: [PATCH 1/3] [TFLITE]Select/Where op support for tflite frontend --- python/tvm/relay/frontend/tflite.py | 54 +++++++++++++++----- tests/python/frontend/tflite/test_forward.py | 22 ++++++++ 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5c8bbfb3c8f9..73fdd7b4989b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -99,7 +99,7 @@ def __init__(self, model, subgraph, exp_tab): 'LOGISTIC': self.convert_logistic, 'MAX_POOL_2D': self.convert_max_pool2d, 'MAXIMUM': self.convert_maximum, - 'MEAN': self._convert_reduce_mean, + 'MEAN': self.convert_reduce_mean, 'MINIMUM': self.convert_minimum, 'MIRROR_PAD': self.convert_mirror_pad, 'MUL': self.convert_mul, @@ -109,16 +109,17 @@ def __init__(self, model, subgraph, exp_tab): 'PAD': self.convert_pad, 'POW': self.convert_pow, 'PRELU': self.convert_prelu, - 'REDUCE_ANY': self._convert_reduce_any, - 'REDUCE_MAX': self._convert_reduce_max, - 'REDUCE_MIN': self._convert_reduce_min, - 'REDUCE_PROD': self._convert_reduce_prod, + 'REDUCE_ANY': self.convert_reduce_any, + 'REDUCE_MAX': self.convert_reduce_max, + 'REDUCE_MIN': self.convert_reduce_min, + 'REDUCE_PROD': self.convert_reduce_prod, 'RELU':self.convert_relu, 'RESHAPE': self.convert_reshape, 'RESIZE_BILINEAR': self.convert_resize_bilinear, 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, + 'SELECT': self.convert_select, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, 'SOFTMAX': self.convert_softmax, @@ -132,7 +133,7 @@ def __init__(self, model, subgraph, exp_tab): 'SQUEEZE': self.convert_squeeze, 'STRIDED_SLICE': self.convert_strided_slice, 'SUB': self.convert_sub, - 'SUM': self._convert_reduce_sum, + 'SUM': self.convert_reduce_sum, 'TAN': self.convert_tan, 'TANH':self.convert_tanh, 'TILE': self.convert_tile, @@ -140,6 +141,7 @@ def __init__(self, model, subgraph, exp_tab): 'TRANSPOSE_CONV': self.convert_transpose_conv, 'TRANSPOSE': self.convert_transpose, 'UNPACK': self.convert_unpack, + 'WHERE': self.convert_select, 'ZEROS_LIKE': self.convert_zeros_like, } @@ -1241,7 +1243,7 @@ def convert_fill(self, op): return out def _convert_reduce(self, relay_op, op): - """Generic method to Convert TFLite MEAN operators""" + """Generic method to Convert TFLite REDUCE operators""" try: from tflite.BuiltinOptions import BuiltinOptions from tflite.ReducerOptions import ReducerOptions @@ -1285,22 +1287,22 @@ def _convert_reduce(self, relay_op, op): return out - def _convert_reduce_min(self, op): + def convert_reduce_min(self, op): return self._convert_reduce(_op.reduce.min, op) - def _convert_reduce_max(self, op): + def convert_reduce_max(self, op): return self._convert_reduce(_op.reduce.max, op) - def _convert_reduce_mean(self, op): + def convert_reduce_mean(self, op): return self._convert_reduce(_op.reduce.mean, op) - def _convert_reduce_prod(self, op): + def convert_reduce_prod(self, op): return self._convert_reduce(_op.reduce.prod, op) - def _convert_reduce_sum(self, op): + def convert_reduce_sum(self, op): return self._convert_reduce(_op.reduce.sum, op) - def _convert_reduce_any(self, op): + def convert_reduce_any(self, op): return self._convert_reduce(_op.reduce.any, op) def convert_fully_connected(self, op): @@ -1697,6 +1699,18 @@ def convert_slice(self, op): return out + def convert_select(self, op): + """Convert TFLite SELECT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be == 3" + cond = self.get_tensor_or_const_expr(input_tensors[0]) + x = self.get_tensor_or_const_expr(input_tensors[1]) + y = self.get_tensor_or_const_expr(input_tensors[2]) + + out = _op.where(cond, x, y) + + return out + def convert_transpose(self, op): """transpose implementation.""" input_tensors = self.get_input_tensors(op) @@ -2357,6 +2371,20 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def get_tensor_or_const_expr(self, tensor): + """ Returns constant expr for constant else a tensor expr""" + if self.has_expr(tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses elemwise operators + # with constants - it means both will be tensors. + expr = self.get_expr(tensor.tensor_idx) + else: + # However, in some corner cases, the elemwise operator is not fused, + # we can receive as constant. + type_str = self.get_tensor_type_str(tensor.tensor.Type()) + expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) + + return expr + def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 26bb86dfe24a..6dfdd169a6e3 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1376,6 +1376,27 @@ def test_all_reduce(): ####################################################################### +# Select, Where +# ------------- + +def test_forward_select(): + with tf.Graph().as_default(): + with tf.Session() as sess: + input1 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input1') + input2 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input2') + mask = input1 > input2 + out = tf.where(mask, input1 + 1, input2 * 2) + in_data1 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + in_data2 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + + compare_tflite_with_tvm([in_data1, in_data2], [ + 'input1:0', 'input2:0'], [input1, input2], [out]) + + # Squeeze # ------- @@ -2014,6 +2035,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth() + test_forward_select() # NN test_forward_convolution() From 08ad3eed0d20456904cc54d8da024799e01b0db9 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Thu, 30 Apr 2020 22:26:22 +0530 Subject: [PATCH 2/3] Review comment fixed --- python/tvm/relay/frontend/tflite.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 73fdd7b4989b..ba590d8b2645 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1703,9 +1703,9 @@ def convert_select(self, op): """Convert TFLite SELECT""" input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 3, "input tensors length should be == 3" - cond = self.get_tensor_or_const_expr(input_tensors[0]) - x = self.get_tensor_or_const_expr(input_tensors[1]) - y = self.get_tensor_or_const_expr(input_tensors[2]) + cond = self.get_tensor_expr(input_tensors[0]) + x = self.get_tensor_expr(input_tensors[1]) + y = self.get_tensor_expr(input_tensors[2]) out = _op.where(cond, x, y) @@ -2371,7 +2371,7 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) - def get_tensor_or_const_expr(self, tensor): + def get_tensor_expr(self, tensor): """ Returns constant expr for constant else a tensor expr""" if self.has_expr(tensor.tensor_idx): # In most cases, we can assume that TOCO fuses elemwise operators From e30ef2563417af9f155fff86b9d51644d784ed66 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Tue, 5 May 2020 23:05:14 +0530 Subject: [PATCH 3/3] Review comment fixed --- python/tvm/relay/frontend/tflite.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ba590d8b2645..6f449311136b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -99,7 +99,7 @@ def __init__(self, model, subgraph, exp_tab): 'LOGISTIC': self.convert_logistic, 'MAX_POOL_2D': self.convert_max_pool2d, 'MAXIMUM': self.convert_maximum, - 'MEAN': self.convert_reduce_mean, + 'MEAN': self._convert_reduce_mean, 'MINIMUM': self.convert_minimum, 'MIRROR_PAD': self.convert_mirror_pad, 'MUL': self.convert_mul, @@ -109,10 +109,10 @@ def __init__(self, model, subgraph, exp_tab): 'PAD': self.convert_pad, 'POW': self.convert_pow, 'PRELU': self.convert_prelu, - 'REDUCE_ANY': self.convert_reduce_any, - 'REDUCE_MAX': self.convert_reduce_max, - 'REDUCE_MIN': self.convert_reduce_min, - 'REDUCE_PROD': self.convert_reduce_prod, + 'REDUCE_ANY': self._convert_reduce_any, + 'REDUCE_MAX': self._convert_reduce_max, + 'REDUCE_MIN': self._convert_reduce_min, + 'REDUCE_PROD': self._convert_reduce_prod, 'RELU':self.convert_relu, 'RESHAPE': self.convert_reshape, 'RESIZE_BILINEAR': self.convert_resize_bilinear, @@ -133,7 +133,7 @@ def __init__(self, model, subgraph, exp_tab): 'SQUEEZE': self.convert_squeeze, 'STRIDED_SLICE': self.convert_strided_slice, 'SUB': self.convert_sub, - 'SUM': self.convert_reduce_sum, + 'SUM': self._convert_reduce_sum, 'TAN': self.convert_tan, 'TANH':self.convert_tanh, 'TILE': self.convert_tile, @@ -1243,7 +1243,7 @@ def convert_fill(self, op): return out def _convert_reduce(self, relay_op, op): - """Generic method to Convert TFLite REDUCE operators""" + """Generic method to Convert TFLite MEAN operators""" try: from tflite.BuiltinOptions import BuiltinOptions from tflite.ReducerOptions import ReducerOptions @@ -1287,22 +1287,22 @@ def _convert_reduce(self, relay_op, op): return out - def convert_reduce_min(self, op): + def _convert_reduce_min(self, op): return self._convert_reduce(_op.reduce.min, op) - def convert_reduce_max(self, op): + def _convert_reduce_max(self, op): return self._convert_reduce(_op.reduce.max, op) - def convert_reduce_mean(self, op): + def _convert_reduce_mean(self, op): return self._convert_reduce(_op.reduce.mean, op) - def convert_reduce_prod(self, op): + def _convert_reduce_prod(self, op): return self._convert_reduce(_op.reduce.prod, op) - def convert_reduce_sum(self, op): + def _convert_reduce_sum(self, op): return self._convert_reduce(_op.reduce.sum, op) - def convert_reduce_any(self, op): + def _convert_reduce_any(self, op): return self._convert_reduce(_op.reduce.any, op) def convert_fully_connected(self, op):