Skip to content

Commit

Permalink
[Frontend, Tensorflow2] Added support for TensorList ops (apache#8454)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xingyu Zhou authored and trevor-m committed Jul 26, 2021
1 parent bd40eaf commit 21199de
Show file tree
Hide file tree
Showing 5 changed files with 583 additions and 5 deletions.
206 changes: 201 additions & 5 deletions python/tvm/relay/frontend/tensorflow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except, too-many-nested-blocks
"""Tensorflow2.x graph to relay converter.
If model is constructed using tf2.x API, then use this converter:
Expand All @@ -38,12 +38,20 @@
from .common import infer_type as _infer_type

from .tensorflow_ops import _convert_map as _convert_map_common
from .tensorflow_ops import _need_prelude_for_shape_inference
from .tensorflow_ops import _get_more_static_shape_rank
from .tensorflow2_ops import _convert_map as _convert_map_tf2
from .tensorflow2_ops import _need_prelude_for_shape_inference

from ..ty import Any

__all__ = ["from_tensorflow"]

# A map to record tensor list write ops and input tl/tensor indices
# Value is (index of tensor list, index of written node)
_tensor_list_write_ops = {
"TensorListSetItem": (0, 2),
}


def _infer_type_with_prelude(val, prelude):
body = _infer_type(val, prelude.mod)
Expand All @@ -66,6 +74,11 @@ def set_span(sym, node_name):
return sym


def is_tensor_list_constuctor(tf_node):
"""Check whether is tensor list constructor node."""
return tf_node.op == "TensorListReserve"


def convert_const_node(node, shape):
"""convert tf const node into relay const or var"""

Expand Down Expand Up @@ -196,6 +209,10 @@ def __init__(self, module):
self._output_shapes = {}
self._tf_node_map = {}
self._gdef_lib = {}
self._tensor_list_shapes = {}
self._tensor_list_shape_nodes = {}
self._sub_map = {}
self._sub_input_idx_map = {}

def from_tensorflow(
self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None
Expand All @@ -215,10 +232,134 @@ def from_tensorflow(
)
return func, self._params

def _analysis_tensor_list_op(
self,
graph,
node,
tl_write_nodes,
tl_stack_nodes,
tl_construct_nodes,
sub_func_name="",
root_node="",
):
if sub_func_name and sub_func_name not in self._sub_input_idx_map:
self._sub_input_idx_map[sub_func_name] = {}

if node.op == "Placeholder":
# record placeholder node in sub functions
self._sub_map[sub_func_name] = node
self._sub_input_idx_map[sub_func_name][node.name] = len(
self._sub_input_idx_map[sub_func_name]
)

if node.op.startswith("TensorList"):
if is_tensor_list_constuctor(node):
tl_construct_nodes.append(node)
else:
for tl_write_name, idx in _tensor_list_write_ops.items():
if node.op.startswith(tl_write_name):
tl_write_nodes.append((node, idx, sub_func_name, root_node))
if node.op.startswith("TensorListStack"):
tl_stack_nodes.append(node)
elif node.op.startswith("StatelessWhile"):
root_node = node.name
cond_fn_name, body_fn_name = [
parse_attr(node.attr).get(x).name for x in ["cond", "body"]
]
for fn_name in [cond_fn_name, body_fn_name]:
subfunction = self._gdef_lib[fn_name]
sub_func_name = fn_name
for sub_node in subfunction.node:
# bypass const node
if sub_node.op == "Const":
continue
self._tf_node_map[sub_node.name] = sub_node
self._analysis_tensor_list_op(
subfunction,
sub_node,
tl_write_nodes,
tl_stack_nodes,
tl_construct_nodes,
sub_func_name=sub_func_name,
root_node=root_node,
)

def _infer_static_shape_stack_node(self, tl_stack_nodes):
for stack_node in tl_stack_nodes:
if len(stack_node.input) < 2:
# Stack node does not have shape
continue
input_shape_name = stack_node.input[1].split(":")[0]
input_shape_node = self._tf_node_map[input_shape_name]
stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]]
in_idx = -1
while stack:
cnode = stack.pop(0)
if not cnode.op.startswith("TensorList"):
if in_idx and cnode.op.startswith("StatelessWhile"):
stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]])
else:
for iname in cnode.input:
if self._tf_node_map[iname.split(":")[0]].op.startswith(
"StatelessWhile"
):
# identify input index based on output index
if iname.split(":")[1]:
in_idx = int(iname.split(":")[1])
stack.append(self._tf_node_map[iname.split(":")[0]])
# identify the corresponding constructor node and add shape to _tensor_list_shapes
elif cnode.name != stack_node.name:
if is_tensor_list_constuctor(cnode):
shape_attr = parse_attr(input_shape_node.attr)
if "value" not in shape_attr:
continue
raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"])
elem_shape = []
for dim in raw_elem_shape:
if dim < 0:
elem_shape.append(Any())
else:
elem_shape.append(int(dim))
self._tensor_list_shapes[cnode.name] = elem_shape
break

def _infer_static_shape_write_node(self, tl_write_nodes):
for item in tl_write_nodes:
wnode = item[0]
ta_idx, inode_idx = item[1]
sub_func_name = item[2]
root_name = item[3]
stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]]
while stack:
cnode = stack.pop(0)

if not cnode.op.startswith("TensorList"):
if cnode.op == "Placeholder" and sub_func_name:
# need to map subfunction
input_idx = self._sub_input_idx_map[sub_func_name][cnode.name]
stack.append(
self._tf_node_map[
self._tf_node_map[root_name].input[input_idx].split(":")[0]
]
)
else:
for iname in cnode.input:
stack.append(self._tf_node_map[iname.split(":")[0]])
# identify the corresponding constructor node and add it to _tensor_list_shape_nodes
elif cnode.name != wnode.name:
if is_tensor_list_constuctor(cnode):
inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]]
tn = wnode.input[inode_idx].split(":")
output_index = int(tn[1]) if len(tn) > 1 else 0
self._tensor_list_shape_nodes[cnode.name] = (inode, wnode.op, output_index)
break

def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types=None):
if input_types is None:
input_types = {}

tl_write_nodes = []
tl_stack_nodes = []
tl_construct_nodes = []
self._layout = layout
for node in graph.node:
name = node.name
Expand All @@ -235,6 +376,18 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_
self._nodes[node.name] = sym
if param:
self._params[node.name] = param
# recursivly iterate tensorlist op if seen while loop
else:
self._analysis_tensor_list_op(
graph, node, tl_write_nodes, tl_stack_nodes, tl_construct_nodes
)

# Use tensor list stack to infer static tensor list shape
self._infer_static_shape_stack_node(tl_stack_nodes)

# Fetch node contains static tensor list shape
self._infer_static_shape_write_node(tl_write_nodes)

for node in graph.node:
self._backtrack_construct(graph, node.name)

Expand Down Expand Up @@ -321,16 +474,36 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs):
gdef_lib=self._gdef_lib,
)
elif op_name in _convert_map_common:
# assert op are exclusive
assert not set(_convert_map_common.keys()) & set(_convert_map_tf2.keys())
if _need_prelude_for_shape_inference(op_name):
sym = _convert_map_common[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod)
elif op_name in _convert_map_tf2:
if _need_prelude_for_shape_inference(op_name):
sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._prelude)
else:
sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._module.mod)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))

sym = set_span(sym, node_name)
return sym

def _parse_element_shape(self, elem_shape, shape_attr):
if "value" in shape_attr:
raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"])

if raw_elem_shape.size == 1 and raw_elem_shape == -1:
elem_shape.append(Any())
else:
for dim in raw_elem_shape:
if dim < 0:
elem_shape.append(Any())
else:
elem_shape.append(dim)

def _backtrack_construct(self, graph, node_name):
"""Convert a specific tensorflow node to relay expression.
Expand Down Expand Up @@ -370,8 +543,8 @@ def _backtrack_construct(self, graph, node_name):
CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), [])
"""

input_op_name = node_name.split(":")[0].split("^")[-1]

if input_op_name not in self._nodes:
node = self._tf_node_map[input_op_name]
attr = parse_attr(node.attr)
Expand All @@ -386,8 +559,31 @@ def _backtrack_construct(self, graph, node_name):
attr["_node_name"] = node.name
attr["_target_layout"] = self._layout
inputs = [self._backtrack_construct(graph, iname) for iname in node.input]
op = self._convert_operator(graph, node.op, node.name, inputs, attr)

# infer shape for TensorList op
if is_tensor_list_constuctor(node):
input_shape_name = (
node.input[1] if "TensorListFromTensor" in node.op else node.input[0]
)
input_shape_name = input_shape_name.split(":")[0]
input_shape_node = self._tf_node_map[input_shape_name]
shape_attr = parse_attr(input_shape_node.attr)
elem_shape = []

self._parse_element_shape(elem_shape, shape_attr)

if elem_shape:
attr["shape"] = elem_shape
if (
"identical_element_shapes" in attr and attr["identical_element_shapes"]
) or elem_shape:
shape = elem_shape
if node.name in self._tensor_list_shapes:
preset_shape = self._tensor_list_shapes[node.name]
shape = _get_more_static_shape_rank(shape, preset_shape)
attr["shape"] = shape

op = self._convert_operator(graph, node.op, node.name, inputs, attr)
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [
Expand Down
Loading

0 comments on commit 21199de

Please sign in to comment.