Skip to content

Commit

Permalink
Gather operation with indices as tensor expr in TFLite frontend (apac…
Browse files Browse the repository at this point in the history
…he#6168)

* gather with indices as tensor expr

Added handling of indices as tensor expr
to gather operation, unit tests amended
Code cheking out of boundary error refactored
in more "pythonic" way. Fixed bug in negative
axis value normalisation

* replaced with get_tensor_expr
  • Loading branch information
d-smirnov authored and Trevor Morris committed Aug 26, 2020
1 parent ac15f11 commit e5ac35e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 46 deletions.
54 changes: 22 additions & 32 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,14 +1347,10 @@ def convert_gather(self, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

data = self.get_expr(input_tensors[0].tensor_idx)

data = self.get_tensor_expr(input_tensors[0])
indices = input_tensors[1]
indices_type = indices.tensor.Type()
assert indices_type in (TensorType.INT32, TensorType.INT64)
indices_type_str = self.get_tensor_type_str(indices_type)
indices = self.exp_tab.new_const(self.get_tensor_value(indices),
dtype=indices_type_str)

assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions
op_options = op.BuiltinOptions()
Expand All @@ -1366,37 +1362,31 @@ def convert_gather(self, op):
data_shape = list(input_tensors[0].tensor.ShapeAsNumpy())
data_dim = len(data_shape)

axis_n = axis
if axis_n < 0:
axis_n += axis_n + data_dim
assert axis_n >= 0, "Axis out of bounds"
assert axis_n < data_dim, "Axis out of bounds"

indices_val = self.get_tensor_value(input_tensors[1])
indices_shape = list(indices_val.shape)
indices_len = len(indices_shape)

out_shape = []
for i in range(data_dim):
if axis_n == i:
for j in range(indices_len):
out_shape.append(indices_shape[j])
else:
out_shape.append(data_shape[i])

loopover = [range(s) for s in out_shape]
for idx in list(itertools.product(*loopover)):
indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)]
axis = data_dim + axis if axis < 0 else axis
assert axis >= 0, "Axis out of bounds"
assert axis < data_dim, "Axis out of bounds"

real_indices = [idx[j] for j in range(axis_n)]
real_indices.append(indices_val[tuple(indices_position)])
real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))])
for r, d in zip(real_indices, data_shape):
if r >= d:
if self.has_expr(indices.tensor_idx):
indices_expr = self.get_expr(indices.tensor_idx)
else:
indices_val = self.get_tensor_value(indices)
indices_expr = self.exp_tab.new_const(indices_val,
dtype=self.get_tensor_type_str(indices_type))
indices_shape = list(indices_val.shape)
indices_len = len(indices_shape)

out_shape = data_shape[:axis] + indices_shape[:] + data_shape[axis+1:]

loopover = [range(s) for s in out_shape]
for idx in list(itertools.product(*loopover)):
real_indices = list(idx[:axis]) \
+ [indices_val[idx[axis: axis + indices_len]]] \
+ list(idx[axis + indices_len:])
if np.any(np.subtract(data_shape, real_indices) < 0):
raise ValueError("TFLite out of bound indices are not supported.")

# Use mode 'fast' since indices are already checked within bounds.
out = _op.take(data, indices, axis=axis, mode="fast")
out = _op.take(data, indices_expr, axis=axis, mode="fast")
return out

def convert_gather_nd(self, op):
Expand Down
41 changes: 27 additions & 14 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,20 +396,31 @@ def test_forward_topk():
# Gather
# ------

def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False):
def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False, wrap_idx=False):
""" One iteration of Gather """
indices = np.asarray(indices).astype('int32')
data = np.random.uniform(1, 10, size=dshape)
data = data.astype(np.uint8) if quantized else data.astype(dtype)
with tf.Graph().as_default():
if wrap_idx:
in_name = "in_indices"
indices_expr = array_ops.placeholder(shape=indices.shape, dtype=indices.dtype, name=in_name)
in_tensor_name = [in_name + ":0"]
in_indices = [indices_expr]
else:
indices_expr = indices
indices = []
in_tensor_name = []
in_indices = []

in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data")
if axis:
out = array_ops.gather(in_data, indices, axis=axis)
out = array_ops.gather(in_data, indices_expr, axis=axis)
else:
out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis
out = array_ops.gather(in_data, indices_expr) #tflite conversion fails for None axis
input_range = {'in_data': (-100, 100)} if quantized else None
try:
compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out],
compare_tflite_with_tvm([data] + indices, ['in_data:0'] + in_tensor_name, [in_data] + in_indices, [out],
quantized=quantized, input_range=input_range)
except ValueError as e:
if not oob:
Expand All @@ -420,16 +431,18 @@ def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False):
def test_forward_gather():
""" GATHER """
for quantized in [False, True]:
_test_gather((4,), [1], 0, 'float32', quantized)
_test_gather((4,), [1], None, 'int32', quantized)
_test_gather((1, 4), [0], 0, 'int32', quantized)
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized)
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized)
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized)
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized)
_test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized)
for wrap_idx in [False, True]:
_test_gather((4,), [1], 0, 'float32', quantized, wrap_idx)
_test_gather((4,), [1], None, 'int32', quantized, wrap_idx)
_test_gather((1, 4), [0], 0, 'int32', quantized, wrap_idx)
_test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized, wrap_idx)
_test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized, wrap_idx)
_test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized, wrap_idx)
_test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized, wrap_idx)
_test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized, wrap_idx)
_test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized, wrap_idx)
_test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized, wrap_idx)
# Out of boundary error cannot be tested with wrapped index
_test_gather((4,), [16], 0, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True)
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
Expand Down

0 comments on commit e5ac35e

Please sign in to comment.