Skip to content

Commit

Permalink
Minor improve
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Jun 9, 2020
1 parent b0dce2d commit 0cc0260
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,12 +2404,24 @@ def _get_abs_layer_name(node):
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']

# A map to record tensor array write ops and input ta/tensor indices
# Value is (index of tensor array, index of written node)
_tensor_array_write_ops = {
"TensorArrayWrite" : (3, 2),
"TensorArrayScatter" : (0, 2),
"TensorArraySplit" : (0, 1),
}

def is_tensor_array_constuctor(tf_node):
is_ta = False
ta_start = "TensorArrayV"
if tf_node.op.startswith(ta_start):
try:
int(tf_node.op[len(ta_start)])
is_ta = True
except ValueError:
pass
return is_ta

def find_parent_loop_name(node_name, while_loop_name_set):
"""Find name of direct parent while loop."""
ploop_name = ""
Expand Down Expand Up @@ -2841,7 +2853,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
self._while_loop_name_set.add(node_name_prefix)
control_flow_nodes.append(node)
elif node.op.startswith("TensorArray"):
if node.op.startswith("TensorArrayV"):
if is_tensor_array_constuctor(node):
ta_construct_nodes.append(node)
else:
for ta_write_name, idx in _tensor_array_write_ops.items():
Expand All @@ -2855,7 +2867,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
for gather_node in ta_gather_nodes:
input_ta_name = gather_node.input[0]
input_ta_node = self._tf_node_map[input_ta_name]
if input_ta_node.op.startswith("TensorArrayV"):
if is_tensor_array_constuctor(input_ta_node):
gather_attr = self._parse_attr(gather_node.attr)
if "element_shape" not in gather_attr:
continue
Expand All @@ -2880,7 +2892,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
for iname in cnode.input:
stack.append(self._tf_node_map[iname.split(":")[0]])
elif cnode.name != wnode.name:
if cnode.op.startswith("TensorArrayV"):
if is_tensor_array_constuctor(cnode):
inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]]
self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op)
break
Expand Down Expand Up @@ -3466,7 +3478,7 @@ def _backtrack_construct(self, node_name):
plname = find_parent_loop_name(node_name, self._while_loop_name_set)

# For TensorArrayV3 op, we need to infer shape first
if node.op.startswith("TensorArrayV"):
if is_tensor_array_constuctor(node):
raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape'])
elem_shape = []
for dim in raw_elem_shape:
Expand Down

0 comments on commit 0cc0260

Please sign in to comment.