From 9115212b95524650fb2f5cdd6b82cbc1f8753478 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Sun, 5 Apr 2020 13:42:28 -0700 Subject: [PATCH] [Relay][ADT]Static Tensor Array (#5103) * Add other static tensor array ops * Add tensor array get data * Minor refactor * Fix pylint * Update docstring * Make get data more generic * Improve test * Improve split test * Improve get data * Minor fix * Further improvement for static shape * Improve shape parsing * Unify get_static_name --- python/tvm/relay/prelude.py | 548 +++++++++++++++++++++++++++++++++ tests/python/relay/test_adt.py | 405 +++++++++++++++++++++++- 2 files changed, 952 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 0e64a2fd1248..47c3ba7b43b0 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -27,6 +27,545 @@ from . import op +def _get_name_static(canonical, dtype, shape): + """Get name for static shape tensor array op corresponding + to the canonical name""" + shape_str = '_'.join([str(dim) for dim in shape]) + if len(shape_str) == 0: + shape_str = "scalar" + if canonical == 'tensor_t': + return 'static_tensor_{}_{}_t'.format(dtype, shape_str) + return "{}_{}_{}".format(canonical, dtype, shape_str) + +class StaticTensorArrayOps(object): + """Contains tensor array related ops for fixed rank tensor array""" + + def __init__(self, prelude, dtype, shape): + """Create tensor array ops registry""" + self.prelude = prelude + self.dtype = dtype + self.shape = shape + + def get_name(self, canonical): + """Get name corresponding to the canonical name""" + return _get_name_static(canonical, self.dtype, self.shape) + + def get_var(self, canonical): + """Get var corresponding to the canonical name""" + name = self.get_name(canonical) + return getattr(self.prelude, name) + + def define_tensor_adt(self): + """Defines the static tensor ADT, which is the container for tensors + with fixed shapes.""" + tensor_type_name = self.get_name('tensor_t') + # Skip register if tensor type is already registered. + global_type_names = set() + for g_ty_var in self.prelude.mod.get_global_type_vars(): + global_type_names.add(g_ty_var.name_hint) + if tensor_type_name in global_type_names: + return + + tensor_type_var = GlobalTypeVar(tensor_type_name) + setattr(self.prelude, tensor_type_name, tensor_type_var) + tensor_type = TensorType(self.shape, self.dtype) + tensor_constructor_name = self.get_name('tensor_constructor') + + tensor_nil_name = self.get_name('tensor_nil') + tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) + tensor_case = Constructor(tensor_constructor_name, [tensor_type], tensor_type_var) + + setattr(self.prelude, tensor_nil_name, tensor_nil_case) + setattr(self.prelude, tensor_constructor_name, tensor_case) + self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, + [], + [tensor_nil_case, tensor_case]) + + def define_tensor_array(self): + """Defines a function to create a tensor array with size n. + tensor_array(n) : Tensor[(), int32] -> list[tensor_t] + """ + tensor_array_constructor_name = self.get_name("tensor_array") + tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name) + setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) + tensor_nil_var = self.get_var('tensor_nil') + tensor_type_var = self.get_var('tensor_t') + n = Var("x", scalar_type('int32')) + body = If(equal(n, const(0)), + self.prelude.nil(), + self.prelude.cons(tensor_nil_var(), + tensor_array_constructor_var(subtract(n, const(1))))) + self.prelude.mod[tensor_array_constructor_var] = \ + Function([n], body, self.prelude.l(tensor_type_var()), []) + + def define_tensor_take(self): + """Defines a function to return a range of tensor_t on axis 0. + tensor_take(t, lower, upper) : + tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t + """ + # We don't register take for scalar tensor. + ndim = len(self.shape) + if ndim == 0: + return + + take_name = self.get_name("tensor_take") + take_var = self._create_global_var(take_name) + setattr(self.prelude, take_name, take_var) + origin_tensor_constructor = self.get_var('tensor_constructor') + + output_shape = [Any(),] + list(self.shape[1:]) + tensor_type_var, tensor_constructor = \ + self._get_adt_by_shape(output_shape) + + t = Var('tensor', self.get_var('tensor_t')()) + lower = Var('lower', scalar_type('int32')) + upper = Var('upper', scalar_type('int32')) + tvar = Var('t') + case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]), + tensor_constructor(op.take(tvar, + op.arange(lower, upper, dtype='int32'), + axis=0))) + self.prelude.mod[take_var] = \ + Function([t, lower, upper], + Match(t, [case], False), tensor_type_var(), []) + + def define_tensor_concatenate(self): + """Defines a function to concatenate two tensor_t on axis 0. + tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t + """ + # We don't register concatenate for scalar tensor. + ndim = len(self.shape) + if ndim == 0: + return + + concat_name = self.get_name("tensor_concatenate") + concat_var = self._create_global_var(concat_name) + setattr(self.prelude, concat_name, concat_var) + output_shape = [Any(),] + list(self.shape[1:]) + tensor_type_var, tensor_constructor = \ + self._get_adt_by_shape(output_shape) + + origin_tensor_constructor = self.get_var('tensor_constructor') + origin_tensor_type_var = self.get_var('tensor_t') + x = Var("x", origin_tensor_type_var()) + y = Var("y", origin_tensor_type_var()) + t1 = Var("t1") + t2 = Var("t2") + + case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]), + Match(y, + [Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]), + tensor_constructor(op.concatenate([t1, t2], axis=0)))], + False)) + + self.prelude.mod[concat_var] = \ + Function([x, y], Match(x, [case], False), tensor_type_var(), []) + + + def define_tensor_expand_dims(self): + """Defines a function to grow a tensor_t's rank by adding one dimension in front + of the original tensor_t. + tensor_expand_dims(t) : tensor_t -> tensor_t + """ + expand_dims_name = self.get_name("tensor_expand_dims") + expand_dims_var = self._create_global_var(expand_dims_name) + setattr(self.prelude, expand_dims_name, expand_dims_var) + origin_tensor_type_var = self.get_var('tensor_t') + origin_tensor_constructor = self.get_var('tensor_constructor') + x = Var("x", origin_tensor_type_var()) + + # Note: we set the added axis to be Any() instead of 1 due to + # in stack op, we need to recursively concatenate. + tensor_type_var, tensor_constructor = \ + self._get_adt_by_shape([Any(),] + list(self.shape)) + t = Var("t") + case = Clause(PatternConstructor(origin_tensor_constructor, [PatternVar(t)]), + tensor_constructor(op.expand_dims(t, 0, 1))) + + self.prelude.mod[expand_dims_var] = \ + Function([x], Match(x, [case], False), tensor_type_var(), []) + + def define_tensor_array_read(self): + """Defines a function to get the nth element of a list. Assume the list has at least one + element. + tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] -> + Tensor[self.shape, self.dtype] + """ + read_name = self.get_name("tensor_array_read") + read_var = self._create_global_var(read_name) + setattr(self.prelude, read_name, read_var) + tensor_type_var = self.get_var('tensor_t') + + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + self.prelude.mod[read_var] = \ + Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []) + + def define_tensor_array_write(self): + """Defines a function to update a tensor array at index n with value v. + tensor_array_write(ta, n, v) : + list[static_tensor_t] -> Tensor[(), int32] -> Tensor[self.shape, self.dtype] -> + list[static_tensor_t] + """ + write_name = self.get_name("tensor_array_write") + write_var = self._create_global_var(write_name) + setattr(self.prelude, write_name, write_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + v = Var("v", tensor_type_var()) + self.prelude.mod[write_var] = \ + Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack(self): + """Defines a function to unstack the values of a tensor_t in a tensor array. + tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t] + """ + ndim = len(self.shape) + # We don't register unstack for scalar tensor array + if ndim == 0: + return + + helper_name = self.get_name("tensor_array_unstack_helper") + helper_var = self._create_global_var(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType(self.shape, self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + tensor_var = Var("tensor", TensorType(self.shape, self.dtype)) + + reduced_tensor_type_var, tensor_constructor = \ + self._get_adt_by_shape(self.shape[1:]) + helper_body = \ + If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(tensor_constructor(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] = \ + Function([i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), []) + + unstack_name = self.get_name("tensor_array_unstack") + unstack_var = self._create_global_var(unstack_name) + setattr(self.prelude, unstack_name, unstack_var) + shape = op.shape_of(tensor_var) + unstack_length = op.take(shape, const(0)) + self.prelude.mod[unstack_var] = \ + Function([tensor_var], helper_var(const(0), unstack_length, tensor_var), + self.prelude.l(reduced_tensor_type_var()), []) + + def define_tensor_array_scatter(self, indices_shape=None, force_update=False): + """Defines a function to scatter the values of a tensor_t in indices of a tensor array. + tensor_array_scatter(ta, indices, value) : + list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] + + Set static indices shape by specifying indices_shape. + Set force_update to get static indices shape operator. + """ + # When this operator has already been registered, only update + # when force_update is set. This should be used only when we need to + # redefine this op for static indices shape. + tensor_array_scatter_name = self.get_name("tensor_array_scatter") + if hasattr(self.prelude, tensor_array_scatter_name) and not force_update: + return + + tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") + tensor_array_scatter_helper_var = \ + self._create_global_var(tensor_array_scatter_helper_name) + tensor_type_var = self.get_var('tensor_t') + ta = Var("ta", self.prelude.l(tensor_type_var())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType(indices_shape or [Any()], 'int32')) + values_ = Var('values_', self.prelude.l(tensor_type_var())) + write_var = self.get_var('tensor_array_write') + read_var = self.get_var('tensor_array_read') + helper_body = If(equal(current, limit), + ta, + tensor_array_scatter_helper_var( + write_var(ta, op.take(indices_, current), + read_var(values_, current)), + add(current, const(1)), + limit, indices_, values_)) + self.prelude.mod[tensor_array_scatter_helper_var] = \ + Function([ta, current, limit, indices_, values_], + helper_body, self.prelude.l(tensor_type_var()), []) + + tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name) + setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + + indices = Var('indices', TensorType(indices_shape or [Any()], 'int32')) + values = Var('values', self.prelude.l(tensor_type_var())) + if indices_shape is None: + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + else: + limit = const(indices_shape[0]) + + body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) + self.prelude.mod[tensor_array_scatter_var] = \ + Function([tensor_array, indices, values], body, + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_split(self, + value_shape=None, + lengths_shape=None, + force_update=False): + """Defines a function to split the values of a tensor_t into a tensor array. + tensor_array_split(ta, value, lengths) : + list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] + + Set static value and lengths shapes by specifying value_shape and lengths_shape. + Set force_update to get static value and lengths shape operator. + """ + # Skip scalar case + ndim = len(self.shape) + if ndim == 0: + return + + # When this operator has already been registered, only update + # when force_update is set. This should be used only when we need to + # redefine this op for static value/indices shape. + split_name = self.get_name("tensor_array_split") + if hasattr(self.prelude, split_name) and not force_update: + return + + tensor_type_var = self.get_var('tensor_t') + tensor_array_split_helper_name = self.get_name("ta_split_helper") + tensor_array_split_helper_var = \ + self._create_global_var(tensor_array_split_helper_name) + setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) + output_shape = [Any(),] + list(self.shape[1:]) + output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) + + if value_shape is None: + value_type_var = tensor_type_var + take_var = self.get_var('tensor_take') + else: + value_type_var, _ = self._get_adt_by_shape(value_shape) + # Also get static shape take operator + origin_shape = list(self.shape) + self.shape = value_shape + self.define_tensor_take() + take_var = self.get_var('tensor_take') + self.shape = origin_shape + + + ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var())) + value1 = Var('value1', value_type_var()) + offset1 = Var('offset1', scalar_type('int32')) + current1 = Var('current1', scalar_type('int32')) + limit1 = Var('limit1', scalar_type('int32')) + lengths1 = Var('lengths', TensorType(lengths_shape or [Any()], 'int32')) + + # Register write for output shape + origin_shape = list(self.shape) + self.shape = output_shape + self.define_tensor_array_write() + write_var = self.get_var('tensor_array_write') + self.shape = origin_shape + helper1_body = If(equal(current1, limit1), + ta1, + write_var( + tensor_array_split_helper_var( + ta1, + value1, + add(offset1, op.take(lengths1, current1)), + add(current1, const(1)), + limit1, + lengths1 + ), + current1, + take_var(value1, + offset1, + add(op.take(lengths1, current1), offset1)))) + self.prelude.mod[tensor_array_split_helper_var] = \ + Function([ta1, value1, offset1, current1, limit1, lengths1], + helper1_body, self.prelude.l(output_tensor_type_var()), []) + split_var = self._create_global_var(split_name) + setattr(self.prelude, split_name, split_var) + tensor_array = Var("tensor_array", self.prelude.l(output_tensor_type_var())) + + value = Var('value', value_type_var()) + lengths = Var('lengths', TensorType(lengths_shape or [Any()], 'int32')) + if lengths_shape is None: + lengths_shape = op.shape_of(lengths) + lengths_limit = op.take(lengths_shape, const(0)) + else: + lengths_limit = const(lengths_shape[0]) + body = tensor_array_split_helper_var( + tensor_array, + value, + const(0), + const(0), + lengths_limit, + lengths) + + self.prelude.mod[split_var] = \ + Function([tensor_array, value, lengths], body, + self.prelude.l(output_tensor_type_var()), []) + + def define_tensor_array_concat(self): + """Defines a function to return the values in the tensor array as concatenated tensor_t. + tensor_array_concat(ta) : list[tensor_t] -> tensor_t + """ + # We don't register concat for scalar tensor array. + ndim = len(self.shape) + if ndim == 0: + return + + concat_name = self.get_name("tensor_array_concat") + concat_var = self._create_global_var(concat_name) + setattr(self.prelude, concat_name, concat_var) + + output_shape = [Any(),] + list(self.shape[1:]) + tensor_type_var, _ = self._get_adt_by_shape(output_shape) + + # Register tensor concatenate and get tensor_nil var for output shape + origin_shape = self.shape + self.shape = output_shape + self.define_tensor_concatenate() + tensor_concat_var = self.get_var('tensor_concatenate') + tensor_nil_var = self.get_var('tensor_nil') + self.shape = origin_shape + + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + hd = Var("hd") + tl = Var("tl") + nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) + cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), + Match(tl, [ + Clause(PatternConstructor(self.prelude.nil), hd), + Clause(PatternWildcard(), + tensor_concat_var(hd, concat_var(tl))) + ], False)) + self.prelude.mod[concat_var] = \ + Function([tensor_array], + Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), []) + + def define_tensor_array_stack(self): + """Defines a function to get the values in the tensor array as a stack tensor_t. + tensor_array_stack(l) : list[tensor_t] -> tensor_t + """ + stack_name = self.get_name("tensor_array_stack") + stack_var = self._create_global_var(stack_name) + setattr(self.prelude, stack_name, stack_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + expand_dims_var = self.get_var('tensor_expand_dims') + + # Register tensor_concatenate for output_shape + origin_shape = self.shape + output_shape = [Any(),] + list(self.shape) + self.shape = output_shape + self.define_tensor_concatenate() + concat_var = self.get_var('tensor_concatenate') + self.shape = origin_shape + + tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) + tensors = self.prelude.foldl(concat_var, + self.prelude.hd(tensor_array_expand_dims), + self.prelude.tl(tensor_array_expand_dims)) + output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) + self.prelude.mod[stack_var] = Function([tensor_array], tensors, + output_tensor_type_var(), []) + + def define_tensor_array_gather(self): + """Defines a function to return the selected values in a tensor array as tensor_t. + tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t + """ + helper_name = self.get_name("tensor_array_gather_helper") + helper_var = self._create_global_var(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor_type_var = self.get_var('tensor_t') + output_shape = [Any(),] + list(self.shape) + output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) + stack_var = self.get_var('tensor_array_stack') + read_var = self.get_var('tensor_array_read') + ta = Var("ta", self.prelude.l(tensor_type_var())) + accu = Var("accu", self.prelude.l(tensor_type_var())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + helper_body = \ + If(equal(current, const(0)), + stack_var(accu), + helper_var( + ta, + self.prelude.cons( + read_var( + ta, op.take(indices_, subtract(current, const(1)))), accu), + subtract(current, const(1)), + limit, indices_)) + self.prelude.mod[helper_var] = \ + Function([ta, accu, current, limit, indices_], + helper_body, output_tensor_type_var(), []) + gather_name = self.get_name("tensor_array_gather") + gather_var = self._create_global_var(gather_name) + setattr(self.prelude, gather_name, gather_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + indices = Var('indices', TensorType([Any()], 'int32')) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) + self.prelude.mod[gather_var] = \ + Function([tensor_array, indices], body, output_tensor_type_var(), []) + + def define_tensor_get_data(self, data_shape): + """Defines a function to get a Tensor from tensor_t with given shape. + """ + tensor_get_data_name = self.get_name("tensor_get_data") + tensor_get_data_var = self._create_global_var(tensor_get_data_name) + setattr(self.prelude, tensor_get_data_name, tensor_get_data_var) + + tensor_type_var, tensor_constructor = self._get_adt_by_shape(data_shape) + t = Var('tensor', tensor_type_var()) + tvar = Var('t') + case =\ + Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar) + self.prelude.mod[tensor_get_data_var] = \ + Function([t], Match(t, [case], False), + TensorType(data_shape, self.dtype), []) + + def register(self): + """Register all tensor array ops in Prelude""" + self.define_tensor_adt() + self.define_tensor_take() + self.define_tensor_concatenate() + self.define_tensor_expand_dims() + self.define_tensor_array() + self.define_tensor_array_read() + self.define_tensor_array_write() + self.define_tensor_array_unstack() + self.define_tensor_array_scatter() + self.define_tensor_array_split() + self.define_tensor_array_concat() + self.define_tensor_array_stack() + self.define_tensor_array_gather() + + def _get_adt_by_shape(self, shape): + """Get ADT type and constructor with given shape.""" + origin_shape = self.shape + self.shape = shape + self.define_tensor_adt() + tensor_type_var = self.get_var("tensor_t") + tensor_constructor = self.get_var("tensor_constructor") + self.shape = origin_shape + return tensor_type_var, tensor_constructor + + def _create_global_var(self, name): + """Create a GlobalVar if doesn't exist in prelude.""" + global_var_name_set = set() + for g_var_name in self.prelude.mod.get_global_vars(): + global_var_name_set.add(g_var_name.name_hint) + if name not in global_var_name_set: + gvar = GlobalVar(name) + else: + gvar = self.prelude.mod.get_global_var(name) + + return gvar + class TensorArrayOps(object): """Contains tensor array related ops""" @@ -666,6 +1205,15 @@ def get_var(self, canonical, dtype): name = self.get_name(canonical, dtype) return getattr(self, name) + def get_name_static(self, canonical, dtype, shape): + """Get name corresponding to the canonical name""" + return _get_name_static(canonical, dtype, shape) + + def get_var_static(self, canonical, dtype, shape): + """Get var corresponding to the canonical name""" + name = self.get_name_static(canonical, dtype, shape) + return getattr(self, name) + def load_prelude(self): """Parses the Prelude from Relay's text format into a module.""" # TODO(@jroesch): we should remove this helper when we port over prelude diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index deeb7330f9da..c9b13d26894f 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -19,7 +19,7 @@ from tvm import relay from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor -from tvm.relay.prelude import Prelude +from tvm.relay.prelude import Prelude, StaticTensorArrayOps from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr import numpy as np @@ -980,6 +980,395 @@ def run(dtype): run('float32') run('int32') +def test_static_tensor_take(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + take = p.get_var_static('tensor_take', dtype, shape) + tensor_constructor = p.get_var_static('tensor_constructor', dtype, shape) + v = relay.var('v') + lower = relay.var('lower') + upper = relay.var('upper') + mod["main"] = relay.Function([v, lower, upper], take(tensor_constructor(v), lower, upper)) + v_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [np.take(v_data, range(2, 5), axis=0)] + check_tensor_array(mod, expected, *(v_data, 2, 5), dtype=dtype) + expected = [np.take(v_data, range(0, 9), axis=0)] + check_tensor_array(mod, expected, *(v_data, 0, 9), dtype=dtype) + run('float32', [10, 10]) + run('int32', [15, 11]) + + +def test_static_tensor_concatenate(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + concat = p.get_var_static('tensor_concatenate', dtype, shape) + tensor = p.get_var_static('tensor_constructor', dtype, shape) + v1 = relay.var('v1') + v2 = relay.var('v2') + mod["main"] = relay.Function([v1, v2], concat(tensor(v1), + tensor(v2))) + v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [np.concatenate((v1_data, v2_data))] + check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) + run('float32', [5,]) + run('int32', [2, 3]) + + +def test_static_tensor_expand_dims(): + def run(dtype, shape): + x = relay.var('x') + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + expand_dims_func = p.get_var_static('tensor_expand_dims', dtype, shape) + tensor = p.get_var_static('tensor_constructor', dtype, shape) + mod["main"] = relay.Function([x], expand_dims_func(tensor(x))) + x_np = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [np.expand_dims(x_np, axis=0)] + check_tensor_array(mod, expected, x_np) + run('float32', []) + run('int32', [2,]) + + +def test_static_tensor_array_constructor(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + tensor_constructor = p.get_name_static('tensor_constructor', dtype, shape) + assert tensor_constructor != None + run('float32', [1, 1]) + + +def test_static_tensor_array_read(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + np_data_list = [] + ta_length = 3 + for _ in range(ta_length): + np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype)) + + v0 = relay.var('v0') + v1 = relay.var('v1') + v2 = relay.var('v2') + n = relay.var('n') + tensor = p.get_var_static('tensor_constructor', dtype, shape) + tensor_array = p.get_var_static('tensor_array', dtype, shape) + init_tensor_array = tensor_array(relay.const(ta_length)) + read_func = p.get_var_static('tensor_array_read', dtype, shape) + write_func = p.get_var_static('tensor_array_write', dtype, shape) + tensor_array0 = write_func(init_tensor_array, relay.const(0), + tensor(v0)) + tensor_array1 = write_func(tensor_array0, relay.const(1), + tensor(v1)) + tensor_array2 = write_func(tensor_array1, relay.const(2), + tensor(v2)) + + mod["main"] = relay.Function([v0, v1, v2, n], read_func(tensor_array2, n)) + expected = [np_data_list[0]] + check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype) + expected = [np_data_list[1]] + check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype) + expected = [np_data_list[2]] + check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype) + run('float32', []) + run('int32', [2, 3]) + + +def test_static_tensor_array_write(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + ta_length = 2 + np_data_list = [np.random.uniform(0, 10, size=shape).astype(dtype) for _ in range(ta_length)] + + v0 = relay.var('v0') + v1 = relay.var('v1') + tensor_array = p.get_var_static('tensor_array', dtype, shape) + init_tensor_array = tensor_array(relay.const(ta_length)) + write_func = p.get_var_static('tensor_array_write', dtype, shape) + tensor = p.get_var_static('tensor_constructor', dtype, shape) + tensor_array0 = write_func(init_tensor_array, relay.const(0), + tensor(v0)) + tensor_array1 = write_func(tensor_array0, relay.const(1), tensor(v1)) + mod["main"] = relay.Function([v0, v1], tensor_array1) + expected = np_data_list + check_tensor_array(mod, expected, *np_data_list, dtype=dtype) + run('float32', []) + run('int32', [2, 3]) + + +def test_static_tensor_array_unstack(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + unstack_tensor = p.get_var_static('tensor_array_unstack', dtype, shape) + v = relay.var('v') + mod["main"] = relay.Function([v], unstack_tensor(v)) + t = np.random.uniform(low=0, high=10, size=shape).astype(dtype) + *expected, = t + check_tensor_array(mod, expected, t, dtype=dtype) + run('float32', [4]) + run('int32', [2, 3]) + + +def test_static_tensor_array_scatter(): + def run(dtype, shape, indices_shape=None): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + if indices_shape is not None: + static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) + + # tensor array + v1 = relay.var('v1') + v2 = relay.var('v2') + v3 = relay.var('v2') + tensor_array = p.get_var_static('tensor_array', dtype, shape) + tensor_array0 = tensor_array(relay.const(3)) + write_func = p.get_var_static('tensor_array_write', dtype, shape) + scatter_func = p.get_var_static('tensor_array_scatter', dtype, shape) + tensor = p.get_var_static('tensor_constructor', dtype, shape) + tensor_array1 = write_func(tensor_array0, relay.const(0), tensor(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) + tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3)) + + # indices array + index = relay.var('index') + + # values array + value_0 = relay.var('value_0') + value_1 = relay.var('value_1') + values_array = tensor_array(relay.const(2)) + values_array = write_func(values_array, relay.const(0), + tensor(value_0)) + values_array = write_func(values_array, relay.const(1), + tensor(value_1)) + + # create the scatter function + tensor_array_scatter = scatter_func(tensor_array1, index, values_array) + mod["main"] = relay.Function([v1, v2, v3, index, value_0, value_1], + tensor_array_scatter) + + # initialize and check + v1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + v3_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + index_data = np.array([0, 1], dtype="int32") + val1_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + val2_data = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [val1_data, val2_data, v3_data] + check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data, + index_data, val1_data, + val2_data), dtype=dtype) + run('float32', [2, 3]) + run('int32', [2, 3]) + run('float32', [2, 3], [2,]) + + +def test_static_tensor_array_split(): + def run(dtype, shape, value_shape=None, lengths_shape=None): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + if value_shape is not None or lengths_shape is not None: + static_tensor_array_ops.define_tensor_array_split(value_shape, lengths_shape, True) + + # tensor array + v1 = relay.var('v1') + v2 = relay.var('v2') + v3 = relay.var('v2') + + adt_shape = [relay.Any(),] + shape[1:] + origin_shape = static_tensor_array_ops.shape + static_tensor_array_ops.shape = adt_shape + static_tensor_array_ops.define_tensor_array() + tensor_array = p.get_var_static('tensor_array', dtype, adt_shape) + static_tensor_array_ops.shape = origin_shape + tensor_array1 = tensor_array(relay.const(3)) + write_func = p.get_var_static('tensor_array_write', dtype, adt_shape) + split_func = p.get_var_static('tensor_array_split', dtype, shape) + tensor = p.get_var_static('tensor_constructor', dtype, adt_shape) + tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) + tensor_array1 = write_func(tensor_array1, relay.const(2), tensor(v3)) + + # value tensor + value = relay.var('value') + + # lengths tensor + ta_len = relay.var('length') + + # create the split function + if value_shape is None: + tensor1 = p.get_var_static('tensor_constructor', dtype, shape) + else: + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, value_shape) + static_tensor_array_ops.register() + tensor1 = p.get_var_static('tensor_constructor', dtype, value_shape) + tensor_array_split = split_func(tensor_array1, tensor1(value), ta_len) + mod["main"] = relay.Function([v1, v2, v3, value, ta_len], + tensor_array_split) + + # initialize and check + v1_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) + v3_data = np.random.uniform(low=0.0, high=8.0, size=[2, 3]).astype(dtype) + value_data = np.random.uniform(low=0.0, high=8.0, + size=value_shape or shape).astype(dtype) + length_data = np.array([2, 2], dtype="int32") + expected = np.concatenate([value_data, v3_data]) + expected = np.split(expected, indices_or_sections=[2, 4]) + check_tensor_array(mod, expected, *(v1_data, v2_data, v3_data, + value_data, length_data), + dtype=dtype) + + run('float32', [4, 3]) + run('int32', [4, 3]) + run('int32', [relay.Any(), 3], [4, 3], [2,]) + + +def test_static_tensor_array_concat(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + v1 = relay.var('v1') + v2 = relay.var('v2') + tensor_array = p.get_var_static('tensor_array', dtype, shape) + tensor_array1 = tensor_array(relay.const(2)) + write_func = p.get_var_static('tensor_array_write', dtype, shape) + concat_func = p.get_var_static('tensor_array_concat', dtype, shape) + tensor = p.get_var_static('tensor_constructor', dtype, shape) + tensor_array1 = write_func(tensor_array1, relay.const(0), tensor(v1)) + tensor_array1 = write_func(tensor_array1, relay.const(1), tensor(v2)) + tensor_array_concat = concat_func(tensor_array1) + mod["main"] = relay.Function([v1, v2], tensor_array_concat) + v1_data = np.random.uniform(low=0.0, high=8.0, size=(2, 3)).astype(dtype) + v2_data = np.random.uniform(low=0.0, high=8.0, size=(1, 3)).astype(dtype) + expected = [np.concatenate((v1_data, v2_data), axis=0)] + check_tensor_array(mod, expected, *(v1_data, v2_data), dtype=dtype) + run('float32', [relay.Any(), 3]) + run('int32', [relay.Any(), 3]) + + +def test_static_tensor_array_gather(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + tensor_array = p.get_var_static('tensor_array', dtype, shape) + tensor = p.get_var_static('tensor_constructor', dtype, shape) + write = p.get_var_static('tensor_array_write', dtype, shape) + gather = p.get_var_static('tensor_array_gather', dtype, shape) + v = relay.var('v') + indice = relay.var('indice') + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) + out = gather(tensor_array3, indice) + mod["main"] = relay.Function([v, indice], out) + t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + indice_data = np.array([0, 2], dtype="int32") + expected = [np.stack([t, t])] + check_tensor_array(mod, expected, *(t, indice_data), dtype=dtype) + run('float32', []) + run('int32', [2, 3]) + + +def test_static_tensor_array_stack(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + tensor_array = p.get_var_static('tensor_array', dtype, shape) + tensor = p.get_var_static('tensor_constructor', dtype, shape) + write = p.get_var_static('tensor_array_write', dtype, shape) + stack = p.get_var_static('tensor_array_stack', dtype, shape) + v = relay.var('v') + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) + tensor_array4 = stack(tensor_array3) + mod["main"] = relay.Function([v], tensor_array4) + t = np.random.uniform(low=0.0, high=8.0, size=shape).astype(dtype) + expected = [np.stack([t, t, t])] + check_tensor_array(mod, expected, t, dtype=dtype) + run('float32', []) + run('int32', [2, 3]) + + +def test_static_tensor_get_data(): + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + static_tensor_array_ops.define_tensor_get_data(shape) + + np_data_list = [] + ta_length = 3 + for _ in range(ta_length): + np_data_list.append(np.random.uniform(0, 10, size=shape).astype(dtype)) + + v0 = relay.var('v0') + v1 = relay.var('v1') + v2 = relay.var('v2') + n = relay.var('n') + tensor = p.get_var_static('tensor_constructor', dtype, shape) + tensor_array = p.get_var_static('tensor_array', dtype, shape) + init_tensor_array = tensor_array(relay.const(ta_length)) + read_func = p.get_var_static('tensor_array_read', dtype, shape) + write_func = p.get_var_static('tensor_array_write', dtype, shape) + get_data_func = p.get_var_static('tensor_get_data', dtype, shape) + tensor_array0 = write_func(init_tensor_array, relay.const(0), + tensor(v0)) + tensor_array1 = write_func(tensor_array0, relay.const(1), + tensor(v1)) + tensor_array2 = write_func(tensor_array1, relay.const(2), + tensor(v2)) + + mod["main"] = relay.Function([v0, v1, v2, n], get_data_func(read_func(tensor_array2, n))) + expected = [np_data_list[0]] + check_tensor_array(mod, expected, *list(np_data_list + [0]), dtype=dtype) + expected = [np_data_list[1]] + check_tensor_array(mod, expected, *list(np_data_list + [1]), dtype=dtype) + expected = [np_data_list[2]] + check_tensor_array(mod, expected, *list(np_data_list + [2]), dtype=dtype) + run('float32', []) + run('int32', [2, 3]) if __name__ == "__main__": test_nat_constructor() @@ -1016,3 +1405,17 @@ def run(dtype): test_tensor_array_concat() test_tensor_array_scatter() test_tensor_array_split() + + test_static_tensor_take() + test_static_tensor_concatenate() + test_static_tensor_expand_dims() + test_static_tensor_array_constructor() + test_static_tensor_array_read() + test_static_tensor_array_write() + test_static_tensor_array_unstack() + test_static_tensor_array_scatter() + test_static_tensor_array_split() + test_static_tensor_array_concat() + test_static_tensor_array_stack() + test_static_tensor_array_gather() + test_static_tensor_get_data()