Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] Support Unstack and Split (apache#2105)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr authored and tqchen committed Dec 13, 2018
1 parent 4bbf96e commit 00d509d
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 63 deletions.
106 changes: 88 additions & 18 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __call__(self, inputs, attrs, *args):
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
self._ignores.append('_input_0d_mismatch')
# Retain the names
try:
attrs['name'] = attrs['_node_name']
Expand Down Expand Up @@ -319,8 +320,7 @@ def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
axis = params[dim_input.list_output_names()[0]]
params.pop(dim_input.list_output_names()[0])
return AttrCvt(op_name="expand_dims", ignores=['Tdim'],
extras={'axis': axis.asnumpy()[0]})(inputs, attr)
return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
return _impl

def _resize_bilinear():
Expand Down Expand Up @@ -383,7 +383,7 @@ def _impl(inputs, attr, params):
def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"])

return _impl
Expand Down Expand Up @@ -787,15 +787,64 @@ def _impl(inputs, attr, params):
)(inputs, attr)
return _impl

def _split():
def _split(has_size_vector):
# TF documentation https://www.tensorflow.org/api_docs/python/tf/split
def _impl(inputs, attr, params):
axis = params.pop(inputs[0].list_output_names()[0])
return AttrCvt(
op_name="split", ignores=['T'],
transforms={'num_split': 'indices_or_sections'},
extras={'axis': axis.asnumpy()[0]})(inputs[1], attr)
try:
# order and number of inputs are different:
# if has_size_vector:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v
# else:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split

# in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow,
# we can only support constants
if has_size_vector:
input_node_index = 0
input_axis_index = 2
size_splits_input_name = inputs[1].list_output_names()[0]
size_splits = params[size_splits_input_name].asnumpy()
section_beginnings = np.cumsum(size_splits)[:-1]
indices_or_sections = tuple(section_beginnings)
else:
input_node_index = 1
input_axis_index = 0
indices_or_sections = attr['num_split']
input_node = inputs[input_node_index]
axis_input_name = inputs[input_axis_index].list_output_names()[0]
axis_input_value = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for split: `axis` and `num_or_size_splits` " \
"should be constants")
return _sym.split(input_node,
indices_or_sections=indices_or_sections,
axis=axis_input_value)
return _impl

def _unpack():
def _impl(inputs, attr, params):
input_node = inputs[0]
axis = attr['axis']
input_shape = attr['_input_shapes'][input_node][0]
axis_length = input_shape[axis]
if axis_length < 0:
raise TypeError("Unstack with unknown axis length")
splitted = _sym.split(input_node,
indices_or_sections=axis_length,
axis=axis,
name=attr.get('_node_name', 'unstack'))

return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted])
return _impl

def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
if data in attr['_input_0d_mismatch']:
return data if num_newaxis == 1 else \
_sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis-1)

return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis)

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -863,7 +912,9 @@ def _impl(inputs, attr, params):
'GreaterEqual' : _broadcast('greater_equal'),
'Equal' : _broadcast('equal'),
'NotEqual' : _broadcast('not_equal'),
'Split' : _split(),
'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
}

# _convert_map_rnn defines maps of rnn operator name to
Expand Down Expand Up @@ -1059,6 +1110,7 @@ def __init__(self):
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
self._outputs_are_0d = {}

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct nnvm nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -1114,6 +1166,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Operator name 'Const' is treated as a parameter to build NNVM params dict.

input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
Expand All @@ -1133,6 +1186,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
self._outputs_are_0d[node.name] = [ \
not shape if isinstance(shape, list) else False \
for shape in self._output_shapes[node.name]]

if node.op == "Placeholder":
self._nodes[node.name] = _sym.Variable(name=node.name,
Expand Down Expand Up @@ -1162,24 +1218,32 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Fill shapes for all inputs in a list
inputs = []
for i in node.input:
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
# Some TensorFlow operators internally maintain execution layers
# and their output name includes the layer number along with
# graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the
# output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case,
# the number has to be ignored for single-output nodes.
# On the other hand, for multi-output nodes the number is the output index,
# and the lack of the number implies 0.
tensor_name = i.split(':')
node_name = tensor_name[0]
if node_name in self._nodes:
in_sym = self._nodes[node_name]
if len(in_sym.list_output_names()) > 1:
tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
in_sym = in_sym[tensor_slot]
input_shape = (self._output_shapes[node_name])[tensor_slot]
input_shape = self._output_shapes[node_name][tensor_slot]
else:
tensor_slot = 0
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym)
input_shapes[in_sym] = [input_shape]
# This means the node is 1d in NNVM and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym)
attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch

inputs = self._fix_extranodes(node.op, attr, inputs)
op = self._convert_operator(node.op, inputs, attr, graph)
Expand Down Expand Up @@ -1207,15 +1271,21 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
if outputs is None:
out.append(final_op)
else:
out = [self._nodes[out_name] for out_name in outputs]
for out_name in outputs:
if ":" in out_name:
out_name, out_num = out_name.split(":")
out_num = int(out_num)
out.append(self._nodes[out_name][out_num])
else:
out.append(self._nodes[out_name])

