Skip to content

Commit

Permalink
Simplify logic for multi-output nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr committed Dec 6, 2018
1 parent de8da06 commit 0003a35
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,22 +1225,20 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# the number has to be ignored for single-output nodes.
# On the other hand, for multi-output nodes the number is the output index,
# and the lack of the number implies 0.
tensor_name = i.split(':')
node_name = tensor_name[0]
node_name, *output_index = i.split(':')
if node_name in self._nodes:
in_sym = self._nodes[node_name]
if len(in_sym.list_output_names()) > 1:
tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
in_sym = in_sym[tensor_slot]
input_shape = self._output_shapes[node_name][tensor_slot]
output_index = int(output_index[0]) if output_index else 0
in_sym = in_sym[output_index]
else:
tensor_slot = 0
input_shape = self._output_shapes[node_name][0]
output_index = 0
input_shape = self._output_shapes[node_name][output_index]
inputs.append(in_sym)
input_shapes[in_sym] = [input_shape]
# This means the node is 1d in NVM and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
if self._outputs_are_0d[node_name][output_index] and input_shape:
input_0d_mismatch.add(in_sym)
attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch
Expand Down

0 comments on commit 0003a35

Please sign in to comment.