Skip to content

Commit

Permalink
WIP fix for stack(unstack(data)) != data
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr committed Nov 16, 2018
1 parent 76f4651 commit d275972
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
14 changes: 10 additions & 4 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,12 @@ def _impl(inputs, attr, params):
axis=axis,
name=attr.get('_node_name', 'unstack'))

return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted])
if len(input_shape) > 1:
squeezed = [_sym.squeeze(split_item, axis=axis) for split_item in splitted]
else:
# FIXME split_item[0] still has shape [1] instead of []
squeezed = [split_item[0] for split_item in splitted]
return _sym.Group(squeezed)
return _impl

def _mean():
Expand Down Expand Up @@ -1210,11 +1215,12 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None):

if i in self._nodes:
tvm_n = self._nodes[i]
outputs = tvm_n.list_output_names()
if len(outputs) > 1:
tvm_n_shape = self._output_shapes[i]
if len(tvm_n.list_output_names()) > 1:
tvm_n = tvm_n[num_layer]
tvm_n_shape = tvm_n_shape[num_layer]
inputs.append(tvm_n)
input_shapes[tvm_n] = self._output_shapes[i]
input_shapes[tvm_n] = tvm_n_shape
attr['_input_shapes'] = input_shapes

inputs = self._fix_extranodes(node.op, attr, inputs)
Expand Down
20 changes: 17 additions & 3 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,21 @@ def test_forward_gather():
# ------

def _test_split(ip_shape, num_or_size_splits, axis, dtype):
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis)
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.concat(tf.split(in_data, num_or_size_splits, axis=axis), axis=axis)

compare_tf_with_tvm([np_data], ['in_data:0'], 'concat:0')

def test_forward_split():
'''test split layer'''
_test_split((6,), 2, 0, 'int32')
Expand All @@ -530,13 +537,20 @@ def test_forward_split():
# ------

def _test_unstack(ip_shape, axis, dtype):
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.unstack(in_data, axis=axis)
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])])

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.stack(tf.unstack(in_data, axis=axis), axis=axis)

compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')

def test_forward_unstack():
'''test unstack layer'''
_test_unstack((6,), 0, 'int32')
Expand Down

0 comments on commit d275972

Please sign in to comment.