From c6461fa902104af52277da0d25b31f9621726e2f Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 10 Apr 2020 18:43:23 -0700 Subject: [PATCH] [Frontend][TensorFlow]Improve TensorFlow Static Shape Tensor Array (#5243) * Support TF Frontend Static TensorArray * Fix pylint * Fix lint * Move get_tensor_array_shape into prelude * Fix lint * Fix common --- python/tvm/relay/frontend/common.py | 41 ++- python/tvm/relay/frontend/tensorflow.py | 311 ++++++++++++++++-- python/tvm/relay/prelude.py | 49 ++- .../frontend/tensorflow/test_forward.py | 144 +++++--- topi/python/topi/util.py | 4 +- 5 files changed, 450 insertions(+), 99 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 5465e50bba435..e86890f3639a5 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -456,22 +456,20 @@ def get_name(node): def infer_type(node, mod=None): """A method to infer the type of an intermediate node in the relay graph.""" - new_mod = IRModule.from_expr(node) - if mod is not None: - new_mod.update(mod) - new_mod = _transform.InferType()(new_mod) - entry = new_mod["main"] - return entry if isinstance(node, _function.Function) else entry.body + if isinstance(mod, IRModule): + mod["main"] = _function.Function([], node) + mod = _transform.InferType()(mod) + entry = mod["main"] + ret = entry.body + else: + new_mod = IRModule.from_expr(node) + if mod is not None: + new_mod.update(mod) + new_mod = _transform.InferType()(new_mod) + entry = new_mod["main"] + ret = entry if isinstance(node, _function.Function) else entry.body -def infer_shape(inputs, mod=None): - """A method to get the output type of an intermediate node in the graph.""" - out_type = infer_type(inputs, mod=mod) - checked_type = out_type.checked_type - if hasattr(checked_type, 'shape'): - # Regular operator that outputs tensors - return get_const_tuple(out_type.checked_type.shape) - # The return type is not a tensor, for example List - return checked_type + return ret def infer_channels(inputs, transpose=False): """A hack for getting 'channels' or 'units' since caffe2 does not provide @@ -483,6 +481,17 @@ def infer_channels(inputs, transpose=False): return channels +def infer_shape(inputs, mod=None): + """A method to get the output type of an intermediate node in the graph.""" + out_type = infer_type(inputs, mod=mod) + checked_type = out_type.checked_type + if hasattr(checked_type, 'shape'): + # Regular operator that outputs tensors + return get_const_tuple(checked_type.shape) + # The return type is not a tensor, for example List + return checked_type + + def infer_value(input_val, params, mod=None): """A hack for getting the value of an expression by evaluating a portion of the relay graph. This is often needed for functions that @@ -505,7 +514,7 @@ def infer_value(input_val, params, mod=None): return m.get_output(0) except Exception: if isinstance(mod, IRModule): - mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val) + mod["main"] = _function.Function(analysis.free_vars(input_val), input_val) else: mod = IRModule.from_expr(input_val) exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm") diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 8a72423815c66..120631ea31dcc 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -26,13 +26,14 @@ import tvm from tvm.ir import IRModule -from tvm.relay.prelude import Prelude +from tvm.relay.prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape from tvm.ir import structural_hash as s_hash from .. import analysis from .. import expr as _expr from .. import function as _function from .. import op as _op +from ..ty import Any from ..expr_functor import ExprMutator, ExprVisitor from .common import AttrCvt, get_relay_op from .common import infer_type as _infer_type @@ -259,8 +260,6 @@ def _impl(inputs, attr, params, mod): if opname == 'conv_transpose' and attr['data_format'] == 'NHWC': # transform to NCHW for TVM backend compatible and set 'flip_layout' # to have output flip back to NHWC - tmp_shape = _infer_shape(inputs[2], mod) - tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2)) attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ attr['strides'][3], attr['strides'][1], attr['strides'][2] @@ -789,25 +788,152 @@ def _impl(inputs, attr, params, mod): def _tensor_array(): def _impl(inputs, attr, params, prelude): + try: + from tensorflow.python.framework import tensor_util + except ImportError as e: + raise ImportError( + "Unable to import tensorflow which is required {}".format(e)) + dtype_str = attr.get('dtype').name - tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) - return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) + assert not attr["dynamic_size"], "Dynamic size tensor array is " \ + "not supported in TVM yet." + + raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape']) + elem_shape = [] + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(dim) + + if elem_shape: + # Element shape is specified. + # Directly create static tensor array with given shape. + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + elem_shape) + static_tensor_array_ops.register() + tensor_array_constructor = prelude.get_var_static('tensor_array', + dtype_str, + elem_shape) + tensor_array = tensor_array_constructor(inputs[0]) + _static_tensor_array_map[tensor_array] = tensor_array + elif attr['identical_element_shapes']: + # identical_element_shapes is set but element shape is not given. + # We create a static tensor array with dummy shape and record it in + # _static_tensor_array_map. Later when creating other tensor array ops + # which uses this tensor array, we reconstruct this tensor array with + # actual shape. + dummy_shape = () + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + dummy_shape) + static_tensor_array_ops.register() + tensor_array_constructor = prelude.get_var_static('tensor_array', + dtype_str, + dummy_shape) + tensor_array = tensor_array_constructor(inputs[0]) + _static_tensor_array_map[tensor_array] = None + else: + tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) + tensor_array = tensor_array_constructor(inputs[0]) + return tensor_array return _impl def _tensor_array_scatter(): def _impl(inputs, attr, params, prelude): dtype_str = attr.get('T').name - values_rank = len(inputs[2].type_annotation.shape) - unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) - unstack_function = prelude.get_var(unstack_name, dtype_str) - values = unstack_function(inputs[2]) - tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) - return tensor_array_scatter_func(inputs[0], inputs[1], values) + input_ta = inputs[0] + input_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + values_shape = _infer_shape(inputs[2], prelude.mod) + input_t_shape = values_shape[1:] + indices_shape = _infer_shape(inputs[1], prelude.mod) + + if input_shape is None: + values_rank = len(values_shape) + unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) + unstack_function = prelude.get_var(unstack_name, dtype_str) + values = unstack_function(inputs[2]) + tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_t_shape) + static_tensor_array_ops.register() + # For scatter operation, it is possible to write to a newly create + # tensor array. We need to check and recreate its input tensor array. + if input_ta in _static_tensor_array_map and \ + _static_tensor_array_map[input_ta] is None: + ta_constructor = prelude.get_var_static('tensor_array', + dtype_str, + input_t_shape) + new_ta = ta_constructor(input_ta.args[0]) + _static_tensor_array_map[input_ta] = new_ta + input_ta = new_ta + + # Register static indices shape + if isinstance(indices_shape[0], int): + static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) + tensor_array_scatter_func = prelude.get_var_static('tensor_array_scatter', + dtype_str, + input_t_shape) + + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + values_shape) + static_tensor_array_ops.register() + unstack_function = prelude.get_var_static('tensor_array_unstack', + dtype_str, + values_shape) + values = unstack_function(inputs[2]) + ret = tensor_array_scatter_func(input_ta, inputs[1], values) + return ret return _impl def _tensor_array_gather(): def _impl(inputs, attr, params, prelude): - return prelude.tensor_array_gather(inputs[2], inputs[1]) + dtype_str = attr.get('dtype').name + input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude) + indices_shape = _infer_shape(inputs[1], prelude.mod) + + if input_shape is None: + gather_func = prelude.get_var('tensor_array_gather', dtype_str) + out = gather_func(inputs[2], inputs[1]) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_shape) + static_tensor_array_ops.register() + if not isinstance(indices_shape[0], int): + gather_function = prelude.get_var_static('tensor_array_gather', + dtype_str, + input_shape) + out_tensor_t = gather_function(inputs[2], inputs[1]) + + # Output shape is (indices_shape[0],) + input_shape + static_tensor_array_ops.define_tensor_get_data((indices_shape[0],) + input_shape) + get_data_func = prelude.get_var_static('tensor_get_data', + dtype_str, + input_shape) + out = get_data_func(out_tensor_t) + else: + # For fixed length indices, directly generate static shape output + read_func = prelude.get_var_static('tensor_array_read', + dtype_str, + input_shape) + static_tensor_array_ops.define_tensor_get_data(input_shape) + get_data_func = prelude.get_var_static('tensor_get_data', + dtype_str, + input_shape) + tensor_list = [] + for i in range(indices_shape[0]): + index = _op.take(inputs[1], tvm.relay.const(i)) + out_tensor = get_data_func(read_func(inputs[2], index)) + tensor_list.append(_op.expand_dims(out_tensor, axis=0)) + + out = _op.concatenate(tensor_list, axis=0) + + return out return _impl def _tensor_array_size(): @@ -817,37 +943,163 @@ def _impl(inputs, attr, params, prelude): def _tensor_array_write(): def _impl(inputs, attr, params, prelude): - input_rank = len(inputs[2].type_annotation.shape) - dtype = attr.get('T').name + dtype_str = attr.get('T').name + input_ta = inputs[3] + input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + input_t_shape = _infer_shape(inputs[2], prelude.mod) + input_rank = len(input_t_shape) + + if input_ta_shape is None: + tensor_name = 'tensor{}'.format(input_rank) + tensor_func = prelude.get_var(tensor_name, dtype_str) + v = tensor_func(inputs[2]) + write_func = prelude.get_var('tensor_array_write', dtype_str) + else: + # For write operation, it is possible to write to a newly create + # tensor array. We need to check and recreate its input tensor array. + if input_ta in _static_tensor_array_map and \ + _static_tensor_array_map[input_ta] is None: + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_t_shape) + static_tensor_array_ops.register() + ta_constructor = prelude.get_var_static('tensor_array', + dtype_str, + input_t_shape) + new_ta = ta_constructor(input_ta.args[0]) + _static_tensor_array_map[input_ta] = new_ta + input_ta = new_ta + input_ta_shape = input_t_shape + else: + input_ta_rank = len(input_ta_shape) + assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ + format(input_ta_rank, input_rank) + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_ta_shape) + static_tensor_array_ops.register() - tensor_name = 'tensor{}'.format(input_rank) - tensor_func = prelude.get_var(tensor_name, dtype) - v = tensor_func(inputs[2]) - write_func = prelude.get_var('tensor_array_write', dtype) + tensor_func = prelude.get_var_static("tensor_constructor", + dtype_str, + input_ta_shape) + v = tensor_func(inputs[2]) + write_func = prelude.get_var_static('tensor_array_write', + dtype_str, + input_ta_shape) - return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) + return write_func(input_ta, _op.take(inputs[1], tvm.relay.const(0)), v) return _impl def _tensor_array_read(): def _impl(inputs, attr, params, prelude): - read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) - return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + dtype_str = attr['dtype'].name + input_shape = get_tensor_array_shape(inputs[2], dtype_str, prelude) + + if input_shape is None: + read_func = prelude.get_var('tensor_array_read', dtype_str) + out = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_shape) + static_tensor_array_ops.register() + static_tensor_array_ops.define_tensor_get_data(input_shape) + read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape) + out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + get_data_func = prelude.get_var_static('tensor_get_data', + dtype_str, + input_shape) + out = get_data_func(out_tensor) + + return out return _impl def _tensor_array_split(): def _impl(inputs, attr, params, prelude): - input_rank = len(inputs[1].type_annotation.shape) dtype_str = attr.get('T').name - v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) + input_ta = inputs[0] + input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + input_t_shape = _infer_shape(inputs[1], prelude.mod) + input_rank = len(input_t_shape) lengths = _op.cast(inputs[2], 'int32') - split_var = prelude.get_var('tensor_array_split', dtype_str) - return split_var(inputs[0], v, lengths) + lengths_shape = _infer_shape(lengths, prelude.mod) + value_shape = _infer_shape(inputs[1], prelude.mod) + + if input_ta_shape is None: + v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) + split_func = prelude.get_var('tensor_array_split', dtype_str) + else: + # For split operation, it is possible to write to a newly create + # tensor array. We need to check and recreate its input tensor array. + if input_ta in _static_tensor_array_map and \ + _static_tensor_array_map[input_ta] is None: + input_ta_shape = (Any(),) + input_t_shape[1:] + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_ta_shape) + static_tensor_array_ops.register() + ta_constructor = prelude.get_var_static('tensor_array', + dtype_str, + input_ta_shape) + new_ta = ta_constructor(input_ta.args[0]) + _static_tensor_array_map[input_ta] = new_ta + input_ta = new_ta + else: + input_ta_rank = len(input_ta_shape) + assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ + format(input_ta_rank, input_rank) + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_ta_shape) + static_tensor_array_ops.register() + + # Check static value/indices shape + if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int): + static_tensor_array_ops.define_tensor_array_split(value_shape, + lengths_shape, + True) + + tensor_func_name = prelude.get_name_static("tensor_constructor", + dtype_str, + value_shape) + if not hasattr(prelude, tensor_func_name): + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + value_shape) + static_tensor_array_ops.register() + tensor_func = prelude.get_var_static("tensor_constructor", + dtype_str, + value_shape) + v = tensor_func(inputs[1]) + split_func = prelude.get_var_static('tensor_array_split', + dtype_str, + input_ta_shape) + + return split_func(input_ta, v, lengths) return _impl def _tensor_array_concat(): def _impl(inputs, attr, params, prelude): - concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name) - return concat_func(inputs[1]) + dtype_str = attr['dtype'].name + input_shape = get_tensor_array_shape(inputs[1], dtype_str, prelude) + + if input_shape is None: + concat_func = prelude.get_var('tensor_array_concat', dtype_str) + out = concat_func(inputs[1]) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_shape) + static_tensor_array_ops.register() + concat_func = prelude.get_var_static("tensor_array_concat", dtype_str, input_shape) + out_tensor = concat_func(inputs[1]) + static_tensor_array_ops.define_tensor_get_data((Any(),) + input_shape[1:]) + get_data_func = prelude.get_var_static('tensor_get_data', + dtype_str, + input_shape) + out = get_data_func(out_tensor) + + return out return _impl def _tile(): @@ -1370,7 +1622,7 @@ def _impl(inputs, attr, params, mod): return AttrCvt( op_name="arange", - ignores=['Tidx'], + ignores=['Tidx', '_class'], extras={'start': start, 'stop': limit, 'step': delta, @@ -2084,6 +2336,9 @@ def _get_abs_layer_name(node): # 1.x. _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] +# A map to record tensor array with fixed rank shape +_static_tensor_array_map = {} + class RewriteSubgraph(ExprMutator): """ A helper class to rewrite expr in while loop function to variable diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 47c3ba7b43b07..243eace0fb946 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" -from tvm.ir import IRModule +from tvm.ir import IRModule, TypeCall from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, GlobalVar, If, const @@ -24,8 +24,51 @@ from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard -from . import op - +from . import op, transform + + +def get_tensor_array_shape(expr, dtype, prelude): + """Get the static shape of a tensor array if it has fixed rank shape. + + By design, static ADT tensor in TVM has type name in the format + of static_tensor_dim0_dim1_..._dimN_t. + + Parameters + ---------- + expr : Relay Expr + Input expression. + + dtype : str + Data type. + + prelude : Prelude + Tensor array prelude + + Returns + ------- + shape : tuple of (int, Any) or None + The output shape. None if input tensor array + has dynamic shape. + """ + mod = prelude.mod + mod["main"] = Function([], expr) + mod = transform.InferType()(mod) + checked_type = mod["main"].body.checked_type + assert isinstance(checked_type, TypeCall), "Input must be a tensor array." + ta_type_str = checked_type.args[0].func.name_hint + static_ta_ty_start = "static_tensor_{}".format(dtype) + if ta_type_str.startswith(static_ta_ty_start): + shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), '') \ + .replace("_t", '') + shape = [] + if "scalar" not in shape_str: + for dim_str in shape_str.split("_"): + if dim_str == "?": + shape.append(Any()) + else: + shape.append(int(dim_str)) + return tuple(shape) + return None def _get_name_static(canonical, dtype, shape): """Get name for static shape tensor array op corresponding diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index fdb8912b641b2..bc884bbbfa9bc 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -839,63 +839,75 @@ def test_forward_squeeze(): _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) -def test_tensor_array_constructor(): - def run(dtype_str): +####################################################################### +# TensorArray +# ----------- +def test_tensor_array_write_read(): + def run(dtype_str, infer_shape, element_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] - t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) - t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) - ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) - ta2 = ta1.write(0, t) + np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str) + in_data = [np_data, np_data] + t1 = tf.constant(np_data, dtype=dtype) + t2 = tf.constant(np_data, dtype=dtype) + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape, + element_shape=element_shape) + ta2 = ta1.write(0, t1) ta3 = ta2.write(1, t2) out = ta3.read(0) g = tf.get_default_graph() - compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') - for dtype in tf_dtypes.keys(): - run(dtype) + compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='vm') + + for dtype in ["float32", "int8"]: + run(dtype, False, None) + run(dtype, False, tf.TensorShape([None, 2])) + run(dtype, True, None) def test_tensor_array_scatter(): - def run(dtype_str): + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) indices = tf.constant([2, 1, 0]) ta1 = tf.TensorArray(dtype=dtype, size=3, - infer_shape=False, dynamic_size=False) + infer_shape=infer_shape) ta2 = ta1.scatter(indices, t) out0 = ta2.read(0) out1 = ta2.read(1) out2 = ta2.read(2) g = tf.get_default_graph() - compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') - compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') - compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') - for dtype in tf_dtypes.keys(): - run(dtype) - -# TODO(wweic): Fix gather issue with PartialEvaluate -# def test_tensor_array_gather(): -# with tf.Graph().as_default(): -# dtype = 'float32' -# t = tf.constant([[1.0], [2.0], [3.0]]) -# scatter_indices = tf.constant([2, 1, 0]) -# gather_indices = tf.constant([1, 2]) -# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False) -# ta2 = ta1.scatter(scatter_indices, t) -# t1 = ta2.gather(gather_indices) -# g = tf.get_default_graph() -# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='vm') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='vm') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='vm') + for dtype in ["float32", "int8"]: + run(dtype, False) + run(dtype, True) + + +def test_tensor_array_gather(): + def run(dtype_str, infer_shape): + with tf.Graph().as_default(): + dtype = tf_dtypes[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str)) + scatter_indices = tf.constant([2, 1, 0]) + gather_indices = tf.constant([1, 2]) + ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape) + ta2 = ta1.scatter(scatter_indices, t) + t1 = ta2.gather(gather_indices) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='vm') + for dtype in ["float32", "int8"]: + run(dtype, True) def test_tensor_array_split(): - def run(dtype_str): + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) - ta1 = tf.TensorArray(dtype=dtype, size=4, - infer_shape=False, dynamic_size=False) + ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=infer_shape) ta2 = ta1.split(t, split_length) out0 = ta2.read(0) out1 = ta2.read(1) @@ -906,56 +918,76 @@ def run(dtype_str): compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug') - for dtype in tf_dtypes.keys(): - run(dtype) + for dtype in ["float32", "int8"]: + run(dtype, False) + run(dtype, True) def test_tensor_array_concat(): - def run(dtype_str): + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) ta1 = tf.TensorArray(dtype=dtype, size=4, - infer_shape=False, dynamic_size=False) + infer_shape=infer_shape) ta2 = ta1.split(t, split_length) t = ta2.concat() out = tf.identity(t) compare_tf_with_tvm([], [], ['Identity:0'], mode='debug') - for dtype in tf_dtypes.keys(): - run(dtype) + for dtype in ["float32", "int8"]: + run(dtype, False) + run(dtype, True) def test_tensor_array_size(): - def run(dtype_str): + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] - ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape) out = ta1.size() g = tf.get_default_graph() compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') - for dtype in tf_dtypes.keys(): - run(dtype) + for dtype in ["float32", "int8"]: + run(dtype, False) + run(dtype, True) + + +def test_tensor_array_stack(): + def run(dtype_str, infer_shape): + with tf.Graph().as_default(): + dtype = tf_dtypes[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str)) + scatter_indices = tf.constant([2, 1, 0]) + ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape) + ta2 = ta1.scatter(scatter_indices, t) + t1 = ta2.stack() + print(t1) + g = tf.get_default_graph() + + compare_tf_with_tvm([], [], ['TensorArrayStack/TensorArrayGatherV3:0'], mode='vm') + for dtype in ["float32", "int8"]: + run(dtype, True) + def test_tensor_array_unstack(): - def run(dtype_str, input_shape): + def run(dtype_str, input_shape, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] t = tf.constant(np.random.choice([0, 1, 2, 3], size=input_shape).astype(dtype.name)) - ta1 = tf.TensorArray(dtype=dtype, infer_shape=False, size=input_shape[0]) + ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0]) ta2 = ta1.unstack(t) out0 = ta2.size() out1 = ta2.read(0) compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') - for dtype in tf_dtypes.keys(): - run(dtype, (5,)) - run(dtype, (5, 5)) - run(dtype, (5, 5, 5)) - run(dtype, (5, 5, 5, 5)) - run(dtype, (5, 5, 5, 5, 5)) - run(dtype, (5, 5, 5, 5, 5, 5)) + for dtype in ["float32", "int8"]: + run(dtype, (5,), False) + run(dtype, (5, 5), True) + run(dtype, (5, 5, 5), False) + run(dtype, (5, 5, 5, 5), True) + ####################################################################### # ConcatV2 @@ -3241,6 +3273,16 @@ def test_forward_isfinite(): test_forward_reduce() test_forward_mean() + # TensorArray + test_tensor_array_write_read() + test_tensor_array_concat() + test_tensor_array_scatter() + test_tensor_array_gather() + test_tensor_array_size() + test_tensor_array_split() + test_tensor_array_stack() + test_tensor_array_unstack() + # General test_forward_multi_input() test_forward_multi_output() diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 681535761f83a..50a6a36edc46a 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -166,12 +166,14 @@ def get_const_tuple(in_tuple): """ ret = [] for elem in in_tuple: - if isinstance(elem, tvm.tir.Var): + if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)): ret.append(elem) elif not isinstance(elem, (tvm.tir.IntImm, int)): elem = tvm.tir.ir_pass.Simplify(elem) if not isinstance(elem, tvm.tir.IntImm): ret.append(elem) + else: + ret.append(get_const_int(elem)) else: ret.append(get_const_int(elem)) return tuple(ret)