Skip to content

Commit

Permalink
Add LOGISTIC operator to relay tflite frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Jun 7, 2019
1 parent c4763cd commit e31f6ac
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
18 changes: 18 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, model, subgraph, exp_tab):
'CONCATENATION': self.convert_concatenation,
'ADD': self.convert_add,
'FULLY_CONNECTED': self.convert_fully_connected,
'LOGISTIC': self.convert_logistic,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -215,6 +216,23 @@ def convert_reshape(self, op):

return out

def convert_logistic(self, op):
"""Convert TFLite LOGISTIC"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

out = _op.sigmoid(in_expr)
return out

def convert_softmax(self, op):
"""Convert TFLite softmax"""
try:
Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,23 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3)), [0, 2])
_test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3])


#######################################################################
# LOGISTIC
# -------

def _test_logistic(data):
""" One iteration of LOGISTIC """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = math_ops.sigmoid(in_data)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_logistic():
""" LOGISTIC """
_test_logistic(np.arange(6.0, dtype=np.float32).reshape((1, 6)))


#######################################################################
# Softmax
# -------
Expand Down Expand Up @@ -508,6 +525,7 @@ def test_forward_inception_v4_net():

# NN
test_forward_convolution()
test_forward_logistic()
test_forward_pooling()
test_forward_softmax()
test_forward_fully_connected()
Expand Down

0 comments on commit e31f6ac

Please sign in to comment.