From d2759720a6a354acbc9c2cbc9323e016bf2620bb Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Fri, 16 Nov 2018 15:45:23 +0300 Subject: [PATCH] WIP fix for stack(unstack(data)) != data --- nnvm/python/nnvm/frontend/tensorflow.py | 14 +++++++++---- .../frontend/tensorflow/test_forward.py | 20 ++++++++++++++++--- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 3a7965f97bbbd..81ddf52311ec9 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -808,7 +808,12 @@ def _impl(inputs, attr, params): axis=axis, name=attr.get('_node_name', 'unstack')) - return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted]) + if len(input_shape) > 1: + squeezed = [_sym.squeeze(split_item, axis=axis) for split_item in splitted] + else: + # FIXME split_item[0] still has shape [1] instead of [] + squeezed = [split_item[0] for split_item in splitted] + return _sym.Group(squeezed) return _impl def _mean(): @@ -1210,11 +1215,12 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None): if i in self._nodes: tvm_n = self._nodes[i] - outputs = tvm_n.list_output_names() - if len(outputs) > 1: + tvm_n_shape = self._output_shapes[i] + if len(tvm_n.list_output_names()) > 1: tvm_n = tvm_n[num_layer] + tvm_n_shape = tvm_n_shape[num_layer] inputs.append(tvm_n) - input_shapes[tvm_n] = self._output_shapes[i] + input_shapes[tvm_n] = tvm_n_shape attr['_input_shapes'] = input_shapes inputs = self._fix_extranodes(node.op, attr, inputs) diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 95e2558268612..c9406c3c2a62e 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -505,14 +505,21 @@ def test_forward_gather(): # ------ def _test_split(ip_shape, num_or_size_splits, axis, dtype): + np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) + num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits + tf.reset_default_graph() in_data = tf.placeholder(dtype, ip_shape, name="in_data") - num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits tf.split(in_data, num_or_size_splits, axis=axis) - np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)]) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.concat(tf.split(in_data, num_or_size_splits, axis=axis), axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], 'concat:0') + def test_forward_split(): '''test split layer''' _test_split((6,), 2, 0, 'int32') @@ -530,13 +537,20 @@ def test_forward_split(): # ------ def _test_unstack(ip_shape, axis, dtype): + np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) + tf.reset_default_graph() in_data = tf.placeholder(dtype, ip_shape, name="in_data") tf.unstack(in_data, axis=axis) - np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])]) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.stack(tf.unstack(in_data, axis=axis), axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0') + def test_forward_unstack(): '''test unstack layer''' _test_unstack((6,), 0, 'int32')