diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b042af9fbe65..8966aa6b389e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -94,7 +94,8 @@ def __init__(self, model, subgraph, exp_tab): 'CAST': self.convert_cast, 'TILE': self.convert_tile, 'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd, - 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd + 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd, + 'PRELU': self.convert_prelu, } def check_unsupported_ops(self): @@ -1325,6 +1326,29 @@ def convert_space_to_batch_nd(self, op): return reshaped_permuted_reshaped_padded + def convert_prelu(self, op): + """Convert TFLite PReLU""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + input_tensor = input_tensors[0] + alpha_tensor = input_tensors[1] + alpha_tensor_type = alpha_tensor.tensor.Type() + alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type) + alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor), + dtype=alpha_tensor_type_str) + in_expr = self.get_expr(input_tensor.tensor_idx) + out = _op.nn.prelu(in_expr, alpha_expr, axis=3) + + return out + + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index de19fe34f811..c2c3cb5df48d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -894,6 +894,19 @@ def test_forward_relu(): """ ReLU """ _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6))) +def _test_prelu(data): + """ One iteration of PReLU """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + alpha = np.full((data.shape[-1],), 0.2, dtype=data.dtype) + # This specific pattern will be replaced into PRelu by tflite + out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data)) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_prelu(): + """ PReLU """ + _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32")) + ####################################################################### # Fully Connected # ------- @@ -1121,6 +1134,7 @@ def test_forward_ssd_mobilenet_v1(): test_forward_softmax() test_forward_tanh() test_forward_relu() + test_forward_prelu() test_forward_fully_connected() # Elemwise