From d437eff94849989d695cb6c5235c47db045c90b4 Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Wed, 27 May 2020 02:29:29 +0800 Subject: [PATCH] [TF] Support TupleWrapper as direct ancestor of control flow ops (#5639) --- python/tvm/relay/frontend/tensorflow.py | 59 ++++++++----------- .../frontend/tensorflow/test_control_flow.py | 20 +++++++ 2 files changed, 45 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 002fb857e258c..d930ce35fcf9c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -3120,21 +3120,19 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ branch = self._branches[node_name_prefix] false_br = self._backtrack_construct(node.input[0]) true_br = self._backtrack_construct(node.input[1]) - assert len(true_br) == 1 - assert len(false_br) == 1 - branch.true_branch = true_br[0] - branch.false_branch = false_br[0] - op = [branch.if_node()] + branch.true_branch = true_br + branch.false_branch = false_br + op = branch.if_node() if node_name_prefix not in self._while_loop_name_set: try: cond_val = np.all(_infer_value(branch.cond, self._params, self._mod).asnumpy()) if cond_val: - op = [branch.true_branch] + op = branch.true_branch else: - op = [branch.false_branch] + op = branch.false_branch except Exception: - op = [branch.if_node()] + op = branch.if_node() elif node.op == "Exit": loop = self._loops[node_name_prefix] @@ -3160,17 +3158,15 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ if exit_number == j: body_pos = i break - op = [_expr.TupleGetItem(expr, body_pos)] + op = _expr.TupleGetItem(expr, body_pos) elif node.op == "Enter": op = self._backtrack_construct(node.input[0]) elif node.op == "LoopCond": op = self._backtrack_construct(node.input[0]) - assert len(op) == 1 - self._loops[node_name_prefix].cond = op[0] + self._loops[node_name_prefix].cond = op elif node.op == "Switch": op = self._backtrack_construct(node.input[0]) cond = self._backtrack_construct(node.input[1]) - assert len(op) == 1 if _in_while_loop(self._control_flow_node_map, node_name_prefix): if node_name_prefix not in self._loop_var_order: self._loop_var_order[node_name_prefix] = [] @@ -3179,11 +3175,11 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ else: self._loop_var_order[node_name_prefix].\ append(int(node.name.split("Switch_")[-1])) - self._loops[node_name_prefix].loop_vars.append(op[0]) + self._loops[node_name_prefix].loop_vars.append(op) else: if node_name_prefix not in self._branches: self._branches[node_name_prefix] = Branch() - self._branches[node_name_prefix].cond = cond[0] + self._branches[node_name_prefix].cond = cond elif node.op == "NextIteration": if node_name_prefix not in self._loop_body_order: self._loop_body_order[node_name_prefix] = [] @@ -3193,9 +3189,7 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ self._loop_body_order[node_name_prefix].\ append(int(node.name.split("NextIteration_")[-1])) op = self._backtrack_construct(node.input[0]) - - assert len(op) == 1 - self._loops[node_name_prefix].body.append(op[0]) + self._loops[node_name_prefix].body.append(op) else: raise Exception("Cannot identify control flow operator: " + "{}".format(node.op)) @@ -3266,10 +3260,10 @@ def _backtrack_construct(self, node_name): op : relay.Expr Converted relay expression """ - node_name = node_name.split(':')[0].split("^")[-1] + input_op_name = node_name.split(':')[0].split("^")[-1] - if node_name not in self._nodes: - node = self._tf_node_map[node_name] + if input_op_name not in self._nodes: + node = self._tf_node_map[input_op_name] attr = self._parse_attr(node.attr) if node.op in _control_flow_nodes: @@ -3278,20 +3272,10 @@ def _backtrack_construct(self, node_name): attr, self._control_flow_node_map) else: - attr["_output_shapes"] = self._output_shapes[node_name] + attr["_output_shapes"] = self._output_shapes[input_op_name] attr["_node_name"] = node.name attr["_target_layout"] = self._layout - inputs = [] - for iname in node.input: - in_op = self._backtrack_construct(iname) - if isinstance(in_op, _expr.TupleWrapper): - tn = iname.split(':') - tensor_slot = int(tn[1]) if len(tn) > 1 else 0 - in_op = in_op[tensor_slot] - else: - in_op = in_op[0] - - inputs.append(in_op) + inputs = [self._backtrack_construct(iname) for iname in node.input] op = self._convert_operator(node.op, inputs, attr, self._graph) if isinstance(op, np.ndarray): @@ -3305,9 +3289,16 @@ def _backtrack_construct(self, node_name): node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0]) self._hash2tfnode[node_hash] = node - self._nodes[node_name] = op + self._nodes[input_op_name] = op + + out = self._nodes[input_op_name] + + if isinstance(out, _expr.TupleWrapper): + tn = node_name.split(':') + tensor_slot = int(tn[1]) if len(tn) > 1 else 0 + return out[tensor_slot] - return self._nodes[node_name] + return out[0] def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): """Load tensorflow graph which is a python tensorflow graph object into relay. diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index 9777a8dc4462c..90035279bf63b 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -21,6 +21,7 @@ tf.disable_v2_behavior() except ImportError: import tensorflow as tf +from tensorflow.python.ops import control_flow_ops import numpy as np from tvm import nd from tvm import relay @@ -368,6 +369,23 @@ def condition(x, y): check_equal(graph, tf_out, {dname: np_data}) +def test_switch(): + graph = tf.Graph() + + with graph.as_default(): + data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32') + dname = 'data' + flag_name = 'flag' + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname) + split = tf.split(data, 2, axis=0) + flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name) + output_false, output_true = control_flow_ops.switch(split[1], flag) + with tf.Session() as sess: + tf_out = sess.run(output_false, feed_dict={data.name: data_np, flag.name: False}) + + check_equal(graph, tf_out, {dname: data_np, flag_name: False}) + + if __name__ == "__main__": # tf.while_loop test_vanilla_loop() @@ -390,3 +408,5 @@ def condition(x, y): test_cond_in_loop() test_vanilla_loop_bound() test_nested_loop_bound() + + test_switch()