From 31bbc63db5ab5bfd8b95e02f776a809b57efc386 Mon Sep 17 00:00:00 2001 From: Samuel Date: Tue, 31 Mar 2020 02:21:01 +0530 Subject: [PATCH] [TFLITE]TOP_K op parser support (#5051) * [TFLITE]TOP_K op parser support * Testcase updated --- python/tvm/relay/frontend/tflite.py | 19 +++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index aa5157024eeb..7f7ae3068e2a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -129,6 +129,7 @@ def __init__(self, model, subgraph, exp_tab): 'TAN': self.convert_tan, 'TANH':self.convert_tanh, 'TILE': self.convert_tile, + 'TOPK_V2': self.convert_topk_v2, 'TRANSPOSE_CONV': self.convert_transpose_conv, 'TRANSPOSE': self.convert_transpose, 'UNPACK': self.convert_unpack, @@ -1550,6 +1551,24 @@ def convert_tile(self, op): return out + def convert_topk_v2(self, op): + """ Convert TFLite TOPK_v2 """ + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + in_expr = self.get_expr(input_tensor_idx) + k = self.get_tensor_value(input_tensors[1]) + out = _op.topk(in_expr, int(k)) + + return out + def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 5f0c4448bdda..42726b7038d5 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -272,6 +272,24 @@ def test_forward_slice(): _test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1]) _test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1]) +####################################################################### +# Topk +# ---- +def _test_topk(in_shape, k=1): + """ One iteration of TOPK """ + data = np.random.uniform(size=in_shape).astype('float32') + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = nn_ops.top_k(in_data, k, name='TopK') + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out[0]]) + +def test_forward_topk(): + """ TOPK """ + _test_topk((3,), 1) + _test_topk((3,), 3) + _test_topk((3, 5, 7), 3) + _test_topk((3, 5, 7), 3) + ####################################################################### # transpose # --------- @@ -1775,6 +1793,7 @@ def test_forward_mediapipe_hand_landmark(): test_all_resize() test_forward_squeeze() test_forward_slice() + test_forward_topk() test_forward_depthtospace() test_forward_spacetodepth()