Skip to content

Commit

Permalink
[TFLite] Implemented EXPAND_DIMS Operator for TFLite. (apache#6243)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainris authored and Trevor Morris committed Aug 26, 2020
1 parent 3d55cc9 commit c7e9efc
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab):
'ELU': self.convert_elu,
'EQUAL': self.convert_equal,
'EXP': self.convert_exp,
'EXPAND_DIMS': self.convert_expand_dims,
'FILL': self.convert_fill,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
Expand Down Expand Up @@ -2904,6 +2905,31 @@ def convert_detection_postprocess(self, op):
ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
return ret

def convert_expand_dims(self, op):
"""Convert TFLite EXPAND_DIMS"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

if input_tensors[0].qnn_params:
# Check that input and output tensor have same qnn params.
output_tensors = self.get_output_tensors(op)
assert self.has_same_qnn_params(input_tensors[0], output_tensors[0]), \
"TFLite EXPAND_DIMS requires input and output tensors' \
scale and zero points to be equal"

input_expr = self.get_tensor_expr(input_tensors[0])
axis = self.get_tensor_value(input_tensors[1])
if isinstance(axis, np.ndarray):
assert len(axis) == 1, "only one value is expected."
axis = int(axis)

ndims = len(input_tensors[0].tensor.ShapeAsNumpy())
assert (-1-ndims <= axis <= ndims), "axis out of range"

out = _op.expand_dims(input_expr, axis, 1)

return out

def convert_one_hot(self, op):
"""Convert TFLite ONE_HOT"""
try:
Expand Down
56 changes: 56 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,61 @@ def test_forward_padv2():
np.uint8(10)], quantized=True)


#######################################################################
# EXPAND_DIMS
# -----------

def _test_expand_dims(input_shape, input_type, axis, quantized=False):
""" One iteration of EXPAND_DIMS """
with tf.Graph().as_default():
axis= ops.convert_to_tensor(axis, dtype=axis.dtype)

if quantized:
# ignoring input_type as quantized requires uint8
input = np.random.uniform(0, 256, input_shape).astype('uint8')
in_input = tf.placeholder(dtype='float32', shape=input.shape, name="input")

input_range = {'q_input': (-100, 100)}
inq_input = tf.quantization.fake_quant_with_min_max_args(
in_input,
min=-100,
max=100,
name="q_input")

out = array_ops.expand_dims(inq_input, axis=axis)
out = tf.quantization.fake_quant_with_min_max_args(
out,
min=-100,
max=100,
name="out")

compare_tflite_with_tvm(
[input],
["q_input"],
[inq_input],
[out],
quantized=True,
input_range=input_range)
else:
input = np.random.uniform(-100, 100, input_shape).astype(input_type)
in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")

out = array_ops.expand_dims(in_input, axis=axis)

compare_tflite_with_tvm(
[input],
["input"],
[in_input],
[out])

def test_forward_expand_dims():
""" EXPAND_DIMS """
for quantized in [False, True]:
_test_expand_dims((6, 2, 7, 5), 'float32', np.int32(0), quantized=quantized)
_test_expand_dims((1, 2, 3), 'int32', np.int32(-2), quantized=quantized)
_test_expand_dims((2, 4, 5), 'float32', np.array([1], dtype=np.int32), quantized=quantized)


#######################################################################
# ONE_HOT
# -------
Expand Down Expand Up @@ -3021,6 +3076,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_select()
test_forward_quantize_dequantize()
test_forward_arg_min_max()
test_forward_expand_dims()

# NN
test_forward_convolution()
Expand Down

0 comments on commit c7e9efc

Please sign in to comment.