diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 48c88d042ab8..f2a9e5852990 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1347,14 +1347,10 @@ def convert_gather(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" - data = self.get_expr(input_tensors[0].tensor_idx) - + data = self.get_tensor_expr(input_tensors[0]) indices = input_tensors[1] indices_type = indices.tensor.Type() assert indices_type in (TensorType.INT32, TensorType.INT64) - indices_type_str = self.get_tensor_type_str(indices_type) - indices = self.exp_tab.new_const(self.get_tensor_value(indices), - dtype=indices_type_str) assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions op_options = op.BuiltinOptions() @@ -1366,37 +1362,31 @@ def convert_gather(self, op): data_shape = list(input_tensors[0].tensor.ShapeAsNumpy()) data_dim = len(data_shape) - axis_n = axis - if axis_n < 0: - axis_n += axis_n + data_dim - assert axis_n >= 0, "Axis out of bounds" - assert axis_n < data_dim, "Axis out of bounds" - - indices_val = self.get_tensor_value(input_tensors[1]) - indices_shape = list(indices_val.shape) - indices_len = len(indices_shape) - - out_shape = [] - for i in range(data_dim): - if axis_n == i: - for j in range(indices_len): - out_shape.append(indices_shape[j]) - else: - out_shape.append(data_shape[i]) - - loopover = [range(s) for s in out_shape] - for idx in list(itertools.product(*loopover)): - indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)] + axis = data_dim + axis if axis < 0 else axis + assert axis >= 0, "Axis out of bounds" + assert axis < data_dim, "Axis out of bounds" - real_indices = [idx[j] for j in range(axis_n)] - real_indices.append(indices_val[tuple(indices_position)]) - real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))]) - for r, d in zip(real_indices, data_shape): - if r >= d: + if self.has_expr(indices.tensor_idx): + indices_expr = self.get_expr(indices.tensor_idx) + else: + indices_val = self.get_tensor_value(indices) + indices_expr = self.exp_tab.new_const(indices_val, + dtype=self.get_tensor_type_str(indices_type)) + indices_shape = list(indices_val.shape) + indices_len = len(indices_shape) + + out_shape = data_shape[:axis] + indices_shape[:] + data_shape[axis+1:] + + loopover = [range(s) for s in out_shape] + for idx in list(itertools.product(*loopover)): + real_indices = list(idx[:axis]) \ + + [indices_val[idx[axis: axis + indices_len]]] \ + + list(idx[axis + indices_len:]) + if np.any(np.subtract(data_shape, real_indices) < 0): raise ValueError("TFLite out of bound indices are not supported.") # Use mode 'fast' since indices are already checked within bounds. - out = _op.take(data, indices, axis=axis, mode="fast") + out = _op.take(data, indices_expr, axis=axis, mode="fast") return out def convert_gather_nd(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ebb4d77cce64..ebfa10fc35fd 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -396,20 +396,31 @@ def test_forward_topk(): # Gather # ------ -def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False): +def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False, wrap_idx=False): """ One iteration of Gather """ indices = np.asarray(indices).astype('int32') data = np.random.uniform(1, 10, size=dshape) data = data.astype(np.uint8) if quantized else data.astype(dtype) with tf.Graph().as_default(): + if wrap_idx: + in_name = "in_indices" + indices_expr = array_ops.placeholder(shape=indices.shape, dtype=indices.dtype, name=in_name) + in_tensor_name = [in_name + ":0"] + in_indices = [indices_expr] + else: + indices_expr = indices + indices = [] + in_tensor_name = [] + in_indices = [] + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data") if axis: - out = array_ops.gather(in_data, indices, axis=axis) + out = array_ops.gather(in_data, indices_expr, axis=axis) else: - out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis + out = array_ops.gather(in_data, indices_expr) #tflite conversion fails for None axis input_range = {'in_data': (-100, 100)} if quantized else None try: - compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], + compare_tflite_with_tvm([data] + indices, ['in_data:0'] + in_tensor_name, [in_data] + in_indices, [out], quantized=quantized, input_range=input_range) except ValueError as e: if not oob: @@ -420,16 +431,18 @@ def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False): def test_forward_gather(): """ GATHER """ for quantized in [False, True]: - _test_gather((4,), [1], 0, 'float32', quantized) - _test_gather((4,), [1], None, 'int32', quantized) - _test_gather((1, 4), [0], 0, 'int32', quantized) - _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized) - _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized) - _test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized) - _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized) - _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized) - _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized) - _test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized) + for wrap_idx in [False, True]: + _test_gather((4,), [1], 0, 'float32', quantized, wrap_idx) + _test_gather((4,), [1], None, 'int32', quantized, wrap_idx) + _test_gather((1, 4), [0], 0, 'int32', quantized, wrap_idx) + _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized, wrap_idx) + _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized, wrap_idx) + _test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized, wrap_idx) + _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized, wrap_idx) + _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized, wrap_idx) + _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized, wrap_idx) + _test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized, wrap_idx) + # Out of boundary error cannot be tested with wrapped index _test_gather((4,), [16], 0, 'float32', quantized, oob=True) _test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True) _test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)