Skip to content

Commit

Permalink
[Relay][Frontend][TFLite] Add parser support for logical operators
Browse files Browse the repository at this point in the history
* Add parser support for logical_and, logical_or
* Add boolean dtype as a valid tensor type
* BOOLEAN dtype is supported only from tf 1.15
  so logical ops work only in that and newer versions
* Logical_not is ommited since tflite can't convert it -->
  throws errors for addv2
  • Loading branch information
inadob committed Feb 4, 2020
1 parent 6f7d6fa commit 75436da
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
34 changes: 34 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(self, model, subgraph, exp_tab):
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_OR': self.convert_logical_or,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -222,6 +224,9 @@ def get_tensor_value(self, tensor_wrapper):
if tensor_wrapper.tensor.Type() == TensorType.INT64:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.BOOL:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_wrapper.tensor.Type())))

Expand All @@ -240,6 +245,8 @@ def get_tensor_type_str(self, tensor_type):
return "int32"
if tensor_type == TensorType.INT64:
return "int64"
if tensor_type == TensorType.BOOL:
return "bool"
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_type)))

Expand Down Expand Up @@ -792,6 +799,33 @@ def convert_not_equal(self, op):
'TFlite quantized NOT_EQUAL operator is not supported yet.')
return self._convert_elemwise(_op.not_equal, op)

def _convert_logical_binary(self, relay_op, op):
"""Generic method to convert logical binary ops"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

lhs_tensor = input_tensors[0]
lhs_expr = self.get_expr(lhs_tensor.tensor_idx)
rhs_tensor = input_tensors[1]
rhs_expr = self.get_expr(rhs_tensor.tensor_idx)
out = relay_op(lhs_expr, rhs_expr)

return out

def convert_logical_and(self, op):
"""Convert tflite LOGICAL_AND"""
return self._convert_logical_binary(_op.logical_and, op)

def convert_logical_or(self, op):
"""Convert tflite LOGICAL_OR"""
return self._convert_logical_binary(_op.logical_or, op)

def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE"""
try:
Expand Down
29 changes: 29 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,32 @@ def test_all_elemwise():
_test_forward_elemwise(_test_equal)
_test_forward_elemwise(_test_not_equal)

#######################################################################
# Logical operators
# -----------------

def _test_logical_binary(logical_bin_op, data):

with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')]
out = logical_bin_op(in_data[0], in_data[1], name='out')
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])

def _test_forward_logical_and(data):
""" One iteration of logical and """
return _test_logical_binary(math_ops.logical_and, data)

def _test_forward_logical_or(data):
""" One iteration of logical or """
return _test_logical_binary(math_ops.logical_or, data)

def test_all_logical():
data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'),
np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')]
_test_forward_logical_and(data)
_test_forward_logical_or(data)

#######################################################################
# Zeros like
# --------
Expand Down Expand Up @@ -1519,6 +1545,9 @@ def test_forward_mediapipe_hand_landmark():
# Reduce
test_all_reduce()

# Logical
test_all_logical()

# End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
Expand Down

0 comments on commit 75436da

Please sign in to comment.