Skip to content

Commit

Permalink
Add support for tflite arg_min and arg_max (apache#5992)
Browse files Browse the repository at this point in the history
* [Relay][Frontend][TFLite] Add parser support for arg_min_max

* this implementation supports only the case when the axis is a scalar
* tflite 1.13 removes all dims of size 1, Relay doesn't do this
* WARNING: every newer version of tflite > 1.13 needs keepdims=TRUE

* Migrated to tflite 2.1.0

keepdims set to False and added some checks

Note the unit tests emmitted following warning:
/workspace/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(T_multiply_red_temp, 0x53f5050)

* linter

* Removed quantized argmin

Removed quantized argmin due to inablility to provide proper test case

* added negative ranges

* re-trigger CI

Co-authored-by: Ina_Dobreva <Ina.Dobreva@arm.com>
  • Loading branch information
2 people authored and Trevor Morris committed Jul 14, 2020
1 parent 67f6e4c commit 2d3a8b2
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
50 changes: 50 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(self, model, subgraph, exp_tab):
'ABS': self.convert_abs,
'ADD': self.convert_add,
'ADD_N': self.convert_add_n,
'ARG_MAX': self.convert_arg_max,
'ARG_MIN': self.convert_arg_min,
'AVERAGE_POOL_2D': self.convert_average_pool2d,
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'CAST': self.convert_cast,
Expand Down Expand Up @@ -1634,6 +1636,54 @@ def convert_reduce_sum(self, op):
def convert_reduce_any(self, op):
return self._convert_reduce(_op.reduce.any, op)

def _convert_arg_min_max(self, relay_op, op):
"""Generic method converting TFLite arg_min_max"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ArgMinOptions import ArgMinOptions
from tflite.ArgMaxOptions import ArgMaxOptions
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "two input tensor arguments expected"

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "one output tensor expected"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
axis_tensor = input_tensors[1]
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
axis_value = self.get_tensor_value(axis_tensor)
assert axis_value.size == 1
axis_value = axis_value.item()

if op.BuiltinOptionsType() == BuiltinOptions.ArgMinOptions:
arg_min_max_options = ArgMinOptions()
elif op.BuiltinOptionsType() == BuiltinOptions.ArgMaxOptions:
arg_min_max_options = ArgMaxOptions()
op_options = op.BuiltinOptions()
arg_min_max_options.Init(op_options.Bytes, op_options.Pos)

# set keepdims to True since tflite 1.13 removes all dims of size 1
# WARNING: all other versions of tflite > 1.13 need keepdims=False
out = relay_op(in_expr, axis=axis_value, keepdims=False, exclude=False)

return out

def convert_arg_min(self, op):
"""Convert TFLite ARG_MIN"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized ARG_MIN operator is not supported yet.')
return self._convert_arg_min_max(_op.argmin, op)

def convert_arg_max(self, op):
"""Convert TFLite ARG_MAX"""
return self._convert_arg_min_max(_op.argmax, op)

def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,39 @@ def test_all_reduce():
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_forward_reduce(_test_reduce_any, dtype="bool")

#######################################################################
# Arg_min_max
# -----------

def _test_arg_min_max(math_op, data, axis, quantized=False):
""" One iteration of arg_min_max"""

with tf.Graph().as_default():
t_name="in"
in_data = array_ops.placeholder(shape=data.shape, dtype=np.float32, name=t_name )
input_range=None
qmin, qmax = -100, 102
if quantized:
inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=qmin, max=qmax, name= 'q' + t_name )
input_range = { inq_data.name.split(':')[0]: (qmin, qmax)}
out = math_op(input=inq_data, axis=axis)
compare_tflite_with_tvm([data], [inq_data.name], [inq_data], [out], quantized=True, input_range=input_range)
else:
out = math_op(input=in_data, axis=axis)
compare_tflite_with_tvm([data], [in_data.name], [in_data], [out])

def test_forward_arg_min_max():
# test quantized
for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.uint8)]:
# There is no quantized version of ArgMin
for axis in [None, 0, 1, -1]:
_test_arg_min_max(math_ops.argmax, data, axis, True)

for data in [np.array(np.random.uniform(-100, 100, (3, 4)), dtype=np.float32)]:
for axis in [None, 0, 1, -1]:
_test_arg_min_max(math_ops.argmax, data, axis)
_test_arg_min_max(math_ops.argmin, data, axis)


#######################################################################
# Select, Where
Expand Down Expand Up @@ -2834,6 +2867,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_sparse_to_dense()
test_forward_select()
test_forward_quantize_dequantize()
test_forward_arg_min_max()

# NN
test_forward_convolution()
Expand Down

0 comments on commit 2d3a8b2

Please sign in to comment.