From cd42c147344566c28b2753817bac36b8db790440 Mon Sep 17 00:00:00 2001 From: Ina_Dobreva Date: Fri, 3 Jan 2020 12:08:40 +0000 Subject: [PATCH] [Relay][Frontend][TFlite] Add add parser support for relational ops Add support for: greater_equal, less, less_equal, equal, not_equal Add tests for the elemwise relational ops --- python/tvm/relay/frontend/tflite.py | 35 +++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 41 ++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5902b92c3f567..9b2a633b29e9e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -89,6 +89,11 @@ def __init__(self, model, subgraph, exp_tab): 'MAXIMUM': self.convert_maximum, 'MINIMUM': self.convert_minimum, 'GREATER': self.convert_greater, + 'GREATER_EQUAL': self.convert_greater_equal, + 'LESS': self.convert_less, + 'LESS_EQUAL': self.convert_less_equal, + 'EQUAL': self.convert_equal, + 'NOT_EQUAL': self.convert_not_equal, 'ZEROS_LIKE': self.convert_zeros_like, 'REDUCE_MIN': self._convert_reduce_min, 'REDUCE_MAX': self._convert_reduce_max, @@ -747,6 +752,36 @@ def convert_squared_difference(self, op): out = _op.power(difference, relay.const(2, exp_type)) return out + def convert_greater_equal(self, op): + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized GREATER_EQUAL operator is not supported yet.') + return self._convert_elemwise(_op.greater_equal, op) + + def convert_less(self, op): + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized LESS operator is not supported yet.') + return self._convert_elemwise(_op.less, op) + + def convert_less_equal(self, op): + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized LESS_EQUAL operator is not supported yet.') + return self._convert_elemwise(_op.less_equal, op) + + def convert_equal(self, op): + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized EQUAL operator is not supported yet.') + return self._convert_elemwise(_op.equal, op) + + def convert_not_equal(self, op): + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized NOT_EQUAL operator is not supported yet.') + return self._convert_elemwise(_op.not_equal, op) + def convert_zeros_like(self, op): """Convert TFLite ZEROS LIKE""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index b7550f40af1e7..670ccc53f8a9f 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -863,6 +863,42 @@ def _test_minimum(data): def _test_greater(data): """ One iteration of greater """ return _test_elemwise(math_ops.greater, data) +####################################################################### +# Greater_equal +# ------------- + +def _test_greater_equal(data): + """ One iteration of greater_equal """ + return _test_elemwise(math_ops.greater_equal, data) +####################################################################### +# Less +# ---- + +def _test_less(data): + """ One iteration of less """ + return _test_elemwise(math_ops.less, data) +####################################################################### +# Less_equal +# ---------- + +def _test_less_equal(data): + """ One iteration of less_equal """ + return _test_elemwise(math_ops.less_equal, data) +####################################################################### +# Equal +# ----- + +def _test_equal(data): + """ One iteration of equal """ + return _test_elemwise(math_ops.equal, data) +####################################################################### +# Not_equal +# --------- + +def _test_not_equal(data): + """ One iteration of not_equal""" + return _test_elemwise(math_ops.not_equal, data) +####################################################################### ####################################################################### # Squared_difference @@ -915,6 +951,11 @@ def test_all_elemwise(): _test_forward_elemwise(_test_minimum) _test_forward_elemwise(_test_greater) _test_forward_elemwise(_test_squared_difference) + _test_forward_elemwise(_test_greater_equal) + _test_forward_elemwise(_test_less) + _test_forward_elemwise(_test_less_equal) + _test_forward_elemwise(_test_equal) + _test_forward_elemwise(_test_not_equal) ####################################################################### # Zeros like