Skip to content

Commit

Permalink
Fix for stack(unstack(data)) != data
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr committed Nov 22, 2018
1 parent 1895ad5 commit 785b082
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
30 changes: 24 additions & 6 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __call__(self, inputs, attrs, *args):
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
self._ignores.append('_input_0d_mismatch')
# Retain the names
try:
attrs['name'] = attrs['_node_name']
Expand Down Expand Up @@ -315,8 +316,7 @@ def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
axis = params[dim_input.list_output_names()[0]]
params.pop(dim_input.list_output_names()[0])
return AttrCvt(op_name="expand_dims", ignores=['Tdim'],
extras={'axis': axis.asnumpy()[0]})(inputs, attr)
return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
return _impl

def _resize_bilinear():
Expand Down Expand Up @@ -379,7 +379,7 @@ def _impl(inputs, attr, params):
def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"])

return _impl
Expand Down Expand Up @@ -834,6 +834,13 @@ def _impl(inputs, attr, params):
)(inputs, attr)
return _impl

def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
if data in attr['_input_0d_mismatch']:
return data if num_newaxis == 1 else \
_sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis-1)

return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis)

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1098,6 +1105,7 @@ def __init__(self):
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
self._outputs_are_0d = {}

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -1153,6 +1161,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Operator name 'Const' is treated as a parameter to build NNVM params dict.

input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
Expand All @@ -1172,6 +1181,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
self._outputs_are_0d[node.name] = [ \
not shape if isinstance(shape, list) else False \
for shape in self._output_shapes[node.name]]

if node.op == "Placeholder":
self._nodes[node.name] = _sym.Variable(name=node.name,
Expand Down Expand Up @@ -1217,12 +1229,18 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

if i in self._nodes:
tvm_n = self._nodes[i]
tvm_n_outputs = tvm_n.list_output_names()
if len(tvm_n_outputs) > 1:
tvm_n_shape = self._output_shapes[i]
if len(tvm_n.list_output_names()) > 1:
tvm_n = tvm_n[num_layer]
tvm_n_shape = [tvm_n_shape[num_layer]]
inputs.append(tvm_n)
input_shapes[tvm_n] = self._output_shapes[i]
input_shapes[tvm_n] = tvm_n_shape
#This means the node is 1d in NVM and 0d in TF.
#See `_expand_dims_0d_aware`.
if self._outputs_are_0d[i][num_layer] and tvm_n_shape[0]:
input_0d_mismatch.add(tvm_n)
attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch

inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph)
Expand Down
20 changes: 17 additions & 3 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,21 @@ def test_forward_gather():
# ------

def _test_split(ip_shape, num_or_size_splits, axis, dtype):
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis)
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.concat(tf.split(in_data, num_or_size_splits, axis=axis), axis=axis)

compare_tf_with_tvm([np_data], ['in_data:0'], 'concat:0')

def test_forward_split():
'''test split layer'''
_test_split((6,), 2, 0, 'int32')
Expand All @@ -533,13 +540,20 @@ def test_forward_split():
# ------

def _test_unstack(ip_shape, axis, dtype):
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.unstack(in_data, axis=axis)
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])])

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.stack(tf.unstack(in_data, axis=axis), axis=axis)

compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')

def test_forward_unstack():
'''test unstack layer'''
_test_unstack((6,), 0, 'int32')
Expand Down

0 comments on commit 785b082

Please sign in to comment.