Skip to content

Commit

Permalink
[TFLite] Support PRelu (#4298)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene authored and tqchen committed Nov 10, 2019
1 parent fc28f7a commit 2f65a87
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
26 changes: 25 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def __init__(self, model, subgraph, exp_tab):
'CAST': self.convert_cast,
'TILE': self.convert_tile,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'PRELU': self.convert_prelu,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -1325,6 +1326,29 @@ def convert_space_to_batch_nd(self, op):

return reshaped_permuted_reshaped_padded

def convert_prelu(self, op):
"""Convert TFLite PReLU"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

input_tensor = input_tensors[0]
alpha_tensor = input_tensors[1]
alpha_tensor_type = alpha_tensor.tensor.Type()
alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor),
dtype=alpha_tensor_type_str)
in_expr = self.get_expr(input_tensor.tensor_idx)
out = _op.nn.prelu(in_expr, alpha_expr, axis=3)

return out


def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down
14 changes: 14 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,19 @@ def test_forward_relu():
""" ReLU """
_test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

def _test_prelu(data):
""" One iteration of PReLU """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
alpha = np.full((data.shape[-1],), 0.2, dtype=data.dtype)
# This specific pattern will be replaced into PRelu by tflite
out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data))
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_prelu():
""" PReLU """
_test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"))

#######################################################################
# Fully Connected
# -------
Expand Down Expand Up @@ -1121,6 +1134,7 @@ def test_forward_ssd_mobilenet_v1():
test_forward_softmax()
test_forward_tanh()
test_forward_relu()
test_forward_prelu()
test_forward_fully_connected()

# Elemwise
Expand Down

0 comments on commit 2f65a87

Please sign in to comment.