From 393212f9ccf0a8ab68003d43a111f0f9fea418f6 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 7 Jun 2019 00:11:34 -0700 Subject: [PATCH] Add LOGISTIC operator to relay tflite frontend --- python/tvm/relay/frontend/tflite.py | 18 ++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 9c8f50f0b020e..098adda82547c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -66,6 +66,7 @@ def __init__(self, model, subgraph, exp_tab): 'ADD': self.convert_add, 'MUL': self.convert_mul, 'FULLY_CONNECTED': self.convert_fully_connected, + 'LOGISTIC': self.convert_logistic, } def check_unsupported_ops(self): @@ -217,6 +218,23 @@ def convert_reshape(self, op): return out + def convert_logistic(self, op): + """Convert TFLite LOGISTIC""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + out = _op.sigmoid(in_expr) + return out + def convert_softmax(self, op): """Convert TFLite softmax""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 677fbb87bf461..1b1170c5016f8 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -394,6 +394,23 @@ def test_forward_squeeze(): _test_squeeze(np.arange(6).reshape((1, 2, 1, 3)), [0, 2]) _test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3]) + +####################################################################### +# LOGISTIC +# ------- + +def _test_logistic(data): + """ One iteration of LOGISTIC """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = math_ops.sigmoid(in_data) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_logistic(): + """ LOGISTIC """ + _test_logistic(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + + ####################################################################### # Softmax # ------- @@ -533,6 +550,7 @@ def test_forward_inception_v4_net(): # NN test_forward_convolution() + test_forward_logistic() test_forward_pooling() test_forward_softmax() test_forward_fully_connected()