From 4ab4c137f27230f464039ddc077d38a3a59563b3 Mon Sep 17 00:00:00 2001 From: Ina Dobreva <55383260+inadob@users.noreply.github.com> Date: Sun, 6 Oct 2019 01:40:29 +0100 Subject: [PATCH] Add parses support for zeros_like tflite operator (#4042) The tensorflow zeros_like operation provided in array_ops.py produces directly a tensor with zeros without a graph, using only the shape and type of the input. This imposes the use of gen_array_ops.py that produces both a tensor and a graph so a comparison between tflite and tvm can be done. --- python/tvm/relay/frontend/tflite.py | 18 ++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 19 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 01f6c670de08..8b913154e2b8 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -75,6 +75,7 @@ def __init__(self, model, subgraph, exp_tab): 'MAXIMUM': self.convert_maximum, 'MINIMUM': self.convert_minimum, 'GREATER': self.convert_greater, + 'ZEROS_LIKE': self.convert_zeros_like, 'REDUCE_MIN': self._convert_reduce_min, 'REDUCE_MAX': self._convert_reduce_max, 'MEAN': self._convert_reduce_mean, @@ -478,6 +479,23 @@ def convert_minimum(self, op): def convert_greater(self, op): return self._convert_elemwise(_op.greater, op) + def convert_zeros_like(self, op): + """Convert TFLite ZEROS LIKE""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + out = _op.zeros_like(in_expr) + + return out + def _convert_reduce(self, relay_op, op): """Generic method to Convert TFLite MEAN operators""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 06afa59e0a82..670e85ba8384 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import variables try: from tensorflow import lite as interpreter_wrapper @@ -632,6 +633,21 @@ def test_all_elemwise(): _test_forward_elemwise(_test_minimum) _test_forward_elemwise(_test_greater) +####################################################################### +# Zeros like +# -------- + +def _test_zeros_like(data): + """ One iteration of ZEROS LIKE """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = gen_array_ops.zeros_like(in_data) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + +def test_forward_zeros_like(): + """ ZEROS LIKE """ + _test_zeros_like(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + ####################################################################### # Reduce # ------ @@ -1020,6 +1036,9 @@ def test_forward_ssd_mobilenet_v1(): # Elemwise test_all_elemwise() + # Zeros Like + test_forward_zeros_like() + # Reduce test_all_reduce()