diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 4eb63f2ab65d8..7f631c471a884 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -1225,22 +1225,20 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # the number has to be ignored for single-output nodes. # On the other hand, for multi-output nodes the number is the output index, # and the lack of the number implies 0. - tensor_name = i.split(':') - node_name = tensor_name[0] + node_name, *output_index = i.split(':') if node_name in self._nodes: in_sym = self._nodes[node_name] if len(in_sym.list_output_names()) > 1: - tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0 - in_sym = in_sym[tensor_slot] - input_shape = self._output_shapes[node_name][tensor_slot] + output_index = int(output_index[0]) if output_index else 0 + in_sym = in_sym[output_index] else: - tensor_slot = 0 - input_shape = self._output_shapes[node_name][0] + output_index = 0 + input_shape = self._output_shapes[node_name][output_index] inputs.append(in_sym) input_shapes[in_sym] = [input_shape] # This means the node is 1d in NVM and 0d in TF. # See `_expand_dims_0d_aware`. - if self._outputs_are_0d[node_name][tensor_slot] and input_shape: + if self._outputs_are_0d[node_name][output_index] and input_shape: input_0d_mismatch.add(in_sym) attr['_input_shapes'] = input_shapes attr['_input_0d_mismatch'] = input_0d_mismatch