diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index e7282eb9afd6b..3a7965f97bbbd 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -760,6 +760,57 @@ def _impl(inputs, attr, params): return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])) return _impl +def _split(has_size_vector): + # TF documentation https://www.tensorflow.org/api_docs/python/tf/split + def _impl(inputs, attr, params): + try: + # order and number of inputs are different: + # if has_size_vector: + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v + # else: + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split + + # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow, + # we can only support constants + if has_size_vector: + input_node_index = 0 + input_axis_index = 2 + size_splits_input_name = inputs[1].list_output_names()[0] + size_splits = params[size_splits_input_name].asnumpy() + section_beginnings = np.cumsum(size_splits)[:-1] + indices_or_sections = tuple(section_beginnings) + else: + input_node_index = 1 + input_axis_index = 0 + indices_or_sections = attr['num_split'] + input_node = inputs[input_node_index] + axis_input_name = inputs[input_axis_index].list_output_names()[0] + axis_input_value = params[axis_input_name].asnumpy()[0] + except (IndexError, KeyError): + raise TypeError( \ + "Unsupported argument for split: `axis` and `num_or_size_splits` " \ + "should be constants") + return _sym.split(input_node, + indices_or_sections=indices_or_sections, + axis=axis_input_value) + return _impl + +def _unpack(): + def _impl(inputs, attr, params): + input_node = inputs[0] + axis = attr['axis'] + input_shape = attr['_input_shapes'][input_node][0] + axis_length = input_shape[axis] + if axis_length < 0: + raise TypeError("Unstack with unknown axis length") + splitted = _sym.split(input_node, + indices_or_sections=axis_length, + axis=axis, + name=attr.get('_node_name', 'unstack')) + + return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted]) + return _impl + def _mean(): def _impl(inputs, attr, params): axis = params.pop(inputs[1].list_output_names()[0]) @@ -835,6 +886,9 @@ def _impl(inputs, attr, params): 'Range' : _range(), 'Rank' : _rank(), 'Transpose' : _transpose(), + 'Split' : _split(False), + 'SplitV' : _split(True), + 'Unpack' : _unpack(), 'Tanh' : AttrCvt('tanh'), 'Mean' : _mean(), 'Less' : _broadcast('less'), @@ -1137,21 +1191,30 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None): # Pass the target layout attr["_target_layout"] = layout - #ToDo: Some of the tensorflow operators internaly maintain - #execution layers and its output name will the layer number along with - #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the - #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, - #the digit has to be ignored. - if ":" in node.input[0]: - in_name, _ = node.input[0].split(':') - node.input[0] = in_name - # Fill shapes for all inputs in a list inputs = [] for i in node.input: + #ToDo: Some of the tensorflow operators internaly maintain + #execution layers and its output name will the layer number along with + #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the + #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, + #the number has to be ignored for single-output nodes. + #On the other hand, for multi-output nodes the number indicates the used output, + #and the lack of the number implies 0 + if ":" in i: + in_name, num_layer = i.split(':') + i = in_name + num_layer = int(num_layer) + else: + num_layer = 0 + if i in self._nodes: - inputs.append(self._nodes[i]) - input_shapes[self._nodes[i]] = self._output_shapes[i] + tvm_n = self._nodes[i] + outputs = tvm_n.list_output_names() + if len(outputs) > 1: + tvm_n = tvm_n[num_layer] + inputs.append(tvm_n) + input_shapes[tvm_n] = self._output_shapes[i] attr['_input_shapes'] = input_shapes inputs = self._fix_extranodes(node.op, attr, inputs) diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 62d3577ba10ae..95e2558268612 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -75,7 +75,10 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm' def run_tf_graph(sess, input_data, input_node, output_node): """ Generic function to execute tensorflow """ - tensor = sess.graph.get_tensor_by_name(output_node) + if isinstance(output_node, list): + tensor = [sess.graph.get_tensor_by_name(node) for node in output_node] + else: + tensor = sess.graph.get_tensor_by_name(output_node) if isinstance(input_data, list): input_dict = {} @@ -91,7 +94,11 @@ def run_tf_graph(sess, input_data, input_node, output_node): def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False): """Generic function to generate and compare tensorflow and TVM output""" - out_node = out_name.split(':')[0] if ":" in out_name else out_name + if isinstance(out_name, str): + out_name = [out_name] + + out_node = [name.split(':')[0] if ":" in name else name for name in out_name] + if isinstance(in_name, list): in_node = [0]*len(in_name) @@ -106,7 +113,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, final_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), - [out_node], + out_node, ) tf_output = run_tf_graph(sess, in_data, in_name, out_name) @@ -119,8 +126,12 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, if no_gpu and device == 'cuda': continue - tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device) - tvm.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, num_output=len(out_node), target=device) + if len(out_node) == 1: + tvm_output = [tvm_output] + + for tf_tensor, tvm_tensor in zip(tf_output, tvm_output): + tvm.testing.assert_allclose(tf_tensor, tvm_tensor, atol=1e-5, rtol=1e-5) sess.close() @@ -489,6 +500,52 @@ def test_forward_gather(): _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32') +####################################################################### +# Split +# ------ + +def _test_split(ip_shape, num_or_size_splits, axis, dtype): + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits + tf.split(in_data, num_or_size_splits, axis=axis) + np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) + + compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)]) + +def test_forward_split(): + '''test split layer''' + _test_split((6,), 2, 0, 'int32') + _test_split((4,), 4, 0, 'float32') + _test_split((2,6), 3, 1, 'int32') + # negative axis + _test_split((1,4), 2, -1, 'int32') + # list of splits + _test_split((6,), [1,2,3], 0, 'int32') + _test_split((3,6,4), [1,4,1], -2, 'float32') + + +####################################################################### +# Unstack +# ------ + +def _test_unstack(ip_shape, axis, dtype): + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.unstack(in_data, axis=axis) + np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) + + compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])]) + +def test_forward_unstack(): + '''test unstack layer''' + _test_unstack((6,), 0, 'int32') + _test_unstack((2,6), 1, 'float64') + # negative axis + _test_unstack((1,4), -1, 'int32') + _test_unstack((3,6,4), -2, 'float32') + + ####################################################################### # Multi Input to graph # -------------------- @@ -1017,6 +1074,8 @@ def test_forward_rel_ops(): test_forward_resize_bilinear() test_forward_pad() test_forward_gather() + test_forward_unstack() + test_forward_split() #test_forward_stridedslice() # Activations