From ae119f81e714e302a813d2440d288efa4c8a35e3 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 11 Jun 2020 22:54:23 -0700 Subject: [PATCH] [Frontend][TensorFlow] Improve Control Flow and TensorArray (#5699) * Improve TF parser control flow and tensor array * Fix tf tensor array scatter * Add ssd test * Add back static ta test * Minor fix for frontend and test_forward * SplitRel for dynamic shape * Fix test ssd * Fix loop var naming issue * Minor improve * Fix format * Fix clang format * Fix tensor array in pytorch frontend * Fix stack size issue for ssd test * Address comments * Fix slice size * Fix build * Rebase --- python/tvm/relay/frontend/common.py | 8 +- python/tvm/relay/frontend/pytorch.py | 18 +- python/tvm/relay/frontend/tensorflow.py | 814 ++++++++++-------- python/tvm/relay/prelude.py | 9 +- src/relay/op/tensor/transform.cc | 26 +- .../frontend/tensorflow/test_control_flow.py | 26 +- .../frontend/tensorflow/test_forward.py | 81 +- tests/python/relay/test_adt.py | 1 - 8 files changed, 602 insertions(+), 381 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 05222c65ecd1..6310e3bfcf29 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -497,13 +497,13 @@ def infer_value(input_val, params, mod=None): portion of the relay graph. This is often needed for functions that whose output shape depends on the value of a tensor. """ + # Check that all free variables have associated parameters. + assert all(var.name_hint in params.keys() for var in analysis.free_vars( + input_val)), "All inputs to infer must be available in params." try: # TODO(kevinthesun): Use VM for all cases. # pylint: disable=import-outside-toplevel from tvm.contrib import graph_runtime - # Check that all free variables have associated parameters. - assert all(var.name_hint in params.keys() for var in analysis.free_vars( - input_val)), "All inputs to infer must be available in params." func = _function.Function(analysis.free_vars(input_val), input_val) with tvm.transform.PassContext(opt_level=0): graph, lib, params = tvm.relay.build(func, target="llvm", params=params) @@ -520,7 +520,7 @@ def infer_value(input_val, params, mod=None): exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm") inputs = [] for param in mod['main'].params: - inputs.append(tvm.nd.array(params[param.name_hint])) + inputs.append(params[param.name_hint]) result = exc.evaluate()(*inputs) return result diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 380388a3df58..2113d7d1f796 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -211,12 +211,12 @@ def tensor_array_concat(lst, axis): assert axis == 0, "Tensor array concat supported only for axis 0" tensor_array, shape = _convert_to_tensor_array(lst, prelude) concat_shape = (Any(),) + shape[1:] - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.define_tensor_get_data(concat_shape) - concat = prelude.get_var_static('tensor_array_concat', "float32", shape) concatenated = concat(tensor_array) - get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) + static_tensor_array_ops.register() + get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape) return get_tensor(concatenated) def _impl(inputs, input_types): @@ -1619,14 +1619,14 @@ def _impl(inputs, input_types): def _tensor_array_stack(prelude): def _impl(inputs, input_types): tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) + + stacked_shape = (Any(),) + shape stack = prelude.get_var_static('tensor_array_stack', "float32", shape) stacked = stack(tensor_array) - stacked_shape = (Any(),) + shape - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.define_tensor_get_data(stacked_shape) - # passing stacked_shape below gives "'Prelude' object has no attribute" error - get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape) + static_tensor_array_ops.register() + get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape) return get_tensor(stacked) return _impl diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 8a10cbe8d59e..f65446691023 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,7 +27,7 @@ from tvm.ir import IRModule from tvm.relay.prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape -from tvm.ir import structural_hash as s_hash +from topi.util import get_const_tuple from .. import analysis from .. import expr as _expr @@ -40,7 +40,6 @@ from .common import infer_shape as _infer_shape from .common import infer_channels as _infer_channels from .common import infer_value as _infer_value -from .common import infer_value_simulated as _infer_value_simulated __all__ = ['from_tensorflow'] @@ -96,6 +95,23 @@ def _get_tuple_param(params, input_node): def _need_prelude_for_shape_inference(op): return "TensorArray" in op +def _get_more_static_shape(shape0, shape1): + """Compare two shapes with the same rank, + and return the one with fewer symbolic dimension. + """ + assert len(shape0) == len(shape1) + num_sym_dim0 = 0 + num_sym_dim1 = 0 + for dim0, dim1 in zip(list(shape0), list(shape1)): + if not isinstance(dim0, int): + num_sym_dim0 += 1 + if not isinstance(dim1, int): + num_sym_dim1 += 1 + + if num_sym_dim0 < num_sym_dim1: + return shape0 + return shape1 + def _rsqrt(): def _impl(inputs, attr, params, mod): inputs.append(tvm.relay.const(-0.5, attr['T'].name)) @@ -275,7 +291,7 @@ def _impl(inputs, attr, params, mod): inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] # NCHW Layout require weights transpose - weights_shape = _infer_shape(inputs[1]) + weights_shape = _infer_shape(inputs[1], mod) if attr['data_format'] == 'NCHW': tmp_shape = weights_shape if opname in ['conv', 'conv_transpose']: @@ -287,7 +303,7 @@ def _impl(inputs, attr, params, mod): weights_shape = tmp_shape - input_shape = _infer_shape(inputs_data) + input_shape = _infer_shape(inputs_data, mod) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) @@ -379,9 +395,6 @@ def _impl(inputs, attr, params, mod): else: attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' - use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4) - channel_axis = 1 if attr['data_format'] == "NCHW" else 3 - # Ignore the new attributes from TF2.0, for now. out = AttrCvt( op_name=_dimension_picker('conv', @@ -394,11 +407,6 @@ def _impl(inputs, attr, params, mod): 'group': ('groups', 1)}, custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr) - if use_bias: - out = _op.nn.bias_add(out, - inputs[2] if opname != 'conv_transpose' else inputs[3], - axis=channel_axis) - if flip_layout: out = _op.transpose(out, axes=(0, 2, 3, 1)) @@ -689,7 +697,7 @@ def _impl(inputs, attr, params, mod): try: crop_size = _get_list_param(params, inputs[3]) except (IndexError, KeyError): - crop_size = _infer_value(inputs[3], params).asnumpy().tolist() + crop_size = _infer_value(inputs[3], params, mod).asnumpy().tolist() method = attr['method'].decode() method = 'nearest_neighbor' if method == 'nearest' else method @@ -723,9 +731,9 @@ def _impl(inputs, attr, params, mod): # Important that the size is defined. If an axis is not, we need to infer what # the shape should be. if -1 in size: - size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + size = _infer_value(inputs[1], params, mod).asnumpy().reshape([-1]).tolist() else: - size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + size = _infer_value(inputs[1], params, mod).asnumpy().reshape([-1]).tolist() attr['size'] = size inputs.pop(1) @@ -844,52 +852,20 @@ def _impl(inputs, attr, params, mod): def _tensor_array(): def _impl(inputs, attr, params, prelude): - try: - from tensorflow.python.framework import tensor_util - except ImportError as e: - raise ImportError( - "Unable to import tensorflow which is required {}".format(e)) - dtype_str = attr.get('dtype').name assert not attr["dynamic_size"], "Dynamic size tensor array is " \ "not supported in TVM yet." - raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape']) - elem_shape = [] - for dim in raw_elem_shape: - if dim < 0: - elem_shape.append(Any()) - else: - elem_shape.append(dim) - - if elem_shape: - # Element shape is specified. - # Directly create static tensor array with given shape. - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - elem_shape) - static_tensor_array_ops.register() - tensor_array_constructor = prelude.get_var_static('tensor_array', - dtype_str, - elem_shape) - tensor_array = tensor_array_constructor(inputs[0]) - _static_tensor_array_map[tensor_array] = tensor_array - elif attr['identical_element_shapes']: - # identical_element_shapes is set but element shape is not given. - # We create a static tensor array with dummy shape and record it in - # _static_tensor_array_map. Later when creating other tensor array ops - # which uses this tensor array, we reconstruct this tensor array with - # actual shape. - dummy_shape = () + if "shape" in attr: + shape = attr["shape"] static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, - dummy_shape) + shape) static_tensor_array_ops.register() tensor_array_constructor = prelude.get_var_static('tensor_array', dtype_str, - dummy_shape) + shape) tensor_array = tensor_array_constructor(inputs[0]) - _static_tensor_array_map[tensor_array] = None else: tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) tensor_array = tensor_array_constructor(inputs[0]) @@ -912,21 +888,12 @@ def _impl(inputs, attr, params, prelude): values = unstack_function(inputs[2]) tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) else: + input_t_shape = _get_more_static_shape(input_t_shape, input_shape) + values_shape = (values_shape[0],) + input_t_shape static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_t_shape) static_tensor_array_ops.register() - # For scatter operation, it is possible to write to a newly create - # tensor array. We need to check and recreate its input tensor array. - if input_ta in _static_tensor_array_map and \ - _static_tensor_array_map[input_ta] is None: - ta_constructor = prelude.get_var_static('tensor_array', - dtype_str, - input_t_shape) - new_ta = ta_constructor(input_ta.args[0]) - _static_tensor_array_map[input_ta] = new_ta - input_ta = new_ta - # Register static indices shape if isinstance(indices_shape[0], int): static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) @@ -960,24 +927,28 @@ def _impl(inputs, attr, params, prelude): dtype_str, input_shape) static_tensor_array_ops.register() + if not isinstance(indices_shape[0], int): gather_function = prelude.get_var_static('tensor_array_gather', dtype_str, input_shape) out_tensor_t = gather_function(inputs[2], inputs[1]) + out_shape = (indices_shape[0],) + input_shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + out_shape) + static_tensor_array_ops.register() # Output shape is (indices_shape[0],) + input_shape - static_tensor_array_ops.define_tensor_get_data((indices_shape[0],) + input_shape) get_data_func = prelude.get_var_static('tensor_get_data', dtype_str, - input_shape) + out_shape) out = get_data_func(out_tensor_t) else: # For fixed length indices, directly generate static shape output read_func = prelude.get_var_static('tensor_array_read', dtype_str, input_shape) - static_tensor_array_ops.define_tensor_get_data(input_shape) get_data_func = prelude.get_var_static('tensor_get_data', dtype_str, input_shape) @@ -987,7 +958,10 @@ def _impl(inputs, attr, params, prelude): out_tensor = get_data_func(read_func(inputs[2], index)) tensor_list.append(_op.expand_dims(out_tensor, axis=0)) - out = _op.concatenate(tensor_list, axis=0) + if indices_shape[0] > 1: + out = _op.concatenate(tensor_list, axis=0) + else: + out = tensor_list[0] return out return _impl @@ -1011,34 +985,30 @@ def _impl(inputs, attr, params, prelude): v = tensor_func(inputs[2]) write_func = prelude.get_var('tensor_array_write', dtype_str) else: - # For write operation, it is possible to write to a newly create - # tensor array. We need to check and recreate its input tensor array. - if input_ta in _static_tensor_array_map and \ - _static_tensor_array_map[input_ta] is None: - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - input_t_shape) - static_tensor_array_ops.register() - ta_constructor = prelude.get_var_static('tensor_array', - dtype_str, - input_t_shape) - new_ta = ta_constructor(input_ta.args[0]) - _static_tensor_array_map[input_ta] = new_ta - input_ta = new_ta - input_ta_shape = input_t_shape - else: - input_ta_rank = len(input_ta_shape) - assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ - format(input_ta_rank, input_rank) - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - input_ta_shape) - static_tensor_array_ops.register() + input_ta_rank = len(input_ta_shape) + assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ + format(input_ta_rank, input_rank) + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_ta_shape) + static_tensor_array_ops.register() tensor_func = prelude.get_var_static("tensor_constructor", dtype_str, input_ta_shape) v = tensor_func(inputs[2]) + # Write tensor with more static shape + actual_shape = _get_more_static_shape(input_t_shape, input_ta_shape) + if actual_shape != input_t_shape: + new_shape = [] + num_any_dim = 0 + for dim in actual_shape: + if not isinstance(dim, int): + num_any_dim += 1 + new_shape.append(dim if isinstance(dim, int) else -1) + if num_any_dim <= 1: + v = tensor_func(_op.reshape(inputs[2], new_shape)) + write_func = prelude.get_var_static('tensor_array_write', dtype_str, input_ta_shape) @@ -1059,7 +1029,6 @@ def _impl(inputs, attr, params, prelude): dtype_str, input_shape) static_tensor_array_ops.register() - static_tensor_array_ops.define_tensor_get_data(input_shape) read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape) out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) get_data_func = prelude.get_var_static('tensor_get_data', @@ -1075,39 +1044,22 @@ def _impl(inputs, attr, params, prelude): dtype_str = attr.get('T').name input_ta = inputs[0] input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) - input_t_shape = _infer_shape(inputs[1], prelude.mod) - input_rank = len(input_t_shape) lengths = _op.cast(inputs[2], 'int32') lengths_shape = _infer_shape(lengths, prelude.mod) value_shape = _infer_shape(inputs[1], prelude.mod) + input_rank = len(value_shape) if input_ta_shape is None: v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) split_func = prelude.get_var('tensor_array_split', dtype_str) else: - # For split operation, it is possible to write to a newly create - # tensor array. We need to check and recreate its input tensor array. - if input_ta in _static_tensor_array_map and \ - _static_tensor_array_map[input_ta] is None: - input_ta_shape = (Any(),) + input_t_shape[1:] - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - input_ta_shape) - static_tensor_array_ops.register() - ta_constructor = prelude.get_var_static('tensor_array', - dtype_str, - input_ta_shape) - new_ta = ta_constructor(input_ta.args[0]) - _static_tensor_array_map[input_ta] = new_ta - input_ta = new_ta - else: - input_ta_rank = len(input_ta_shape) - assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ - format(input_ta_rank, input_rank) - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - input_ta_shape) - static_tensor_array_ops.register() + input_ta_rank = len(input_ta_shape) + assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ + format(input_ta_rank, input_rank) + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_ta_shape) + static_tensor_array_ops.register() # Check static value/indices shape if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int): @@ -1149,10 +1101,14 @@ def _impl(inputs, attr, params, prelude): static_tensor_array_ops.register() concat_func = prelude.get_var_static("tensor_array_concat", dtype_str, input_shape) out_tensor = concat_func(inputs[1]) - static_tensor_array_ops.define_tensor_get_data((Any(),) + input_shape[1:]) + out_shape = (Any(),) + input_shape[1:] + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + out_shape) + static_tensor_array_ops.register() get_data_func = prelude.get_var_static('tensor_get_data', dtype_str, - input_shape) + out_shape) out = get_data_func(out_tensor) return out @@ -1160,9 +1116,13 @@ def _impl(inputs, attr, params, prelude): def _tile(): def _impl(inputs, attr, params, mod): - reps = _get_list_param(params, inputs.pop()) - new_input = [] - new_input.append(inputs.pop(0)) + reps_input = inputs.pop() + if isinstance(reps_input, _expr.Call): + np_reps = _infer_value(reps_input, params, mod).asnumpy() + reps = [np_reps.flatten()[i] for i in range(np_reps.flatten().shape[0])] + else: + reps = _get_list_param(params, reps_input) + new_input = [inputs.pop(0)] return AttrCvt( op_name='tile', @@ -1177,7 +1137,7 @@ def _impl(inputs, attr, params, mod): except (IndexError, KeyError, AttributeError): # Handle symbolic begin try: - begin = _infer_value(inputs[1], params).asnumpy().tolist() + begin = _infer_value(inputs[1], params, mod).asnumpy().tolist() except Exception: begin = inputs[1] try: @@ -1185,10 +1145,21 @@ def _impl(inputs, attr, params, mod): except (IndexError, KeyError, AttributeError): # Handle symbolic size try: - size = _infer_value(inputs[2], params).asnumpy().tolist() + size = _infer_value(inputs[2], params, mod).asnumpy().tolist() except Exception: size = inputs[2] - return _op.strided_slice(inputs[0], begin=begin, end=size, slice_mode="size") + + # Align begin and strides for dynamic shape. + data_dim = len(_infer_shape(inputs[0], mod)) + strides = [1] * data_dim + if not isinstance(begin, (_expr.Call, _expr.Var)): + for _ in range(len(begin), data_dim): + begin.append(0) + elif not isinstance(size, (_expr.Call, _expr.Var)): + for _ in range(len(size), data_dim): + size.append(-1) + return _op.strided_slice(inputs[0], begin=begin, end=size, + strides=strides, slice_mode="size") return _impl @@ -1202,8 +1173,8 @@ def _impl(inputs, attr, params, mod): # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. try: - params_new = _infer_value(pop_node, params) - shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) + params_new = _infer_value(pop_node, params, mod) + shape_arg = tuple(params_new.asnumpy().astype('int32').flatten()) except Exception: # Deal with symbolic shape case. if isinstance(pop_node, _expr.Call) and \ @@ -1211,6 +1182,7 @@ def _impl(inputs, attr, params, mod): # shape_of is the direct ancestor. return _op.reshape_like(inputs[0], pop_node.args[0]) shape_arg = pop_node + return AttrCvt( op_name="reshape", extras={'newshape': shape_arg}, @@ -1218,6 +1190,7 @@ def _impl(inputs, attr, params, mod): return _impl + def _depth_to_space(): def _impl(inputs, attr, params, mod): block_size = int(attr['block_size']) @@ -1239,7 +1212,8 @@ def _impl(inputs, attr, params, mod): def _bias_add(): def _impl(inputs, attr, params, mod): # Must expand for proper broadcasting in NCHW. - if attr['data_format'].decode("utf-8") == 'NCHW': + if 'data_format' in attr and \ + attr['data_format'].decode("utf-8") == 'NCHW': bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) else: bias = inputs[1] @@ -1251,7 +1225,7 @@ def _impl(inputs, attr, params, mod): if isinstance(inputs[1], _expr.Var): shape = params[inputs[1].name_hint] else: - shape = _infer_value(inputs[1], params) + shape = _infer_value(inputs[1], params, mod) shape = list(shape.asnumpy().reshape([-1])) return _op.broadcast_to(inputs[0], shape) return _impl @@ -1286,7 +1260,7 @@ def _impl(inputs, attr, params, mod): # For run-time calculation moving_mean_shape = [int(n) for n in inputs[3].type_annotation.shape] moving_variance_shape = [int(n) for n in inputs[4].type_annotation.shape] - if (moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0): + if moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0: inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True) inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True) out = AttrCvt(op_name='batch_norm', @@ -1352,7 +1326,10 @@ def _impl(inputs, attr, params, mod): # Output shape must be defined to avoid errors. If any axis is not, we must # try to compute its shape. if output_shape is None or -1 in output_shape: - output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() + try: + output_shape = _expr.Constant(_infer_value(inputs[0], params, mod)) + except Exception: + output_shape = inputs[0] fill_arg = _get_num_param(params, inputs.pop(1)) dtype = attr['T'].name @@ -1387,6 +1364,8 @@ def _reduce(op): def _impl(inputs, attr, params, mod): axis = _get_list_param(params, inputs[1]) axis = tuple(axis) + if not axis: + axis = None return AttrCvt( op_name=op, extras={'axis': axis}, @@ -1444,15 +1423,49 @@ def _impl(inputs, attr, params, mod): begin = _get_list_param(params, inputs[1]) end = _get_list_param(params, inputs[2]) stride = _get_list_param(params, inputs[3]) + begin_mask = int(attr.get('begin_mask', 0)) end_mask = int(attr.get('end_mask', 0)) ellipsis_mask = int(attr.get('ellipsis_mask', 0)) new_axis_mask = int(attr.get('new_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) - data_shape = _infer_shape(inputs[0], mod) + in_type = _infer_type(inputs[0], mod) + data_shape = get_const_tuple(in_type.checked_type.shape) data_dim = len(data_shape) stride_dim = len(stride) + # This is a special routine to handle strided_slice after shape_of. + # We need this since in some cases we want to do strided_slice on + # a partial symbolic shape, such as (1, ?), and get a static shape + # (1,). Directly slice on shape_of will result in fully dynamic shape. + # TODO(kevinthesun): Can we generalize this process with partial eval? + if isinstance(inputs[0], _expr.Call) and inputs[0].op == _op.get("shape_of"): + bg = begin[0] + ed = end[0] + st = stride[0] + + if ed <= 0 < st: + ed += data_shape[0] + + in_shape = _infer_shape(inputs[0].args[0], mod) + dtype = in_type.checked_type.dtype + out_data = [] + idx = bg + while idx < ed: + if isinstance(in_shape[idx], int): + out_data.append(in_shape[idx]) + else: + break + idx += st + + # Only return when in_shape is fully static in the range from begin to end. + if idx >= st: + ret = _expr.const(out_data, dtype) + if shrink_axis_mask: + ret = _op.squeeze(ret) + + return ret + def _transform_mask(stride_dim, ellipsis_mask): """Handle mask inputs to create new begin, end, stride and output shape""" m_begin = [0] * data_dim @@ -1492,19 +1505,19 @@ def _transform_mask(stride_dim, ellipsis_mask): break if mask & begin_mask: m_begin[final_index] = data_shape[final_index] \ - if stride[index] < 0 else 0 + if stride[index] < 0 else 0 elif begin[index]: m_begin[final_index] = begin[index] if mask & end_mask: m_end[final_index] = 0 if stride[index] < 0 \ - else data_shape[final_index] + else data_shape[final_index] elif end[index]: m_end[final_index] = end[index] m_stride[final_index] = stride[index] if mask & shrink_axis_mask: #Tensorflow make axis with shrink_axis_mask as dimension 1 m_begin[final_index] = data_shape[final_index] + begin[index] \ - if begin[index] < 0 else begin[index] + if begin[index] < 0 else begin[index] m_end[final_index] = begin[index] + 1 m_stride[final_index] = 1 fshape_indices.append(-2) @@ -1588,7 +1601,7 @@ def _impl(inputs, attr, params, mod): try: axes = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): - axes = _infer_value_simulated(inputs[1], params).asnumpy() + axes = _infer_value(inputs[1], params, mod).asnumpy().tolist() return _op.transpose(inputs[0], axes=axes) return _impl @@ -1620,7 +1633,8 @@ def _impl(inputs, attr, params, mod): input_shape = _infer_shape(inputs[0], mod) name = attr["_node_name"] - params[name] = tvm.nd.array([len(input_shape)]) + params[name] = tvm.nd.array(np.array([len(input_shape)]) + .astype("int32")) return [_expr.var(name, shape=params[name].shape, dtype='int32')] @@ -1633,24 +1647,22 @@ def _impl(inputs, attr, params, mod): start = _get_param(params, inputs[0])[0] except (IndexError, KeyError, AttributeError): try: - start = _infer_value(inputs[1], params).asnumpy().tolist() + start = _infer_value(inputs[1], params, mod).asnumpy().tolist() start = start if not isinstance(start, list) else start[0] except Exception: # Symbolic start start = inputs[0] - if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant): - limit = _get_param(params, inputs[1])[0] - else: - if any(['Rank' in param for param in params]): - limit = params.pop('Rank').asnumpy()[0] - else: - try: - limit = _infer_value(inputs[1], params, mod).asnumpy().tolist() - limit = limit if not isinstance(limit, list) else limit[0] - except Exception: - # Symbolic limit - limit = inputs[1] + try: + limit = _get_param(params, inputs[1])[0] \ + if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \ + else params.pop('Rank').asnumpy()[0] + except (IndexError, KeyError, AttributeError): + try: + limit = _infer_value(inputs[1], params, mod).asnumpy().tolist() + limit = limit if not isinstance(limit, list) else limit[0] + except Exception: + limit = inputs[1] try: delta = _get_param(params, inputs[2])[0] @@ -1785,16 +1797,21 @@ def _impl(inputs, attr, params, mod): try: k = int(_get_num_param(params, k_input)) except (IndexError, KeyError, AttributeError): - k = int(_infer_value(k_input, params).asnumpy().tolist()) - if k < 1: - raise tvm.error.OpAttributeInvalid( - 'Attribute k must be positive in operator TopKV2') + try: + k = int(_infer_value(k_input, params, mod).asnumpy().tolist()) + except Exception: + k = k_input + if isinstance(k, int): + if k < 1: + raise tvm.error.OpAttributeInvalid( + 'Attribute k must be positive in operator TopKV2') + k = _expr.const(k) if attr['sorted'] is False: raise tvm.error.OpAttributeUnImplemented( 'Attribute sorted=False is not supported in operator TopKV2') return AttrCvt(op_name='topk', ignores=['sorted'], - extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) + extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})([inputs[0]], attr) return _impl def _floordiv(): @@ -1821,12 +1838,12 @@ def _impl(inputs, attr, params, mod): try: block_shape = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): - block_shape = _infer_value(inputs[1], params).asnumpy().tolist() + block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist() try: paddings = _get_list_param(params, inputs[2]) except (IndexError, KeyError, AttributeError): - paddings = _infer_value(inputs[2], params).asnumpy() + paddings = _infer_value(inputs[2], params, mod).asnumpy() paddings = np.squeeze(paddings) if len(paddings.shape) == 1: paddings = np.expand_dims(paddings, axis=0) @@ -1851,7 +1868,7 @@ def _impl(inputs, attr, params, mod): axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded) + permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, mod) # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # producing an output tensor of shape: # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., @@ -1871,12 +1888,12 @@ def _impl(inputs, attr, params, mod): try: block_shape = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): - block_shape = _infer_value(inputs[1], params).asnumpy().tolist() + block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist() try: crops = _get_list_param(params, inputs[2]) except (IndexError, KeyError, AttributeError): - crops = _infer_value(inputs[2], params).asnumpy() + crops = _infer_value(inputs[2], params, mod).asnumpy() crops = np.squeeze(crops) if len(crops.shape) == 1: crops = np.expand_dims(crops, axis=0) @@ -1905,7 +1922,7 @@ def _impl(inputs, attr, params, mod): # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # input_shape[M+1], ..., input_shape[N-1]] - reshaped_permuted_shape = _infer_shape(reshaped_permuted) + reshaped_permuted_shape = _infer_shape(reshaped_permuted, mod) cropped = reshaped_permuted for axis in range(1, M+1): crop = crops[axis - 1] @@ -2395,29 +2412,36 @@ def _get_abs_layer_name(node): # 1.x. _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] -# A map to record tensor array with fixed rank shape -_static_tensor_array_map = {} - -class RewriteSubgraph(ExprMutator): - """ - A helper class to rewrite expr in while loop function to variable - - Parameters - ---------- - rewrite_map : Dict[expr, expr] - A dictionay contains a set of expr to var mapping. - """ - def __init__(self, rewrite_map): - ExprMutator.__init__(self) - self.rewrite_map = rewrite_map - - def visit(self, expr): - if expr in self.rewrite_map: - return self.rewrite_map[expr] - return super().visit(expr) +# A map to record tensor array write ops and input ta/tensor indices +# Value is (index of tensor array, index of written node) +_tensor_array_write_ops = { + "TensorArrayWrite" : (3, 2), + "TensorArrayScatter" : (0, 2), + "TensorArraySplit" : (0, 1), +} -def rewrite_subgraph(expr, rewrites): - return RewriteSubgraph(rewrites).visit(expr) +def is_tensor_array_constuctor(tf_node): + """Check whether is tensor array constructor node.""" + is_ta = False + ta_start = "TensorArrayV" + if tf_node.op.startswith(ta_start): + is_ta = tf_node.op[len(ta_start)].isnumeric() + return is_ta + +def find_parent_loop_name(node_name, while_loop_name_set): + """Find name of direct parent while loop.""" + ploop_name = "" + name_prefix = node_name.rsplit('/', 1)[0] + if name_prefix.startswith("^"): + name_prefix = name_prefix[1:] + for lname in while_loop_name_set: + if name_prefix.startswith(lname) and len(ploop_name) < len(lname): + ploop_name = lname + + if len(ploop_name) == 0: + ploop_name = name_prefix + + return ploop_name def _in_while_loop(control_flow_node_map, op_name): """ @@ -2444,6 +2468,28 @@ def _in_while_loop(control_flow_node_map, op_name): return op_name in control_flow_node_map and \ "LoopCond" in control_flow_node_map[op_name] +class RewriteSubgraph(ExprMutator): + """ + A helper class to rewrite expr in while loop function to variable. + + Parameters + ---------- + rewrite_map : Dict[expr, expr] + A dictionay contains a set of expr to var mapping. + """ + def __init__(self, rewrite_map): + ExprMutator.__init__(self) + self.rewrite_map = rewrite_map + + def visit(self, expr): + if expr in self.rewrite_map: + return self.rewrite_map[expr] + return super().visit(expr) + +def rewrite_subgraph(expr, rewrites): + """Rewrite loop body.""" + return RewriteSubgraph(rewrites).visit(expr) + class Branch: """A class contains the components that are used to build up a Relay if node. @@ -2524,118 +2570,50 @@ def if_node(self): self._if = self._if_node() return self._if +class VarChecker(ExprVisitor): + """Check whether a Variable is used in loop body. -class LoopBound(ExprVisitor): - """ - When a loop body is create, we get a Relay expression backtracing all - the way back to input node. This will result in lots of unnecessary - expression placed into loop body and compute multiple times. For example, - consider the following tensorflow code: - - .. code-block:: python - - i = tf.constant(0) - data = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024)) - slice = tf.strided_slice(data, 0, 512) - def c(i): return tf.less(i, 10) - def b(i): return [tf.add(i, 1), tf.add(i, 1) + slice] - r = tf.while_loop(c, b, [i]) - - If we directly create recursive function, slice will be placed into function body. - Instead, we recognize whether slice is inside while_loop block and pass it as an - extra loop variable to avoid duplicate computation. - - TODO(kevinthesun): Add a LICM pass for Relay to handle generic loop/function. + Parameters + ---------- + var : relay.expr.Var + Relay Variable to be checked. """ - def __init__(self, loop_name, hash2tfnode, while_loop_name_set): + def __init__(self, var): ExprVisitor.__init__(self) - self._loop_name = loop_name - self._hash2tfnode = hash2tfnode - self._while_loop_name_set = while_loop_name_set - self.extra_loop_var_names = set() - - def _find_parent_loop_name(self, node_name): - """Find name of direct parent while loop.""" - ploop_name = "" - name_prefix = node_name.rsplit('/', 1)[0] - if name_prefix.startswith("^"): - name_prefix = name_prefix[1:] - # To get the name of the direct parent while loop for a given node, - # we iterate all the while loop names inside TensorFlow graph def. - # If we find a loop name with which current node name starts, - # it means current node is under this loop. However, due to nested - # loop, this loop may not be the direct parent while loop of current - # node. We need to keep the longest loop name, which represents the - # innermost while loop corresponding to current node. - for lname in self._while_loop_name_set: - if name_prefix.startswith(lname) and len(ploop_name) < len(lname): - ploop_name = lname - - if len(ploop_name) == 0: - ploop_name = name_prefix - - return ploop_name + self._var = var + self.used = False def visit(self, expr): - """ - For each expression in the body, look up the corresponding - TensorFlow node with its structural hash. If the current loop is the - direct parent of this node, we check whether its every input node belongs - to the current loop. If not, we mark this input node as an extra loop - variable to the current loop. - """ - expr_hash = s_hash(expr) - - if expr_hash in self._hash2tfnode: - node = self._hash2tfnode[expr_hash] - ploop_name = self._find_parent_loop_name(node.name) - # It is possibel that a node is under nested loop of current loop. - # We only check the direct children of current loop. - if ploop_name == self._loop_name: - for iname in node.input: - iploop_name = self._find_parent_loop_name(iname) - # Use startswith to deal with nested loop - if not iploop_name.startswith(self._loop_name): - if iname not in self.extra_loop_var_names: - self.extra_loop_var_names.add(iname) + if self._var == expr: + self.used = True super().visit(expr) - class Loop: """ A class contains the components that are used to build up a Relay recursive call. - Parameters ---------- - loop_vars : List[tvm.relay.Expr] - The loop variables that used in a while loop. - - cond : tvm.relay.Expr - The condition of a while loop. + mod : tvm.IRModule + Module for current parsed IR. - body : tvm.relay.Expr - The body of a matched while loop. + loop_name : str + Name prefix of while loop in TensorFlow graph. - _loop : tvm.relay.Expr - An internal variable indicates where a recursive call is already created - for a matched TF while loop construct. + lvar2expr : dict from str to dict from Relay.expr.Var to Relay.expr + A dictionary recording all loop vars and corresponding + relay expression. Examples -------- The following is a vanilla loop from TensorFlow: - .. code-block:: python - i = tf.constant(0) c = lambda i: tf.less(i, 10) b = lambda i: tf.add(i, 1) r = tf.while_loop(c, b, [i]) - It will be converted to the following recursive call in Relay: - .. code-block:: python - fn (%while/Less/y: Tensor[(1,), int32], %while/Add/y: Tensor[(1,), int32], %Const: Tensor[(1,), int32]) { @@ -2657,86 +2635,74 @@ class Loop: %6 } """ - def __init__(self, mod, loop_name, hash2tfnode, - node_map, while_loop_name_set): - self.loop_vars = [] + def __init__(self, mod, loop_name, lvar2expr): self.cond = None self.body = [] self._loop = None self._mod = mod self._loop_name = loop_name - self._hash2tfnode = hash2tfnode - self._node_map = node_map - self._while_loop_name_set = while_loop_name_set + self._lvar2expr = lvar2expr + self.loop_vars = [] + self.aligned = False def _while_loop(self): """An internal API to create a Relay recursive call for a matched TF `while_loop` construct. """ + bind_map = {} wl = tvm.relay.var('while_loop') - sb = tvm.relay.scope_builder.ScopeBuilder() - loop_checker = LoopBound(self._loop_name, - self._hash2tfnode, - self._while_loop_name_set) - for body in self.body: - loop_checker.visit(body) - - loop_vars = [] - bind_map = {} - loop_var_hash_set = set() - for var in self.loop_vars: - loop_var_hash_set.add(s_hash(var)) - - extra_nodes = [] - for extra_loop_var_name in loop_checker.extra_loop_var_names: - extra_loop_var_name = extra_loop_var_name.split(':')[0].split("^")[-1] - extra_node = self._node_map[extra_loop_var_name] - extra_node = extra_node if isinstance(extra_node, _expr.Tuple) else extra_node[0] - if s_hash(extra_node) not in loop_var_hash_set: - self.loop_vars.append(extra_node) - extra_nodes.append(extra_node) - - for i, var in enumerate(self.loop_vars): - if not isinstance(var, _expr.Var): - var_chk = _infer_type(var, self._mod) - var_type = var_chk.checked_type - else: - var_type = var.type_annotation - - v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type) - loop_vars.append(v) - bind_map[var] = v - - - self.cond = rewrite_subgraph(self.cond, bind_map) - self.body = [rewrite_subgraph(b, bind_map) for b in self.body] - - self.body_shape = [] - for body in self.body: - current_node = body - shape = _infer_shape(current_node, self._mod) - while not isinstance(shape, (tuple, list)): - current_node = current_node.args[-1] - shape = _infer_shape(current_node, self._mod) - self.body_shape.append(shape) + lv_list = [] + expr_list = [] + extra_vars = [] + + for i, lv in enumerate(self.loop_vars): + if self._loop_name not in self._lvar2expr: + self._lvar2expr[self._loop_name] = {} + + # Handle the case when loop var is not properly lifted. + # This can happen when loop var node name is set accidentally + # beginning with loop name. + if lv not in self._lvar2expr[self._loop_name]: + var_name = "{}_loop_var_{}".format(self._loop_name, i) + var_type = _infer_type(lv, self._mod).checked_type + loop_var = tvm.relay.var(var_name, type_annotation=var_type) + self._lvar2expr[self._loop_name][loop_var] = lv + bind_map[lv] = loop_var + self.loop_vars[i] = loop_var + lv = loop_var + + lv_list.append(lv) + expr_list.append(self._lvar2expr[self._loop_name][lv]) + + if bind_map: + self.cond = rewrite_subgraph(self.cond, bind_map) + self.body = [rewrite_subgraph(b, bind_map) for b in self.body] cond = tvm.relay.op.min(self.cond) + for lv, exp in self._lvar2expr[self._loop_name].items(): + if lv not in self.loop_vars: + var_checker = VarChecker(lv) + for bd in self.body + [cond]: + var_checker.visit(bd) + if var_checker.used: + lv_list.append(lv) + expr_list.append(exp) + extra_vars.append(lv) + break + with sb.if_scope(cond): - extra_args = [] - if extra_nodes: - extra_args = list(loop_vars[-len(extra_nodes):]) - sb.ret(wl(*list(self.body + extra_args))) + sb.ret(wl(*list(self.body + extra_vars))) with sb.else_scope(): - sb.ret(tvm.relay.Tuple(loop_vars)) + sb.ret(tvm.relay.Tuple(lv_list)) - loop_fn = tvm.relay.Function(loop_vars, sb.get()) + loop_fn = tvm.relay.Function(lv_list, sb.get()) sb = tvm.relay.scope_builder.ScopeBuilder() sb.let(wl, loop_fn) - loop_ret = wl(*self.loop_vars) + loop_ret = wl(*expr_list) sb.ret(loop_ret) ret = sb.get() @@ -2770,9 +2736,13 @@ def __init__(self): self._control_flow_node_map = defaultdict(set) self._loop_body_order = {} self._loop_var_order = {} - self._hash2tfnode = {} + self._lvar2expr = {} + self._lname_map = {} + self._sorted_cf_node_names = [] self._while_loop_name_set = set() self._main_graph_proto = self + self._tensor_array_shapes = {} + self._tensor_array_shape_nodes = {} def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -2820,6 +2790,9 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): missing_operators = self._parse_import_prerequisites(graph) control_flow_nodes = [] + ta_write_nodes = [] + ta_gather_nodes = [] + ta_construct_nodes = [] self._in_shape = shape self._layout = layout self._graph = graph @@ -2883,6 +2856,50 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): if node.op == "Exit": self._while_loop_name_set.add(node_name_prefix) control_flow_nodes.append(node) + elif node.op.startswith("TensorArray"): + if is_tensor_array_constuctor(node): + ta_construct_nodes.append(node) + else: + for ta_write_name, idx in _tensor_array_write_ops.items(): + if node.op.startswith(ta_write_name): + ta_write_nodes.append((node, idx)) + break + if node.op.startswith("TensorArrayGather"): + ta_gather_nodes.append(node) + + # Use tensor array gather to infer static tensor array shape + for gather_node in ta_gather_nodes: + input_ta_name = gather_node.input[0] + input_ta_node = self._tf_node_map[input_ta_name] + if is_tensor_array_constuctor(input_ta_node): + gather_attr = self._parse_attr(gather_node.attr) + if "element_shape" not in gather_attr: + continue + raw_elem_shape = tensor_util.TensorShapeProtoToList(gather_attr["element_shape"]) + elem_shape = [] + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(int(dim)) + self._tensor_array_shapes[input_ta_node.name] = elem_shape + + # Fetch node contains static tensor array shape + for item in ta_write_nodes: + wnode = item[0] + ta_idx, inode_idx = item[1] + + stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]] + while stack: + cnode = stack.pop(0) + if not cnode.op.startswith("TensorArray"): + for iname in cnode.input: + stack.append(self._tf_node_map[iname.split(":")[0]]) + elif cnode.name != wnode.name: + if is_tensor_array_constuctor(cnode): + inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]] + self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op) + break # First, parse all control flow nodes. # Convert tf.cond to Branch and tf.while_loop to Loop. @@ -2907,6 +2924,9 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): if i == len(control_flow_nodes) - 1: sorted_cf_nodes.extend(exits) + for node in sorted_cf_nodes: + self._sorted_cf_node_names.append(node.name) + for node in sorted_cf_nodes: self._backtrack_construct(node.name) @@ -2940,7 +2960,13 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): out.append(out_rnn) out = out[0] if len(out) == 1 else _expr.Tuple(out) - func = _function.Function(analysis.free_vars(out), out) + fvars = analysis.free_vars(out) + func = _function.Function(fvars, out) + final_params = {} + for fv in fvars: + if fv.name_hint in self._params: + final_params[fv.name_hint] = self._params[fv.name_hint] + self._params = final_params return func def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): @@ -3128,22 +3154,27 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ Converted relay expression. """ node_name_prefix = node.name.rsplit('/', 1)[0] + plname = find_parent_loop_name(node.name, self._while_loop_name_set) if node.op == "Merge": if _in_while_loop(self._control_flow_node_map, node_name_prefix): - op = self._backtrack_construct(node.input[0]) + op = self._licm_construct(plname, node.input[0]) if node_name_prefix not in self._loops: self._loops[node_name_prefix] = Loop(self._mod, - node_name_prefix, - self._hash2tfnode, - self._nodes, - self._while_loop_name_set) + plname, + self._lvar2expr) else: - if len(self._branches) == 0: - raise RuntimeError("Cannot find a created " - "conditional for merge node") + if node_name_prefix not in self._branches: + switch_prefix = node_name_prefix + "/Switch" + merge_idx = self._sorted_cf_node_names.index(node.name) + for i in range(merge_idx - 1, -1, -1): + cf_name = self._sorted_cf_node_names[i] + if cf_name.startswith(switch_prefix): + self._backtrack_construct(cf_name) + break + branch = self._branches[node_name_prefix] - false_br = self._backtrack_construct(node.input[0]) - true_br = self._backtrack_construct(node.input[1]) + false_br = self._licm_construct(plname, node.input[0]) + true_br = self._licm_construct(plname, node.input[1]) branch.true_branch = true_br branch.false_branch = false_br op = branch.if_node() @@ -3184,13 +3215,13 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ break op = _expr.TupleGetItem(expr, body_pos) elif node.op == "Enter": - op = self._backtrack_construct(node.input[0]) + op = self._licm_construct(plname, node.input[0]) elif node.op == "LoopCond": - op = self._backtrack_construct(node.input[0]) + op = self._licm_construct(plname, node.input[0]) self._loops[node_name_prefix].cond = op elif node.op == "Switch": - op = self._backtrack_construct(node.input[0]) - cond = self._backtrack_construct(node.input[1]) + op = self._licm_construct(plname, node.input[0]) + cond = self._licm_construct(plname, node.input[1]) if _in_while_loop(self._control_flow_node_map, node_name_prefix): if node_name_prefix not in self._loop_var_order: self._loop_var_order[node_name_prefix] = [] @@ -3212,7 +3243,7 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ else: self._loop_body_order[node_name_prefix].\ append(int(node.name.split("NextIteration_")[-1])) - op = self._backtrack_construct(node.input[0]) + op = self._licm_construct(plname, node.input[0]) self._loops[node_name_prefix].body.append(op) else: raise Exception("Cannot identify control flow operator: " + @@ -3353,6 +3384,55 @@ def _convert_operator(self, op_name, inputs, attrs, raise NotImplementedError("Operator {} not implemented.".format(op_name)) return sym + def _licm_construct(self, loop_name, node_name): + """Construct a node by considering whether it is + loop invariant with the given while loop. If yes, we + generate a loop Variable. Otherwise, return regular + converted relay expression. + + Parameters + ---------- + loop_name : str + TensorFlow while loop name to be checked. + + node_name : str + TensorFlow node name. + + Returns + ------- + out : relay.Expr or relay.Var + Converted relay expression or loop var. + """ + actual_expr = self._backtrack_construct(node_name) + tn = node_name.split(':') + node_name = tn[0].split("^")[-1] + cloop_name = find_parent_loop_name(node_name, self._while_loop_name_set) + + if loop_name in self._while_loop_name_set and not cloop_name.startswith(loop_name): + if loop_name not in self._lvar2expr: + self._lvar2expr[loop_name] = {} + if loop_name not in self._lname_map: + self._lname_map[loop_name] = {} + + if node_name not in self._lname_map[loop_name]: + var_name = "{}_loop_var".format(node_name) + var_type = _infer_type(actual_expr, self._mod).checked_type + loop_var = tvm.relay.var(var_name, type_annotation=var_type) + try: + extra_param = _infer_value(actual_expr, self._params, self._mod) + self._params[var_name] = extra_param + except Exception: + pass + self._lvar2expr[loop_name][loop_var] = actual_expr + self._lname_map[loop_name][node_name] = loop_var + ret = loop_var + else: + ret = self._lname_map[loop_name][node_name] + else: + ret = actual_expr + + return ret + def _backtrack_construct(self, node_name): """Convert a specific tensorflow node to relay expression. @@ -3365,13 +3445,19 @@ def _backtrack_construct(self, node_name): Parameters ---------- node_name : str - Tensorflow node name. + TensorFlow node name. Returns ------- op : relay.Expr Converted relay expression """ + try: + from tensorflow.python.framework import tensor_util + except ImportError as e: + raise ImportError( + "Unable to import tensorflow which is required {}".format(e)) + input_op_name = node_name.split(':')[0].split("^")[-1] if input_op_name not in self._nodes: @@ -3387,7 +3473,47 @@ def _backtrack_construct(self, node_name): attr["_output_shapes"] = self._output_shapes[input_op_name] attr["_node_name"] = node.name attr["_target_layout"] = self._layout + inputs = [self._backtrack_construct(iname) for iname in node.input] + + plname = find_parent_loop_name(node_name, self._while_loop_name_set) + + # For TensorArrayV3 op, we need to infer shape first + if is_tensor_array_constuctor(node): + raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape']) + elem_shape = [] + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(dim) + + if elem_shape: + attr["shape"] = elem_shape + if attr['identical_element_shapes'] or elem_shape: + shape_node, wnode_op = self._tensor_array_shape_nodes[node.name] + converted = self._backtrack_construct(shape_node.name) + shape = _infer_shape(converted, self._mod) + if wnode_op.startswith("TensorArraySplit"): + shape = (Any(),) + shape[1:] + elif wnode_op.startswith("TensorArrayScatter"): + shape = shape[1:] + + if node.name in self._tensor_array_shapes: + preset_shape = self._tensor_array_shapes[node.name] + shape = _get_more_static_shape(shape, preset_shape) + + if "shape" in attr: + attr["shape"] = _get_more_static_shape(shape, attr["shape"]) + else: + attr["shape"] = shape + + # LICM + if plname in self._while_loop_name_set: + for i, iname in enumerate(node.input): + actual_input = self._licm_construct(plname, iname) + inputs[i] = actual_input + op = self._convert_operator(node.op, inputs, attr, self._graph) if isinstance(op, np.ndarray): @@ -3399,8 +3525,6 @@ def _backtrack_construct(self, node_name): elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)): op = [op] - node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0]) - self._hash2tfnode[node_hash] = node self._nodes[input_op_name] = op out = self._nodes[input_op_name] diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 243eace0fb94..5b2ecc27b998 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -555,21 +555,21 @@ def define_tensor_array_gather(self): self.prelude.mod[gather_var] = \ Function([tensor_array, indices], body, output_tensor_type_var(), []) - def define_tensor_get_data(self, data_shape): + def define_tensor_get_data(self): """Defines a function to get a Tensor from tensor_t with given shape. """ tensor_get_data_name = self.get_name("tensor_get_data") tensor_get_data_var = self._create_global_var(tensor_get_data_name) setattr(self.prelude, tensor_get_data_name, tensor_get_data_var) - - tensor_type_var, tensor_constructor = self._get_adt_by_shape(data_shape) + tensor_type_var = self.get_var('tensor_t') + tensor_constructor = self.get_var('tensor_constructor') t = Var('tensor', tensor_type_var()) tvar = Var('t') case =\ Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar) self.prelude.mod[tensor_get_data_var] = \ Function([t], Match(t, [case], False), - TensorType(data_shape, self.dtype), []) + TensorType(self.shape, self.dtype), []) def register(self): """Register all tensor array ops in Prelude""" @@ -586,6 +586,7 @@ def register(self): self.define_tensor_array_concat() self.define_tensor_array_stack() self.define_tensor_array_gather() + self.define_tensor_get_data() def _get_adt_by_shape(self, shape): """Get ADT type and constructor with given shape.""" diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9d87610b8dbf..222a38d8814e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2088,13 +2088,19 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_GE(axis, 0) << "axis should be within the input dimension range."; if (const IntImmNode* sections = param->indices_or_sections.as()) { - CHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == - tir::make_zero(DataType::Int(64)))) - << "indices_or_sections need to be able to divide input.shape[axis]"; + if (!data->shape[axis].as()) { + CHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == + tir::make_zero(DataType::Int(64)))) + << "indices_or_sections need to be able to divide input.shape[axis]"; + } std::vector fields; for (int i = 0; i < sections->value; ++i) { std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = indexdiv(oshape[axis], sections->value); + if (data->shape[axis].as()) { + oshape[axis] = Any(); + } else { + oshape[axis] = indexdiv(oshape[axis], sections->value); + } auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); } @@ -2112,10 +2118,16 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); } - CHECK(reporter->Assert(begin < data->shape[axis])) - << "The sum of sections must match the input.shape[axis]"; + if (!data->shape[axis].as()) { + CHECK(reporter->Assert(begin < data->shape[axis])) + << "The sum of sections must match the input.shape[axis]"; + } std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = data->shape[axis] - begin; + if (data->shape[axis].as()) { + oshape[axis] = Any(); + } else { + oshape[axis] = data->shape[axis] - begin; + } auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); reporter->Assign(types[1], TupleType(Array(fields))); diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index 90035279bf63..3ec04bf38490 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -46,7 +46,7 @@ def check_equal(graph, tf_out, input_map=None): def test_vanilla_loop(): graph = tf.Graph() with graph.as_default(): - i = tf.constant(0) + i = tf.constant(0, name="while/constant") def c(i): return tf.less(i, 10) @@ -368,7 +368,6 @@ def condition(x, y): check_equal(graph, tf_out, {dname: np_data}) - def test_switch(): graph = tf.Graph() @@ -385,6 +384,28 @@ def test_switch(): check_equal(graph, tf_out, {dname: data_np, flag_name: False}) +def test_loop_tuple_input(): + graph = tf.Graph() + + with graph.as_default(): + data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32') + dname = 'data' + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname) + split = tf.split(data, 2, axis=0) + + def body(x, y): + return x + 2, y + 1 + + start = tf.constant(0) + def condition(x, y): + return tf.less(y, 20) + + r = tf.while_loop(condition, body, loop_vars=[split[1], start]) + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={data.name: data_np}) + + check_equal(graph, tf_out, {dname: data_np}) + if __name__ == "__main__": # tf.while_loop @@ -410,3 +431,4 @@ def test_switch(): test_nested_loop_bound() test_switch() + test_loop_tuple_input() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 07c1cd343bcd..e78f9d0cb6e7 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -21,6 +21,7 @@ This article is a test script to test tensorflow operator with Relay. """ from __future__ import print_function +import threading import numpy as np import pytest try: @@ -45,6 +46,7 @@ from tvm import te from tvm import relay import tvm.relay.testing.tf as tf_testing +from tvm.runtime.vm import VirtualMachine from packaging import version as package_version ####################################################################### @@ -98,11 +100,10 @@ def vmobj_to_list(o): def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None, opt_level=3, mode='graph_runtime', - cuda_layout="NCHW"): + cuda_layout="NCHW", layout=None, disabled_pass=None): """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) - layout = None if target == "cuda": layout = cuda_layout target_host = None @@ -111,7 +112,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, layout=layout, shape=shape_dict, outputs=out_names) - if mode in ['debug', 'vm']: + ctx = tvm.context(target, 0) + if mode == 'debug': ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm") inputs = [] for param in mod['main'].params: @@ -126,11 +128,19 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, inputs.append(tvm.nd.array(params[param.name_hint])) result = ex.evaluate()(*inputs) return vmobj_to_list(result) + elif mode == 'vm': + with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): + vm_exec = relay.vm.compile(mod, target="llvm", params=params) + vm = VirtualMachine(vm_exec) + vm.init(tvm.cpu()) + inputs = {} + for e, i in zip(input_node, input_data): + inputs[e] = i + result = vm.invoke("main", **inputs) + return vmobj_to_list(result) else: - with tvm.transform.PassContext(opt_level=opt_level): + with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): graph, lib, params = relay.build(mod, target, target_host, params) - - ctx = tvm.context(target, 0) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) # set inputs @@ -888,10 +898,15 @@ def test_tensor_array_scatter(): def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] + if infer_shape: + element_shape = tf.TensorShape([tf.Dimension(None)]) + else: + element_shape = None t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) indices = tf.constant([2, 1, 0]) ta1 = tf.TensorArray(dtype=dtype, size=3, - infer_shape=infer_shape) + infer_shape=infer_shape, + element_shape=element_shape) ta2 = ta1.scatter(indices, t) out0 = ta2.read(0) out1 = ta2.read(1) @@ -967,8 +982,14 @@ def test_tensor_array_size(): def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] + np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str) + in_data = [np_data, np_data] + t1 = tf.constant(np_data, dtype=dtype) + t2 = tf.constant(np_data, dtype=dtype) ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape) - out = ta1.size() + ta2 = ta1.write(0, t1) + ta3 = ta2.write(1, t2) + out = ta3.size() g = tf.get_default_graph() compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') for dtype in ["float32", "int8"]: @@ -2267,6 +2288,48 @@ def test_forward_resnetv2(): tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) +####################################################################### +# SSD +# --- + + +def _test_ssd_impl(): + '''Test SSD with backbone MobileNet V1''' + with tf.Graph().as_default(): + graph_def = tf_testing.get_workload( + "object_detection/ssd_mobilenet_v1_ppn_shared_" + "box_predictor_300x300_coco14_sync_2018_07_03.pb") + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + data = np.random.uniform(0.0, 255.0, size=(1, 512, 512, 3)).astype('uint8') + in_node = "image_tensor" + out_node = ['detection_boxes', "detection_scores", "detection_classes"] + + with tf.Session() as sess: + tf_output = run_tf_graph( + sess, data, '{}:0'.format(in_node), ["{}:0".format(oname) for oname in out_node]) + # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready. + for device in ["llvm"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + tvm_output = run_tvm_graph(graph_def, data, in_node, len(out_node), + target=device, layout="NCHW", out_names=out_node, + mode="vm", disabled_pass=["FoldScaleAxis"]) + for i in range(len(out_node)): + tvm.testing.assert_allclose(tvm_output[i], tf_output[i], + rtol=1e-3, atol=1e-3) + +def test_forward_ssd(): + run_thread = threading.Thread(target=_test_ssd_impl, args=()) + old_stack_size = threading.stack_size(100 * 1024 * 1024) + run_thread.start() + run_thread.join() + threading.stack_size(old_stack_size) + + ####################################################################### # Placeholder # ----------- @@ -3559,7 +3622,6 @@ def test_forward_spop(): # Main # ---- if __name__ == '__main__': - # Transforms test_forward_slice() test_forward_transpose() @@ -3664,6 +3726,7 @@ def test_forward_spop(): test_forward_inception_v1() test_forward_mobilenet() test_forward_resnetv2() + test_forward_ssd() test_forward_placeholder() test_forward_ptb() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index c9b13d26894f..ff76e1c64bcb 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -1336,7 +1336,6 @@ def run(dtype, shape): p = Prelude(mod) static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() - static_tensor_array_ops.define_tensor_get_data(shape) np_data_list = [] ta_length = 3