From 75436da6eb500dceb0483b9189a7db71d5e63484 Mon Sep 17 00:00:00 2001 From: Ina_Dobreva Date: Thu, 19 Dec 2019 12:12:41 +0000 Subject: [PATCH] [Relay][Frontend][TFLite] Add parser support for logical operators * Add parser support for logical_and, logical_or * Add boolean dtype as a valid tensor type * BOOLEAN dtype is supported only from tf 1.15 so logical ops work only in that and newer versions * Logical_not is ommited since tflite can't convert it --> throws errors for addv2 --- python/tvm/relay/frontend/tflite.py | 34 ++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 29 +++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 791c056c4a3d..7e4c37ad8235 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -117,6 +117,8 @@ def __init__(self, model, subgraph, exp_tab): 'PRELU': self.convert_prelu, 'TRANSPOSE_CONV': self.convert_transpose_conv, 'SQUARED_DIFFERENCE': self.convert_squared_difference, + 'LOGICAL_AND': self.convert_logical_and, + 'LOGICAL_OR': self.convert_logical_or, } def check_unsupported_ops(self): @@ -222,6 +224,9 @@ def get_tensor_value(self, tensor_wrapper): if tensor_wrapper.tensor.Type() == TensorType.INT64: return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape( tensor_wrapper.tensor.ShapeAsNumpy()) + if tensor_wrapper.tensor.Type() == TensorType.BOOL: + return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape( + tensor_wrapper.tensor.ShapeAsNumpy()) raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_wrapper.tensor.Type()))) @@ -240,6 +245,8 @@ def get_tensor_type_str(self, tensor_type): return "int32" if tensor_type == TensorType.INT64: return "int64" + if tensor_type == TensorType.BOOL: + return "bool" raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_type))) @@ -792,6 +799,33 @@ def convert_not_equal(self, op): 'TFlite quantized NOT_EQUAL operator is not supported yet.') return self._convert_elemwise(_op.not_equal, op) + def _convert_logical_binary(self, relay_op, op): + """Generic method to convert logical binary ops""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + lhs_tensor = input_tensors[0] + lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + rhs_tensor = input_tensors[1] + rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + out = relay_op(lhs_expr, rhs_expr) + + return out + + def convert_logical_and(self, op): + """Convert tflite LOGICAL_AND""" + return self._convert_logical_binary(_op.logical_and, op) + + def convert_logical_or(self, op): + """Convert tflite LOGICAL_OR""" + return self._convert_logical_binary(_op.logical_or, op) + def convert_zeros_like(self, op): """Convert TFLite ZEROS LIKE""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 9835bfcb46bf..a4d636ad9042 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -955,6 +955,32 @@ def test_all_elemwise(): _test_forward_elemwise(_test_equal) _test_forward_elemwise(_test_not_equal) +####################################################################### +# Logical operators +# ----------------- + +def _test_logical_binary(logical_bin_op, data): + + with tf.Graph().as_default(): + in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'), + array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')] + out = logical_bin_op(in_data[0], in_data[1], name='out') + compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) + +def _test_forward_logical_and(data): + """ One iteration of logical and """ + return _test_logical_binary(math_ops.logical_and, data) + +def _test_forward_logical_or(data): + """ One iteration of logical or """ + return _test_logical_binary(math_ops.logical_or, data) + +def test_all_logical(): + data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'), + np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')] + _test_forward_logical_and(data) + _test_forward_logical_or(data) + ####################################################################### # Zeros like # -------- @@ -1519,6 +1545,9 @@ def test_forward_mediapipe_hand_landmark(): # Reduce test_all_reduce() + # Logical + test_all_logical() + # End to End test_forward_mobilenet_v1() test_forward_mobilenet_v2()