Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF] Support TupleWrapper as direct ancestor of control flow ops #5639

Merged
merged 1 commit into from
May 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 25 additions & 34 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3073,21 +3073,19 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_
branch = self._branches[node_name_prefix]
false_br = self._backtrack_construct(node.input[0])
true_br = self._backtrack_construct(node.input[1])
assert len(true_br) == 1
assert len(false_br) == 1
branch.true_branch = true_br[0]
branch.false_branch = false_br[0]
op = [branch.if_node()]
branch.true_branch = true_br
branch.false_branch = false_br
op = branch.if_node()
if node_name_prefix not in self._while_loop_name_set:
try:
cond_val = np.all(_infer_value(branch.cond, self._params,
self._mod).asnumpy())
if cond_val:
op = [branch.true_branch]
op = branch.true_branch
else:
op = [branch.false_branch]
op = branch.false_branch
except Exception:
op = [branch.if_node()]
op = branch.if_node()
elif node.op == "Exit":
loop = self._loops[node_name_prefix]

Expand All @@ -3113,17 +3111,15 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_
if exit_number == j:
body_pos = i
break
op = [_expr.TupleGetItem(expr, body_pos)]
op = _expr.TupleGetItem(expr, body_pos)
elif node.op == "Enter":
op = self._backtrack_construct(node.input[0])
elif node.op == "LoopCond":
op = self._backtrack_construct(node.input[0])
assert len(op) == 1
self._loops[node_name_prefix].cond = op[0]
self._loops[node_name_prefix].cond = op
elif node.op == "Switch":
op = self._backtrack_construct(node.input[0])
cond = self._backtrack_construct(node.input[1])
assert len(op) == 1
if _in_while_loop(self._control_flow_node_map, node_name_prefix):
if node_name_prefix not in self._loop_var_order:
self._loop_var_order[node_name_prefix] = []
Expand All @@ -3132,11 +3128,11 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_
else:
self._loop_var_order[node_name_prefix].\
append(int(node.name.split("Switch_")[-1]))
self._loops[node_name_prefix].loop_vars.append(op[0])
self._loops[node_name_prefix].loop_vars.append(op)
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
self._branches[node_name_prefix].cond = cond[0]
self._branches[node_name_prefix].cond = cond
elif node.op == "NextIteration":
if node_name_prefix not in self._loop_body_order:
self._loop_body_order[node_name_prefix] = []
Expand All @@ -3146,9 +3142,7 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_
self._loop_body_order[node_name_prefix].\
append(int(node.name.split("NextIteration_")[-1]))
op = self._backtrack_construct(node.input[0])

assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0])
self._loops[node_name_prefix].body.append(op)
else:
raise Exception("Cannot identify control flow operator: " +
"{}".format(node.op))
Expand Down Expand Up @@ -3219,10 +3213,10 @@ def _backtrack_construct(self, node_name):
op : relay.Expr
Converted relay expression
"""
node_name = node_name.split(':')[0].split("^")[-1]
input_op_name = node_name.split(':')[0].split("^")[-1]

if node_name not in self._nodes:
node = self._tf_node_map[node_name]
if input_op_name not in self._nodes:
node = self._tf_node_map[input_op_name]
attr = self._parse_attr(node.attr)

if node.op in _control_flow_nodes:
Expand All @@ -3231,20 +3225,10 @@ def _backtrack_construct(self, node_name):
attr,
self._control_flow_node_map)
else:
attr["_output_shapes"] = self._output_shapes[node_name]
attr["_output_shapes"] = self._output_shapes[input_op_name]
attr["_node_name"] = node.name
attr["_target_layout"] = self._layout
inputs = []
for iname in node.input:
in_op = self._backtrack_construct(iname)
if isinstance(in_op, _expr.TupleWrapper):
tn = iname.split(':')
tensor_slot = int(tn[1]) if len(tn) > 1 else 0
in_op = in_op[tensor_slot]
else:
in_op = in_op[0]

inputs.append(in_op)
inputs = [self._backtrack_construct(iname) for iname in node.input]
op = self._convert_operator(node.op, inputs, attr, self._graph)

if isinstance(op, np.ndarray):
Expand All @@ -3258,9 +3242,16 @@ def _backtrack_construct(self, node_name):

node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0])
self._hash2tfnode[node_hash] = node
self._nodes[node_name] = op
self._nodes[input_op_name] = op

out = self._nodes[input_op_name]

if isinstance(out, _expr.TupleWrapper):
tn = node_name.split(':')
tensor_slot = int(tn[1]) if len(tn) > 1 else 0
return out[tensor_slot]

return self._nodes[node_name]
return out[0]

def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
Expand Down
20 changes: 20 additions & 0 deletions tests/python/frontend/tensorflow/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import numpy as np
from tvm import nd
from tvm import relay
Expand Down Expand Up @@ -368,6 +369,23 @@ def condition(x, y):
check_equal(graph, tf_out, {dname: np_data})


def test_switch():
graph = tf.Graph()

with graph.as_default():
data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
dname = 'data'
flag_name = 'flag'
data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
split = tf.split(data, 2, axis=0)
flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name)
output_false, output_true = control_flow_ops.switch(split[1], flag)
with tf.Session() as sess:
tf_out = sess.run(output_false, feed_dict={data.name: data_np, flag.name: False})

check_equal(graph, tf_out, {dname: data_np, flag_name: False})


if __name__ == "__main__":
# tf.while_loop
test_vanilla_loop()
Expand All @@ -390,3 +408,5 @@ def condition(x, y):
test_cond_in_loop()
test_vanilla_loop_bound()
test_nested_loop_bound()

test_switch()