diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 2f877e997fde5..9eaf14f8ecfae 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -488,11 +488,9 @@ def convert_lrn(self, op): beta = lrn_options.Beta() size = (radius * 2) + 1 alpha = alpha * size - - # TFLite supports lrn only over the last dim - input_tensor_rank = len(input_tensor.tensor.ShapeAsNumpy()) - axis = input_tensor_rank - 1 + axis = 3 # NHWC format out = _op.nn.lrn(in_expr, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta) + return out def convert_logistic(self, op): @@ -728,6 +726,12 @@ def convert_neg(self, op): def convert_elu(self, op): """Convert TFLite ELU""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + assert isinstance(op, Operator) + if self.is_quantized(op): raise tvm.error.OpNotImplemented( 'TFlite quantized ELU operator is not supported yet.') @@ -740,6 +744,7 @@ def convert_elu(self, op): out = relay.const(-1.0, exp_type) * \ _op.nn.relu(relay.const(1., exp_type) - _op.exp(in_expr)) + \ _op.nn.relu(in_expr) + return out def convert_square(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 5b779979030b0..33df88e498e5a 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1307,7 +1307,9 @@ def _test_local_response_normalization(data, depth_radius, bias, alpha, beta): def test_forward_local_response_normalization(): """ LOCAL_RESPONSE_NORMALIZATION """ data = np.random.uniform(size=(1, 6, 4, 3)).astype('float32') - _test_local_response_normalization(data, depth_radius=5, bias=1, alpha=1, beta=0.5) + # LOCAL_RESPONSE_NORMALIZATION come with TFLite >= 1.14.0 fbs schema + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + _test_local_response_normalization(data, depth_radius=5, bias=1, alpha=1, beta=0.5) ####################################################################### @@ -1715,9 +1717,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_prelu() test_forward_fully_connected() test_forward_l2_normalization() - # The below activations come with TFLite >= 1.14.0 fbs schema - if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): - test_forward_local_response_normalization() + test_forward_local_response_normalization() # Elemwise test_all_elemwise()