diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 36221b7467aa..1ec82372cb71 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -67,6 +67,8 @@ def __init__(self, model, subgraph, exp_tab): 'ABS': self.convert_abs, 'ADD': self.convert_add, 'ADD_N': self.convert_add_n, + 'ARG_MAX': self.convert_arg_max, + 'ARG_MIN': self.convert_arg_min, 'AVERAGE_POOL_2D': self.convert_average_pool2d, 'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd, 'CAST': self.convert_cast, @@ -1634,6 +1636,54 @@ def convert_reduce_sum(self, op): def convert_reduce_any(self, op): return self._convert_reduce(_op.reduce.any, op) + def _convert_arg_min_max(self, relay_op, op): + """Generic method converting TFLite arg_min_max""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.ArgMinOptions import ArgMinOptions + from tflite.ArgMaxOptions import ArgMaxOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "two input tensor arguments expected" + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "one output tensor expected" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + axis_tensor = input_tensors[1] + # In Tensorflow, `axis` argument is a Tensor, not attribute. We + # support the case where it inputs from a scalar constant. + axis_value = self.get_tensor_value(axis_tensor) + assert axis_value.size == 1 + axis_value = axis_value.item() + + if op.BuiltinOptionsType() == BuiltinOptions.ArgMinOptions: + arg_min_max_options = ArgMinOptions() + elif op.BuiltinOptionsType() == BuiltinOptions.ArgMaxOptions: + arg_min_max_options = ArgMaxOptions() + op_options = op.BuiltinOptions() + arg_min_max_options.Init(op_options.Bytes, op_options.Pos) + + # set keepdims to True since tflite 1.13 removes all dims of size 1 + # WARNING: all other versions of tflite > 1.13 need keepdims=False + out = relay_op(in_expr, axis=axis_value, keepdims=False, exclude=False) + + return out + + def convert_arg_min(self, op): + """Convert TFLite ARG_MIN""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized ARG_MIN operator is not supported yet.') + return self._convert_arg_min_max(_op.argmin, op) + + def convert_arg_max(self, op): + """Convert TFLite ARG_MAX""" + return self._convert_arg_min_max(_op.argmax, op) + def convert_fully_connected(self, op): """Convert TFLite fully connected""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 52491b2de308..511846728309 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1755,6 +1755,39 @@ def test_all_reduce(): if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): _test_forward_reduce(_test_reduce_any, dtype="bool") +####################################################################### +# Arg_min_max +# ----------- + +def _test_arg_min_max(math_op, data, axis, quantized=False): + """ One iteration of arg_min_max""" + + with tf.Graph().as_default(): + t_name="in" + in_data = array_ops.placeholder(shape=data.shape, dtype=np.float32, name=t_name ) + input_range=None + qmin, qmax = -100, 102 + if quantized: + inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=qmin, max=qmax, name= 'q' + t_name ) + input_range = { inq_data.name.split(':')[0]: (qmin, qmax)} + out = math_op(input=inq_data, axis=axis) + compare_tflite_with_tvm([data], [inq_data.name], [inq_data], [out], quantized=True, input_range=input_range) + else: + out = math_op(input=in_data, axis=axis) + compare_tflite_with_tvm([data], [in_data.name], [in_data], [out]) + +def test_forward_arg_min_max(): + # test quantized + for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.uint8)]: + # There is no quantized version of ArgMin + for axis in [None, 0, 1, -1]: + _test_arg_min_max(math_ops.argmax, data, axis, True) + + for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.float32)]: + for axis in [None, 0, 1, -1]: + _test_arg_min_max(math_ops.argmax, data, axis) + _test_arg_min_max(math_ops.argmin, data, axis) + ####################################################################### # Select, Where @@ -2834,6 +2867,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_sparse_to_dense() test_forward_select() test_forward_quantize_dequantize() + test_forward_arg_min_max() # NN test_forward_convolution()