#Add the RNN outputs also with 'head' nodes of the nnvm graph
if self._num_rnn_layer:
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
out.append(out_rnn)

if isinstance(out, list):
out = _sym.Group(out)
out = _sym.Group(out) if len(out) > 1 else out[0]

return out, self._params

Expand Down
93 changes: 48 additions & 45 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if no_gpu and device == 'cuda':
continue

tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
num_output=len(out_node), target=device, out_names=out_name)
# since the names from tensorflow and nnvm runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
Expand Down Expand Up @@ -506,14 +507,24 @@ def test_forward_gather():
# Split
# -----

def _test_split(in_shape, axis, num_split, dtype):
def _test_split(in_shape, axis, num_or_size_splits, dtype):
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)

""" One iteration of a Split """
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis)

with tf.Graph().as_default():
in_data = tf.placeholder(dtype, in_shape, name="in_data")
tf.split(in_data, num_split, axis)
np_data = np.random.uniform(size=in_shape).astype(dtype)
compare_tf_with_tvm(np_data, 'in_data:0', 'split:0')
compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])

# and now test together with concat
tf.reset_default_graph()
in_data = tf.placeholder(dtype, in_shape, name="in_data")
splitted = tf.split(in_data, num_or_size_splits, axis=axis)
tf.concat(splitted, axis)

compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0')

def test_forward_split():
'''test split layer'''
Expand All @@ -523,11 +534,11 @@ def test_forward_split():
_test_split((6,), 0, 3, 'float32')
# rank 2
_test_split((6, 2), 0, 3, 'float32')
_test_split((2, 6), 1, 3, 'float32')
_test_split((2, 6), 1, 6, 'float32')
# rank 3
_test_split((6, 2, 4), 0, 3, 'float32')
_test_split((6, 2, 4), 0, 2, 'int32')
_test_split((2, 6, 4), 1, 3, 'float32')
_test_split((2, 4, 6), 2, 3, 'float32')
_test_split((2, 4, 6), 2, 1, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
Expand All @@ -538,45 +549,37 @@ def test_forward_split():
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')
# size_splits list
_test_split((6,), 0, [1, 2, 3], 'int32')
_test_split((3, 6, 4), -2, [1, 4, 1], 'float32')


#######################################################################
# Split followed by concat
# ------------------------
# Unstack
# -------

def _test_split_concat(in_shape, axis, num_split, dtype):
""" One iteration of a split_concat pair"""
def _test_unstack(ip_shape, axis, dtype):
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

with tf.Graph().as_default():
in_data = tf.placeholder(dtype, in_shape, name="in_data")
splitted = tf.split(in_data, num_split, axis)
tf.concat(splitted, axis)
np_data = np.random.uniform(size=in_shape).astype(dtype)
compare_tf_with_tvm(np_data, 'in_data:0', 'concat:0')

def test_forward_split_concat():
'''test split followed by concat layers'''
# rank 1
_test_split_concat((3,), 0, 1, 'float32')
_test_split_concat((3,), 0, 3, 'float32')
_test_split_concat((6,), 0, 3, 'float32')
# rank 2
_test_split_concat((6, 2), 0, 3, 'float32')
_test_split_concat((2, 6), 1, 3, 'float32')
# rank 3
_test_split_concat((6, 2, 4), 0, 3, 'float32')
_test_split_concat((2, 6, 4), 1, 3, 'float32')
_test_split_concat((2, 4, 6), 2, 3, 'float32')
# rank 4
_test_split((6, 1, 3, 5), 0, 3, 'float32')
_test_split((1, 6, 3, 5), 1, 3, 'float32')
_test_split((1, 3, 6, 5), 2, 3, 'float32')
_test_split((1, 3, 5, 6), 3, 3, 'float32')
# split along negative axis
_test_split((6, 1, 3, 5), -4, 3, 'float32')
_test_split((1, 6, 3, 5), -3, 3, 'float32')
_test_split((1, 3, 6, 5), -2, 3, 'float32')
_test_split((1, 3, 5, 6), -1, 3, 'float32')
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.unstack(in_data, axis=axis)

compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])])

tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.stack(tf.unstack(in_data, axis=axis), axis=axis)

compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')

def test_forward_unstack():
'''test unstack layer'''
_test_unstack((6,), 0, 'int32')
_test_unstack((2,6), 1, 'float64')
# negative axis
_test_unstack((1,4), -1, 'int32')
_test_unstack((3,6,4), -2, 'float32')


#######################################################################
Expand Down Expand Up @@ -1139,7 +1142,7 @@ def test_forward_rel_ops():
test_forward_gather()
test_forward_stridedslice()
test_forward_split()
test_forward_split_concat()
test_forward_unstack()

# Activations
test_forward_sigmoid()
Expand Down

0 comments on commit 00d509d

Please sign in to comment.