Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLite] Implemented EXPAND_DIMS Operator for TFLite. #6243

Merged
merged 2 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab):
'ELU': self.convert_elu,
'EQUAL': self.convert_equal,
'EXP': self.convert_exp,
'EXPAND_DIMS': self.convert_expand_dims,
'FILL': self.convert_fill,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
Expand Down Expand Up @@ -2904,6 +2905,31 @@ def convert_detection_postprocess(self, op):
ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
return ret

def convert_expand_dims(self, op):
"""Convert TFLite EXPAND_DIMS"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

if input_tensors[0].qnn_params:
# Check that input and output tensor have same qnn params.
output_tensors = self.get_output_tensors(op)
assert self.has_same_qnn_params(input_tensors[0], output_tensors[0]), \
"TFLite EXPAND_DIMS requires input and output tensors' \
scale and zero points to be equal"

input_expr = self.get_tensor_expr(input_tensors[0])
axis = self.get_tensor_value(input_tensors[1])
if isinstance(axis, np.ndarray):
assert len(axis) == 1, "only one value is expected."
axis = int(axis)

ndims = len(input_tensors[0].tensor.ShapeAsNumpy())
assert (-1-ndims <= axis <= ndims), "axis out of range"

out = _op.expand_dims(input_expr, axis, 1)

return out

def convert_one_hot(self, op):
"""Convert TFLite ONE_HOT"""
try:
Expand Down
56 changes: 56 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,61 @@ def test_forward_padv2():
np.uint8(10)], quantized=True)


#######################################################################
# EXPAND_DIMS
# -----------

def _test_expand_dims(input_shape, input_type, axis, quantized=False):
""" One iteration of EXPAND_DIMS """
with tf.Graph().as_default():
axis= ops.convert_to_tensor(axis, dtype=axis.dtype)

if quantized:
# ignoring input_type as quantized requires uint8
input = np.random.uniform(0, 256, input_shape).astype('uint8')
in_input = tf.placeholder(dtype='float32', shape=input.shape, name="input")

input_range = {'q_input': (-100, 100)}
inq_input = tf.quantization.fake_quant_with_min_max_args(
in_input,
min=-100,
max=100,
name="q_input")

out = array_ops.expand_dims(inq_input, axis=axis)
out = tf.quantization.fake_quant_with_min_max_args(
out,
min=-100,
max=100,
name="out")

compare_tflite_with_tvm(
[input],
["q_input"],
[inq_input],
[out],
quantized=True,
input_range=input_range)
else:
input = np.random.uniform(-100, 100, input_shape).astype(input_type)
in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")

out = array_ops.expand_dims(in_input, axis=axis)

compare_tflite_with_tvm(
[input],
["input"],
[in_input],
[out])

def test_forward_expand_dims():
""" EXPAND_DIMS """
for quantized in [False, True]:
_test_expand_dims((6, 2, 7, 5), 'float32', np.int32(0), quantized=quantized)
_test_expand_dims((1, 2, 3), 'int32', np.int32(-2), quantized=quantized)
_test_expand_dims((2, 4, 5), 'float32', np.array([1], dtype=np.int32), quantized=quantized)


#######################################################################
# ONE_HOT
# -------
Expand Down Expand Up @@ -3021,6 +3076,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_select()
test_forward_quantize_dequantize()
test_forward_arg_min_max()
test_forward_expand_dims()

# NN
test_forward_convolution()
Expand Down