diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 58d9c39a245c4..4e62dc83f5f18 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2404,12 +2404,24 @@ def _get_abs_layer_name(node): _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] # A map to record tensor array write ops and input ta/tensor indices +# Value is (index of tensor array, index of written node) _tensor_array_write_ops = { "TensorArrayWrite" : (3, 2), "TensorArrayScatter" : (0, 2), "TensorArraySplit" : (0, 1), } +def is_tensor_array_constuctor(tf_node): + is_ta = False + ta_start = "TensorArrayV" + if tf_node.op.startswith(ta_start): + try: + int(tf_node.op[len(ta_start)]) + is_ta = True + except ValueError: + pass + return is_ta + def find_parent_loop_name(node_name, while_loop_name_set): """Find name of direct parent while loop.""" ploop_name = "" @@ -2841,7 +2853,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): self._while_loop_name_set.add(node_name_prefix) control_flow_nodes.append(node) elif node.op.startswith("TensorArray"): - if node.op.startswith("TensorArrayV"): + if is_tensor_array_constuctor(node): ta_construct_nodes.append(node) else: for ta_write_name, idx in _tensor_array_write_ops.items(): @@ -2855,7 +2867,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): for gather_node in ta_gather_nodes: input_ta_name = gather_node.input[0] input_ta_node = self._tf_node_map[input_ta_name] - if input_ta_node.op.startswith("TensorArrayV"): + if is_tensor_array_constuctor(input_ta_node): gather_attr = self._parse_attr(gather_node.attr) if "element_shape" not in gather_attr: continue @@ -2880,7 +2892,7 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): for iname in cnode.input: stack.append(self._tf_node_map[iname.split(":")[0]]) elif cnode.name != wnode.name: - if cnode.op.startswith("TensorArrayV"): + if is_tensor_array_constuctor(cnode): inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]] self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op) break @@ -3466,7 +3478,7 @@ def _backtrack_construct(self, node_name): plname = find_parent_loop_name(node_name, self._while_loop_name_set) # For TensorArrayV3 op, we need to infer shape first - if node.op.startswith("TensorArrayV"): + if is_tensor_array_constuctor(node): raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape']) elem_shape = [] for dim in raw_elem_shape: