diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 9c1290bedb6b..10f23a49b5de 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -36,6 +36,7 @@ def __call__(self, inputs, attrs, *args): self._ignores.append('_node_name') self._ignores.append('is_training') self._ignores.append('_target_layout') + self._ignores.append('_input_0d_mismatch') # Retain the names try: attrs['name'] = attrs['_node_name'] @@ -319,8 +320,7 @@ def _impl(inputs, attr, params): dim_input = inputs.pop(1) axis = params[dim_input.list_output_names()[0]] params.pop(dim_input.list_output_names()[0]) - return AttrCvt(op_name="expand_dims", ignores=['Tdim'], - extras={'axis': axis.asnumpy()[0]})(inputs, attr) + return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0]) return _impl def _resize_bilinear(): @@ -383,7 +383,7 @@ def _impl(inputs, attr, params): def _pack(): def _impl(inputs, attr, params): axis = int(attr["axis"]) - inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] + inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs] return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"]) return _impl @@ -787,15 +787,64 @@ def _impl(inputs, attr, params): )(inputs, attr) return _impl -def _split(): +def _split(has_size_vector): + # TF documentation https://www.tensorflow.org/api_docs/python/tf/split def _impl(inputs, attr, params): - axis = params.pop(inputs[0].list_output_names()[0]) - return AttrCvt( - op_name="split", ignores=['T'], - transforms={'num_split': 'indices_or_sections'}, - extras={'axis': axis.asnumpy()[0]})(inputs[1], attr) + try: + # order and number of inputs are different: + # if has_size_vector: + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v + # else: + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split + + # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow, + # we can only support constants + if has_size_vector: + input_node_index = 0 + input_axis_index = 2 + size_splits_input_name = inputs[1].list_output_names()[0] + size_splits = params[size_splits_input_name].asnumpy() + section_beginnings = np.cumsum(size_splits)[:-1] + indices_or_sections = tuple(section_beginnings) + else: + input_node_index = 1 + input_axis_index = 0 + indices_or_sections = attr['num_split'] + input_node = inputs[input_node_index] + axis_input_name = inputs[input_axis_index].list_output_names()[0] + axis_input_value = params[axis_input_name].asnumpy()[0] + except (IndexError, KeyError): + raise TypeError( \ + "Unsupported argument for split: `axis` and `num_or_size_splits` " \ + "should be constants") + return _sym.split(input_node, + indices_or_sections=indices_or_sections, + axis=axis_input_value) return _impl +def _unpack(): + def _impl(inputs, attr, params): + input_node = inputs[0] + axis = attr['axis'] + input_shape = attr['_input_shapes'][input_node][0] + axis_length = input_shape[axis] + if axis_length < 0: + raise TypeError("Unstack with unknown axis length") + splitted = _sym.split(input_node, + indices_or_sections=axis_length, + axis=axis, + name=attr.get('_node_name', 'unstack')) + + return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted]) + return _impl + +def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): + if data in attr['_input_0d_mismatch']: + return data if num_newaxis == 1 else \ + _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis-1) + + return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -863,7 +912,9 @@ def _impl(inputs, attr, params): 'GreaterEqual' : _broadcast('greater_equal'), 'Equal' : _broadcast('equal'), 'NotEqual' : _broadcast('not_equal'), - 'Split' : _split(), + 'Split' : _split(False), + 'SplitV' : _split(True), + 'Unpack' : _unpack(), } # _convert_map_rnn defines maps of rnn operator name to @@ -1059,6 +1110,7 @@ def __init__(self): self._output_shapes = {} self._num_param = 0 self._num_rnn_layer = False + self._outputs_are_0d = {} def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct nnvm nodes from tensorflow graph definition - GraphDef. @@ -1114,6 +1166,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Operator name 'Const' is treated as a parameter to build NNVM params dict. input_shapes = {} + input_0d_mismatch = set() attr = self._parse_attr(node.attr) #Variable converted to Const will not have only value attr @@ -1133,6 +1186,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): else: raise NotImplementedError( \ "Please freeze the graph with add_shapes=True") + self._outputs_are_0d[node.name] = [ \ + not shape if isinstance(shape, list) else False \ + for shape in self._output_shapes[node.name]] if node.op == "Placeholder": self._nodes[node.name] = _sym.Variable(name=node.name, @@ -1162,11 +1218,13 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Fill shapes for all inputs in a list inputs = [] for i in node.input: - #ToDo: Some of the tensorflow operators internaly maintain - #execution layers and its output name will the layer number along with - #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the - #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, - #the digit has to be ignored. + # Some TensorFlow operators internally maintain execution layers + # and their output name includes the layer number along with + # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the + # output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case, + # the number has to be ignored for single-output nodes. + # On the other hand, for multi-output nodes the number is the output index, + # and the lack of the number implies 0. tensor_name = i.split(':') node_name = tensor_name[0] if node_name in self._nodes: @@ -1174,12 +1232,18 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): if len(in_sym.list_output_names()) > 1: tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0 in_sym = in_sym[tensor_slot] - input_shape = (self._output_shapes[node_name])[tensor_slot] + input_shape = self._output_shapes[node_name][tensor_slot] else: + tensor_slot = 0 input_shape = self._output_shapes[node_name][0] inputs.append(in_sym) input_shapes[in_sym] = [input_shape] + # This means the node is 1d in NNVM and 0d in TF. + # See `_expand_dims_0d_aware`. + if self._outputs_are_0d[node_name][tensor_slot] and input_shape: + input_0d_mismatch.add(in_sym) attr['_input_shapes'] = input_shapes + attr['_input_0d_mismatch'] = input_0d_mismatch inputs = self._fix_extranodes(node.op, attr, inputs) op = self._convert_operator(node.op, inputs, attr, graph) @@ -1207,7 +1271,13 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): if outputs is None: out.append(final_op) else: - out = [self._nodes[out_name] for out_name in outputs] + for out_name in outputs: + if ":" in out_name: + out_name, out_num = out_name.split(":") + out_num = int(out_num) + out.append(self._nodes[out_name][out_num]) + else: + out.append(self._nodes[out_name]) #Add the RNN outputs also with 'head' nodes of the nnvm graph if self._num_rnn_layer: @@ -1215,7 +1285,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out.append(out_rnn) if isinstance(out, list): - out = _sym.Group(out) + out = _sym.Group(out) if len(out) > 1 else out[0] return out, self._params diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 219ceb5bd379..ed3d0272b4fc 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -124,7 +124,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, if no_gpu and device == 'cuda': continue - tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device) + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, + num_output=len(out_node), target=device, out_names=out_name) # since the names from tensorflow and nnvm runs are not exactly same, # first len(tf_output) will be compared for i in range(len(tf_output)): @@ -506,14 +507,24 @@ def test_forward_gather(): # Split # ----- -def _test_split(in_shape, axis, num_split, dtype): +def _test_split(in_shape, axis, num_or_size_splits, dtype): + np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype) + """ One iteration of a Split """ + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits + tf.split(in_data, num_or_size_splits, axis=axis) - with tf.Graph().as_default(): - in_data = tf.placeholder(dtype, in_shape, name="in_data") - tf.split(in_data, num_split, axis) - np_data = np.random.uniform(size=in_shape).astype(dtype) - compare_tf_with_tvm(np_data, 'in_data:0', 'split:0') + compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)]) + + # and now test together with concat + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + splitted = tf.split(in_data, num_or_size_splits, axis=axis) + tf.concat(splitted, axis) + + compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0') def test_forward_split(): '''test split layer''' @@ -523,11 +534,11 @@ def test_forward_split(): _test_split((6,), 0, 3, 'float32') # rank 2 _test_split((6, 2), 0, 3, 'float32') - _test_split((2, 6), 1, 3, 'float32') + _test_split((2, 6), 1, 6, 'float32') # rank 3 - _test_split((6, 2, 4), 0, 3, 'float32') + _test_split((6, 2, 4), 0, 2, 'int32') _test_split((2, 6, 4), 1, 3, 'float32') - _test_split((2, 4, 6), 2, 3, 'float32') + _test_split((2, 4, 6), 2, 1, 'float32') # rank 4 _test_split((6, 1, 3, 5), 0, 3, 'float32') _test_split((1, 6, 3, 5), 1, 3, 'float32') @@ -538,45 +549,37 @@ def test_forward_split(): _test_split((1, 6, 3, 5), -3, 3, 'float32') _test_split((1, 3, 6, 5), -2, 3, 'float32') _test_split((1, 3, 5, 6), -1, 3, 'float32') + # size_splits list + _test_split((6,), 0, [1, 2, 3], 'int32') + _test_split((3, 6, 4), -2, [1, 4, 1], 'float32') ####################################################################### -# Split followed by concat -# ------------------------ +# Unstack +# ------- -def _test_split_concat(in_shape, axis, num_split, dtype): - """ One iteration of a split_concat pair""" +def _test_unstack(ip_shape, axis, dtype): + np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype) - with tf.Graph().as_default(): - in_data = tf.placeholder(dtype, in_shape, name="in_data") - splitted = tf.split(in_data, num_split, axis) - tf.concat(splitted, axis) - np_data = np.random.uniform(size=in_shape).astype(dtype) - compare_tf_with_tvm(np_data, 'in_data:0', 'concat:0') - -def test_forward_split_concat(): - '''test split followed by concat layers''' - # rank 1 - _test_split_concat((3,), 0, 1, 'float32') - _test_split_concat((3,), 0, 3, 'float32') - _test_split_concat((6,), 0, 3, 'float32') - # rank 2 - _test_split_concat((6, 2), 0, 3, 'float32') - _test_split_concat((2, 6), 1, 3, 'float32') - # rank 3 - _test_split_concat((6, 2, 4), 0, 3, 'float32') - _test_split_concat((2, 6, 4), 1, 3, 'float32') - _test_split_concat((2, 4, 6), 2, 3, 'float32') - # rank 4 - _test_split((6, 1, 3, 5), 0, 3, 'float32') - _test_split((1, 6, 3, 5), 1, 3, 'float32') - _test_split((1, 3, 6, 5), 2, 3, 'float32') - _test_split((1, 3, 5, 6), 3, 3, 'float32') - # split along negative axis - _test_split((6, 1, 3, 5), -4, 3, 'float32') - _test_split((1, 6, 3, 5), -3, 3, 'float32') - _test_split((1, 3, 6, 5), -2, 3, 'float32') - _test_split((1, 3, 5, 6), -1, 3, 'float32') + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.unstack(in_data, axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])]) + + tf.reset_default_graph() + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.stack(tf.unstack(in_data, axis=axis), axis=axis) + + compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0') + +def test_forward_unstack(): + '''test unstack layer''' + _test_unstack((6,), 0, 'int32') + _test_unstack((2,6), 1, 'float64') + # negative axis + _test_unstack((1,4), -1, 'int32') + _test_unstack((3,6,4), -2, 'float32') ####################################################################### @@ -1139,7 +1142,7 @@ def test_forward_rel_ops(): test_forward_gather() test_forward_stridedslice() test_forward_split() - test_forward_split_concat() + test_forward_unstack() # Activations test_forward_sigmoid()