diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 8d53b003da1ee..451f46f4fed4c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -5,6 +5,7 @@ import logging import warnings +from collections import defaultdict # Numpy support import numpy as np @@ -1171,6 +1172,100 @@ def _get_abs_layer_name(node): params, num_layers) return sym +_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] + +class Branch: + """A class contains the components that are used to build up a Relay if + node. + """ + def __init__(self): + self._if = None + self.cond_vars = set() + self.cond = None + self.true_branch = None + self.false_branch = None + + def _if_node(self): + from tvm import relay + + cond_vars = [] + bind_map = {} + for i, var in enumerate(list(self.cond_vars)): + if not isinstance(var, _expr.Var): + raise TypeError("var is expected to be _expr.Var type, but " + "received {}".format(repr(var))) + v = relay.var("cond_var" + str(i), + type_annotation=var.type_annotation) + cond_vars.append(v) + bind_map[var] = v + + self.cond = relay.bind(self.cond, bind_map) + cond = relay.op.min(self.cond) + self.true_branch = relay.bind(self.true_branch, bind_map) + self.false_branch = relay.bind(self.false_branch, bind_map) + + return relay.If(cond, self.true_branch, self.false_branch) + + def if_node(self): + """Create a if node if it hasn't been created yet.""" + if self._if is None: + self._if = self._if_node() + return self._if + return self._if + + +class Loop: + """A class contains the components that are used to build up a Relay + recursive call. + """ + def __init__(self): + self.loop_vars = [] + self.cond = None + self.body = [] + self._loop = None + + def _while_loop(self): + from tvm import relay + wl = relay.var('while_loop') + sb = relay.scope_builder.ScopeBuilder() + + loop_vars = [] + bind_map = {} + for i, var in enumerate(self.loop_vars): + assert isinstance(var, _expr.Var), repr(var) + v = relay.var("loop_var" + str(i), + type_annotation=var.type_annotation) + loop_vars.append(v) + bind_map[var] = v + + self.cond = relay.bind(self.cond, bind_map) + self.body = [relay.bind(b, bind_map) for b in self.body] + + cond = relay.op.min(self.cond) + + with sb.if_scope(cond): + sb.ret(wl(*self.body)) + with sb.else_scope(): + sb.ret(relay.Tuple(loop_vars)) + + loop_fn = relay.Function(loop_vars, sb.get()) + sb = relay.scope_builder.ScopeBuilder() + sb.let(wl, loop_fn) + sb.ret(wl(*self.loop_vars)) + return sb.get() + + def while_loop(self): + if self._loop is None: + self._loop = self._while_loop() + return self._loop + return self._loop + + +def _in_while_loop(control_flow_node_map, op_name): + return op_name in control_flow_node_map and \ + "LoopCond" in control_flow_node_map[op_name] + + class GraphProto(object): """ A helper class for handling relay graph copying from Tensorflow GraphDef. Definition: @@ -1183,6 +1278,9 @@ def __init__(self): self._output_shapes = {} self._num_param = 0 self._num_rnn_layer = False + self._loops = {} + self._branches = {} + # self.module = relay.Module({}) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -1231,7 +1329,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) + control_flow_node_map = defaultdict(set) for node in graph.node: + node_name_prefix = node.name.rsplit('/', 1)[0] + control_flow_node_map[node_name_prefix].add(node.op) if node.op == 'Placeholder': if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) @@ -1319,8 +1420,53 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): inputs.append(self._nodes[i][0]) input_shapes[self._nodes[i][0]] = self._output_shapes[i] attr['_input_shapes'] = input_shapes + node_name_prefix = node.name.rsplit('/', 1)[0] - op = self._convert_operator(node.op, inputs, attr, graph) + if node.op == "Merge": + if _in_while_loop(control_flow_node_map, node_name_prefix): + op = self._nodes[node.input[0]] + self._loops[node_name_prefix] = Loop() + else: + if len(self._branches) == 0: + raise RuntimeError("Cannot find a created " + "conditional for merge node") + branch = self._branches[node_name_prefix] + false_br = self._nodes[node.input[0]] + true_br = self._nodes[node.input[1]] + assert len(true_br) == 1 + assert len(false_br) == 1 + branch.true_branch = true_br[0] + branch.false_branch = false_br[0] + op = [branch.if_node()] + # del self._branches[node_name_prefix] + elif node.op == "Exit": + loop = self._loops[node_name_prefix] + exit_name = node.name.split('/')[-1] + assert str.startswith(exit_name, 'Exit') + exit_number = int("0" + exit_name[4:]) + expr = loop.while_loop() + op = _expr.TupleGetItem(expr, exit_number) + elif node.op == "Enter": + op = self._nodes[node.input[0]] + elif node.op == "LoopCond": + op = self._nodes[node.input[0]] + assert len(op) == 1 + self._loops[node_name_prefix].cond = op[0] + elif node.op == "Switch": + op = self._nodes[node.input[0]] + assert len(op) == 1 + if _in_while_loop(control_flow_node_map, node_name_prefix): + self._loops[node_name_prefix].loop_vars.append(op[0]) + else: + if node_name_prefix not in self._branches: + self._branches[node_name_prefix] = Branch() + self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0]) + elif node.op == "NextIteration": + op = self._nodes[node.input[0]] + assert len(op) == 1 + self._loops[node_name_prefix].body.append(op[0]) + else: + op = self._convert_operator(node.op, inputs, attr, graph) # Check if op is converted to param if isinstance(op, np.ndarray): @@ -1351,10 +1497,12 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out_type = ir_pass.infer_type(node_output[0]) self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)] - out = [] if outputs is None: - out = op + if node.op == "Exit": + out = [op[0].tuple_value] + else: + out = op else: out = [self._nodes[out_name][0] for out_name in outputs] @@ -1384,7 +1532,9 @@ def _parse_import_prerequisites(self, graph): elif node.op == "Const": pass else: - if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): + if any([node.op in t for t in [_identity_list, _convert_map, + _convert_map_rnn, + _control_flow_nodes]]): pass else: missing_operators.add(node.op) diff --git a/tests/python/relay/test_tf_loop_to_relay.py b/tests/python/relay/test_tf_loop_to_relay.py new file mode 100644 index 0000000000000..49196123274c8 --- /dev/null +++ b/tests/python/relay/test_tf_loop_to_relay.py @@ -0,0 +1,298 @@ +"""Unit tests for converting TensorFlow control flow op to Relay.""" +import tensorflow as tf +import numpy as np +from tvm import relay +from tvm.relay.frontend.tensorflow import from_tensorflow + + +def check_equal(graph, tf_out): + expr, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + ex = relay.create_executor('debug') + relay_out = ex.evaluate(expr)(**params) + if isinstance(relay_out, relay.backend.interpreter.TensorValue): + np.testing.assert_allclose(tf_out, relay_out.asnumpy()) + else: + if not isinstance(tf_out, list): + tf_out = [tf_out] + for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]): + np.testing.assert_allclose(x, y) + + +def vanilla_loop(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(0) + + def c(i): return tf.less(i, 10) + + def b(i): return tf.add(i, 1) + r = tf.while_loop(c, b, [i]) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def loop_2_vars(): + graph = tf.Graph() + with graph.as_default(): + i0 = tf.constant(0) + j0 = tf.ones([2, 2]) + + def c(i, j): return i < 10 + + def b(i, j): return [tf.add(i, 1), j] + i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0]) + i1 += tf.constant(1337) + + with tf.Session() as sess: + tf_out = sess.run(i1) + + check_equal(graph, tf_out) + + +def loop_3_vars(): + graph = tf.Graph() + with graph.as_default(): + i0 = tf.constant(1) + j0 = tf.constant(2) + k0 = tf.constant(4) + + def c(i, j, k): return i < 10 + + def b(i, j, k): return [i+1, j * k, k + i] + r = tf.while_loop(c, b, loop_vars=[i0, j0, k0]) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def loop_conditions(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(1) + j = tf.constant(1) + k = tf.constant(5) + + def c(i, j, k): return \ + tf.equal(tf.not_equal(tf.less(i + j, 10), + tf.less(j * k, 100)), + tf.greater_equal(k, i + j)) + + def b(i, j, k): return [i+j, j+k, k+1] + r = tf.while_loop(c, b, loop_vars=[i, j, k]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def loop_bodies(): + graph = tf.Graph() + with graph.as_default(): + def body(x): + a = tf.constant(np.array([[5, 6], [7, 8]]), dtype=tf.int32) + b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32) + c = a + b + return tf.nn.relu(x + c) + + def condition(x): + return tf.reduce_sum(x) < 100 + x = tf.constant(0, shape=[2, 2]) + r = tf.while_loop(condition, body, [x]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def nested_loop(): + graph = tf.Graph() + with graph.as_default(): + + def body(x): + def nest_body(c): + return tf.multiply(c, 2) + def cd(c): return tf.less(c, 10) + c = tf.constant(2) + res = tf.while_loop(cd, nest_body, loop_vars=[c]) + return tf.nn.relu(x + res) + + def condition(x): + return tf.greater(x, 100) + x = tf.constant(3) + r = tf.while_loop(condition, body, loop_vars=[x]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def vanilla_cond(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(1) + j = tf.constant(4) + + def f1(): + return tf.multiply(1, 17) + + def f2(): + return tf.add(4, 23) + r = tf.cond(tf.less(i, j), f1, f2) + + with tf.Session(graph=graph) as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def multiple_cond_vars(): + graph = tf.Graph() + with graph.as_default(): + x1 = tf.constant(7) + x2 = tf.constant(12) + z = tf.constant(20) + r = tf.cond(tf.less(tf.add(x1, x2), 10), + lambda: tf.add(10, 2), lambda: tf.square(5)) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def cond_fn_parameters(): + graph = tf.Graph() + with graph.as_default(): + def fn1(x, y): + return tf.multiply(5, 6) + + def fn2(x, y): + return tf.add(3, 4) + + i = tf.constant(1) + j = tf.constant(2) + k = tf.constant(3) + r = tf.cond(tf.less(i, j), lambda: fn1(i, k), lambda: fn2(j, k)) + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={i: 1, j: 2, k: 3}) + + check_equal(graph, tf_out) + + +def nested_cond(): + graph = tf.Graph() + with graph.as_default(): + def fn1(a, b): + def nest_fn1(): + return tf.add(1, 2) + + def nest_fn2(): + return tf.subtract(10, 5) + + res = tf.cond(tf.less(1, 2), nest_fn1, nest_fn2) + return tf.multiply(tf.add(87, res), 10) + + def fn2(a, b): + return tf.add(10, 10) + + x = tf.constant(5) + y = tf.constant(6) + z = tf.constant(7) + pred = tf.less(x, y) + r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) + + check_equal(graph, tf_out) + + +def loop_in_cond(): + graph = tf.Graph() + with graph.as_default(): + def fn1(a, b): + i = tf.constant(0) + + def cd(i): return tf.less(i, 10) + + def bd(i): return tf.add(i, 1) + res = tf.while_loop(cd, bd, [i]) + return tf.multiply(tf.add(20, res), 10) + + def fn2(a, b): + return tf.add(10, 20) + + x = tf.constant(7) + y = tf.constant(20) + z = tf.constant(10) + pred = tf.less(x, y) + r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) + + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) + + check_equal(graph, tf_out) + + +def cond_in_loop(): + graph = tf.Graph() + with graph.as_default(): + def body(x): + x = tf.constant(7) + z = tf.constant(20) + res = tf.cond(tf.less(x, 10), lambda: tf.add( + 10, 20), lambda: tf.square(10)) + return tf.multiply(res, x) + + x = tf.constant(21) + def condition(x): + return tf.less(x, 100) + + r = tf.while_loop(condition, body, loop_vars=[x]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def loop_lambda_placeholder(): + graph = tf.Graph() + with graph.as_default(): + c = lambda i, j: tf.equal(tf.less(i, 17), tf.greater(j, 7)) + b = lambda i, j: [i + 3, j - 13] + + i = tf.placeholder(tf.float32) + j = tf.placeholder(tf.float32) + r = tf.while_loop(c, b, loop_vars=[i, j]) + + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={i: -203, j: 107}) + + check_equal(graph, tf_out) + + +if __name__ == "__main__": + + # tf.while_loop + vanilla_loop() + loop_2_vars() + loop_3_vars() + loop_conditions() + loop_bodies() + + # tf.cond + vanilla_cond() + multiple_cond_vars() + cond_fn_parameters() + + # nested cases + nested_loop() + nested_cond() + loop_in_cond() + cond_in_loop() + + # w/ placeholder and lambda + loop_lambda_placeholder()