Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] Support Unstack and Split
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr committed Nov 14, 2018
1 parent 1b86373 commit 76f4651
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 16 deletions.
85 changes: 74 additions & 11 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,57 @@ def _impl(inputs, attr, params):
return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
return _impl

def _split(has_size_vector):
# TF documentation https://www.tensorflow.org/api_docs/python/tf/split
def _impl(inputs, attr, params):
try:
# order and number of inputs are different:
# if has_size_vector:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v
# else:
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split

# in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow,
# we can only support constants
if has_size_vector:
input_node_index = 0
input_axis_index = 2
size_splits_input_name = inputs[1].list_output_names()[0]
size_splits = params[size_splits_input_name].asnumpy()
section_beginnings = np.cumsum(size_splits)[:-1]
indices_or_sections = tuple(section_beginnings)
else:
input_node_index = 1
input_axis_index = 0
indices_or_sections = attr['num_split']
input_node = inputs[input_node_index]
axis_input_name = inputs[input_axis_index].list_output_names()[0]
axis_input_value = params[axis_input_name].asnumpy()[0]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for split: `axis` and `num_or_size_splits` " \
"should be constants")
return _sym.split(input_node,
indices_or_sections=indices_or_sections,
axis=axis_input_value)
return _impl

def _unpack():
def _impl(inputs, attr, params):
input_node = inputs[0]
axis = attr['axis']
input_shape = attr['_input_shapes'][input_node][0]
axis_length = input_shape[axis]
if axis_length < 0:
raise TypeError("Unstack with unknown axis length")
splitted = _sym.split(input_node,
indices_or_sections=axis_length,
axis=axis,
name=attr.get('_node_name', 'unstack'))

return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted])
return _impl

def _mean():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].list_output_names()[0])
Expand Down Expand Up @@ -835,6 +886,9 @@ def _impl(inputs, attr, params):
'Range' : _range(),
'Rank' : _rank(),
'Transpose' : _transpose(),
'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(),
'Less' : _broadcast('less'),
Expand Down Expand Up @@ -1137,21 +1191,30 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None):
# Pass the target layout
attr["_target_layout"] = layout

#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the digit has to be ignored.
if ":" in node.input[0]:
in_name, _ = node.input[0].split(':')
node.input[0] = in_name

# Fill shapes for all inputs in a list
inputs = []
for i in node.input:
#ToDo: Some of the tensorflow operators internaly maintain
#execution layers and its output name will the layer number along with
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
#the number has to be ignored for single-output nodes.
#On the other hand, for multi-output nodes the number indicates the used output,
#and the lack of the number implies 0
if ":" in i:
in_name, num_layer = i.split(':')
i = in_name
num_layer = int(num_layer)
else:
num_layer = 0

if i in self._nodes:
inputs.append(self._nodes[i])
input_shapes[self._nodes[i]] = self._output_shapes[i]
tvm_n = self._nodes[i]
outputs = tvm_n.list_output_names()
if len(outputs) > 1:
tvm_n = tvm_n[num_layer]
inputs.append(tvm_n)
input_shapes[tvm_n] = self._output_shapes[i]
attr['_input_shapes'] = input_shapes

inputs = self._fix_extranodes(node.op, attr, inputs)
Expand Down
69 changes: 64 additions & 5 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
def run_tf_graph(sess, input_data, input_node, output_node):
""" Generic function to execute tensorflow """

tensor = sess.graph.get_tensor_by_name(output_node)
if isinstance(output_node, list):
tensor = [sess.graph.get_tensor_by_name(node) for node in output_node]
else:
tensor = sess.graph.get_tensor_by_name(output_node)

if isinstance(input_data, list):
input_dict = {}
Expand All @@ -91,7 +94,11 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False):
"""Generic function to generate and compare tensorflow and TVM output"""

out_node = out_name.split(':')[0] if ":" in out_name else out_name
if isinstance(out_name, str):
out_name = [out_name]

out_node = [name.split(':')[0] if ":" in name else name for name in out_name]


if isinstance(in_name, list):
in_node = [0]*len(in_name)
Expand All @@ -106,7 +113,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
[out_node],
out_node,
)

tf_output = run_tf_graph(sess, in_data, in_name, out_name)
Expand All @@ -119,8 +126,12 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if no_gpu and device == 'cuda':
continue

tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
tvm.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, num_output=len(out_node), target=device)
if len(out_node) == 1:
tvm_output = [tvm_output]

for tf_tensor, tvm_tensor in zip(tf_output, tvm_output):
tvm.testing.assert_allclose(tf_tensor, tvm_tensor, atol=1e-5, rtol=1e-5)

sess.close()

Expand Down Expand Up @@ -489,6 +500,52 @@ def test_forward_gather():
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')


#######################################################################
# Split
# ------

def _test_split(ip_shape, num_or_size_splits, axis, dtype):
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
tf.split(in_data, num_or_size_splits, axis=axis)
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])

def test_forward_split():
'''test split layer'''
_test_split((6,), 2, 0, 'int32')
_test_split((4,), 4, 0, 'float32')
_test_split((2,6), 3, 1, 'int32')
# negative axis
_test_split((1,4), 2, -1, 'int32')
# list of splits
_test_split((6,), [1,2,3], 0, 'int32')
_test_split((3,6,4), [1,4,1], -2, 'float32')


#######################################################################
# Unstack
# ------

def _test_unstack(ip_shape, axis, dtype):
tf.reset_default_graph()
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.unstack(in_data, axis=axis)
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)

compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])])

def test_forward_unstack():
'''test unstack layer'''
_test_unstack((6,), 0, 'int32')
_test_unstack((2,6), 1, 'float64')
# negative axis
_test_unstack((1,4), -1, 'int32')
_test_unstack((3,6,4), -2, 'float32')


#######################################################################
# Multi Input to graph
# --------------------
Expand Down Expand Up @@ -1017,6 +1074,8 @@ def test_forward_rel_ops():
test_forward_resize_bilinear()
test_forward_pad()
test_forward_gather()
test_forward_unstack()
test_forward_split()
#test_forward_stridedslice()

# Activations
Expand Down

0 comments on commit 76f4651

Please sign in to comment.