Skip to content

Commit

Permalink
[TF] Support TupleWrapper as direct ancestor of control flow ops (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and kevinthesun committed Jun 2, 2020
1 parent ae560c8 commit d437eff
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 34 deletions.
59 changes: 25 additions & 34 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3120,21 +3120,19 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_
branch = self._branches[node_name_prefix]
false_br = self._backtrack_construct(node.input[0])
true_br = self._backtrack_construct(node.input[1])
assert len(true_br) == 1
assert len(false_br) == 1
branch.true_branch = true_br[0]
branch.false_branch = false_br[0]
op = [branch.if_node()]
branch.true_branch = true_br
branch.false_branch = false_br
op = branch.if_node()
if node_name_prefix not in self._while_loop_name_set:
try:
cond_val = np.all(_infer_value(branch.cond, self._params,
self._mod).asnumpy())
if cond_val:
op = [branch.true_branch]
op = branch.true_branch
else:
op = [branch.false_branch]
op = branch.false_branch
except Exception:
op = [branch.if_node()]
op = branch.if_node()
elif node.op == "Exit":
loop = self._loops[node_name_prefix]

Expand All @@ -3160,17 +3158,15 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_
if exit_number == j:
body_pos = i
break
op = [_expr.TupleGetItem(expr, body_pos)]
op = _expr.TupleGetItem(expr, body_pos)
elif node.op == "Enter":
op = self._backtrack_construct(node.input[0])
elif node.op == "LoopCond":
op = self._backtrack_construct(node.input[0])
assert len(op) == 1
self._loops[node_name_prefix].cond = op[0]
self._loops[node_name_prefix].cond = op
elif node.op == "Switch":
op = self._backtrack_construct(node.input[0])
cond = self._backtrack_construct(node.input[1])
assert len(op) == 1
if _in_while_loop(self._control_flow_node_map, node_name_prefix):
if node_name_prefix not in self._loop_var_order:
self._loop_var_order[node_name_prefix] = []
Expand All @@ -3179,11 +3175,11 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_
else:
self._loop_var_order[node_name_prefix].\
append(int(node.name.split("Switch_")[-1]))
self._loops[node_name_prefix].loop_vars.append(op[0])
self._loops[node_name_prefix].loop_vars.append(op)
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
self._branches[node_name_prefix].cond = cond[0]
self._branches[node_name_prefix].cond = cond
elif node.op == "NextIteration":
if node_name_prefix not in self._loop_body_order:
self._loop_body_order[node_name_prefix] = []
Expand All @@ -3193,9 +3189,7 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_
self._loop_body_order[node_name_prefix].\
append(int(node.name.split("NextIteration_")[-1]))
op = self._backtrack_construct(node.input[0])

assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0])
self._loops[node_name_prefix].body.append(op)
else:
raise Exception("Cannot identify control flow operator: " +
"{}".format(node.op))
Expand Down Expand Up @@ -3266,10 +3260,10 @@ def _backtrack_construct(self, node_name):
op : relay.Expr
Converted relay expression
"""
node_name = node_name.split(':')[0].split("^")[-1]
input_op_name = node_name.split(':')[0].split("^")[-1]

if node_name not in self._nodes:
node = self._tf_node_map[node_name]
if input_op_name not in self._nodes:
node = self._tf_node_map[input_op_name]
attr = self._parse_attr(node.attr)

if node.op in _control_flow_nodes:
Expand All @@ -3278,20 +3272,10 @@ def _backtrack_construct(self, node_name):
attr,
self._control_flow_node_map)
else:
attr["_output_shapes"] = self._output_shapes[node_name]
attr["_output_shapes"] = self._output_shapes[input_op_name]
attr["_node_name"] = node.name
attr["_target_layout"] = self._layout
inputs = []
for iname in node.input:
in_op = self._backtrack_construct(iname)
if isinstance(in_op, _expr.TupleWrapper):
tn = iname.split(':')
tensor_slot = int(tn[1]) if len(tn) > 1 else 0
in_op = in_op[tensor_slot]
else:
in_op = in_op[0]

inputs.append(in_op)
inputs = [self._backtrack_construct(iname) for iname in node.input]
op = self._convert_operator(node.op, inputs, attr, self._graph)

if isinstance(op, np.ndarray):
Expand All @@ -3305,9 +3289,16 @@ def _backtrack_construct(self, node_name):

node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0])
self._hash2tfnode[node_hash] = node
self._nodes[node_name] = op
self._nodes[input_op_name] = op

out = self._nodes[input_op_name]

if isinstance(out, _expr.TupleWrapper):
tn = node_name.split(':')
tensor_slot = int(tn[1]) if len(tn) > 1 else 0
return out[tensor_slot]

return self._nodes[node_name]
return out[0]

def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
Expand Down
20 changes: 20 additions & 0 deletions tests/python/frontend/tensorflow/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import numpy as np
from tvm import nd
from tvm import relay
Expand Down Expand Up @@ -368,6 +369,23 @@ def condition(x, y):
check_equal(graph, tf_out, {dname: np_data})


def test_switch():
graph = tf.Graph()

with graph.as_default():
data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
dname = 'data'
flag_name = 'flag'
data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
split = tf.split(data, 2, axis=0)
flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name)
output_false, output_true = control_flow_ops.switch(split[1], flag)
with tf.Session() as sess:
tf_out = sess.run(output_false, feed_dict={data.name: data_np, flag.name: False})

check_equal(graph, tf_out, {dname: data_np, flag_name: False})


if __name__ == "__main__":
# tf.while_loop
test_vanilla_loop()
Expand All @@ -390,3 +408,5 @@ def condition(x, y):
test_cond_in_loop()
test_vanilla_loop_bound()
test_nested_loop_bound()

test_switch()

0 comments on commit d437eff

Please sign in to comment.