diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 2065d60a299e..b9a165711913 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -78,6 +78,7 @@ def __init__(self, model, subgraph, exp_tab): 'ELU': self.convert_elu, 'EQUAL': self.convert_equal, 'EXP': self.convert_exp, + 'FILL': self.convert_fill, 'FLOOR_DIV': self.convert_floor_div, 'FLOOR_MOD': self.convert_floor_mod, 'FLOOR': self.convert_floor, @@ -123,6 +124,7 @@ def __init__(self, model, subgraph, exp_tab): 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd, 'SPACE_TO_DEPTH': self.convert_space_to_depth, 'SPLIT': self.convert_split, + 'SPLIT_V': self.convert_split_v, 'SQRT': self.convert_sqrt, 'SQUARE': self.convert_square, 'SQUARED_DIFFERENCE': self.convert_squared_difference, @@ -1212,6 +1214,21 @@ def convert_zeros_like(self, op): return out + def convert_fill(self, op): + """Convert TFLite FILL""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + if self.has_expr(input_tensors[0].tensor_idx): + raise tvm.error.OpNotImplemented("For dims parameter of Fill operator," + " only constant values are supported.") + + in_dims = list(self.get_tensor_value(input_tensors[0])) + in_value_expr = self.get_expr(input_tensors[1].tensor_idx) + out = _op.full(in_value_expr, in_dims) + + return out + def _convert_reduce(self, relay_op, op): """Generic method to Convert TFLite MEAN operators""" try: @@ -1617,6 +1634,35 @@ def convert_split(self, op): return out + def convert_split_v(self, op): + """SPLIT_V implementation.""" + input_tensors = self.get_input_tensors(op) + + assert len(input_tensors) == 3, "input tensors length should be 3" + + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + in_expr = self.get_expr(input_tensor_idx) + + if self.has_expr(input_tensors[1].tensor_idx): + raise tvm.error.OpNotImplemented("For size_splits parameter of SPLIT_V operator, " + "only constant values are supported.") + size_splits = list(self.get_tensor_value(input_tensors[1])) + size_splits = tuple(np.cumsum(size_splits)[:-1]) + + axis_tensor = input_tensors[2] + split_axis = self.get_tensor_value(axis_tensor) + + out = _op.split(in_expr, size_splits, axis=int(split_axis)) + # Relay does not like a TupleWrapper of 1 element, further this + # only shows up with tf1.13 if we use a split with num_splits==1. + # In tf 1.14 this doesn't appear as it is automatically a reshape + # operation. + if isinstance(out, _expr.TupleWrapper) and out.size == 1: + out = out[0] + + return out + def convert_slice(self, op): """Convert TFLite SLICE""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 8d5d7abe1aaf..ce31a6dc6bef 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -216,15 +216,19 @@ def with_fused_activation_function(input_tensor, fn_name): return math_ops.tanh(input_tensor) raise AssertionError("Unknown fused_activation_function {}".format(fn_name)) -def _test_split(in_shape, axis, num_Splits, dtype): - '''internal split tester taking as parameters in_shape, number of tensors to split into - and dtype (data type)''' + +def _test_split(in_shape, axis, num_splits, dtype): + """internal split tester taking as parameters in_shape, number of tensors to split into + and dtype (data type)""" + np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype) with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=in_shape, dtype=dtype) - out = array_ops.split(in_data, num_Splits, axis=axis) - out_names = ['out_' + str(n) + ':0' for n in range(num_Splits)] - compare_tflite_with_tvm([np_data], ['Placeholder:0'], [in_data], out, + in_data = array_ops.placeholder(shape=in_shape, dtype=dtype, name="in_data") + out = array_ops.split(in_data, num_splits, axis=axis) + num_splits = len(num_splits) if isinstance(num_splits, list) \ + else num_splits + out_names = ['out_' + str(n) + ':0' for n in range(num_splits)] + compare_tflite_with_tvm([np_data], ['in_data'], [in_data], out, out_names=out_names) def test_forward_split(): @@ -252,6 +256,9 @@ def test_forward_split(): _test_split((1, 6, 3, 5), -3, 3, 'float32') _test_split((1, 3, 6, 5), -2, 3, 'float32') _test_split((1, 3, 5, 6), -1, 3, 'float32') + # size_splits split + _test_split((6,), 0, [1, 2, 3], 'float32') + _test_split((3, 6, 4), -2, [1, 4, 1], 'float32') ####################################################################### # slice @@ -1210,6 +1217,39 @@ def test_forward_zeros_like(): """ ZEROS LIKE """ _test_zeros_like(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + +####################################################################### +# Fill +# ---- + +def _test_fill(dims, value_data, value_dtype): + """ Use the fill op to create a tensor of value_data with constant dims.""" + + value_data = np.array(value_data, dtype=value_dtype) + # TF 1.13 TFLite convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + with tf.Graph().as_default(): + value = array_ops.placeholder(dtype=value_dtype, name="value", shape=[]) + out = tf.fill(dims, value) + compare_tflite_with_tvm([value_data], ["value"], [value], [out]) + + with tf.Graph().as_default(): + input1 = array_ops.placeholder(dtype=value_dtype, name="input1", shape=dims) + # Fill op gets converted to static tensor during conversion + out = tf.fill(dims, value_data) + out1 = tf.add(out, input1) + input1_data = np.random.uniform(0, 5, size=dims).astype(value_dtype) + compare_tflite_with_tvm([input1_data], ["input1"], [input1], [out1]) + + +def test_forward_fill(): + """ Test FILL op """ + + _test_fill((1, 2, 2, 4), 5, "int32") + _test_fill((1, 2, 2, 4), 5, "float32") + _test_fill((5, ), 5, "int32") + + ####################################################################### # Reduce # ------ @@ -1961,6 +2001,9 @@ def test_forward_mediapipe_hand_landmark(): # Zeros Like test_forward_zeros_like() + # Fill + test_forward_fill() + # Reduce test_all_reduce()