diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index 81c6e5e9229d..8279b1249ced 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -17,10 +17,13 @@ """Backend compiler related feature registration""" # pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments from __future__ import absolute_import + +from tvm.runtime import convert from tvm.te.hybrid import script from .. import op as _reg _reg.register_injective_schedule("dyn.reshape") +_reg.register_broadcast_schedule("dyn.tile") @script def _reshape_shape_func_input_data(data, newshape, ndim): @@ -81,3 +84,40 @@ def _reshape_shape_func_input_data(data, newshape, ndim): @_reg.register_shape_func("dyn.reshape", True) def dynamic_reshape_shape_func(attrs, inputs, out_ndims): return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] + + +@script +def _tile_shape_func(data, reps, ndim, tndim, rndim): + out = output_tensor((tndim,), "int64") + + if ndim == rndim: + for i in const_range(tndim): + out[i] = int64(data.shape[i] * reps[i]) + elif ndim > rndim: + ngap = ndim - rndim + for i in const_range(ndim): + if i < ngap: + out[i] = int64(data.shape[i]) + else: + out[i] = int64(data.shape[i] * reps[i - ngap]) + else: + rgap = rndim - ndim + for i in const_range(rndim): + if i < rgap: + out[i] = int64(reps[i]) + else: + out[i] = int64(reps[i] * data.shape[i - rgap]) + return out + + +@_reg.register_shape_func("dyn.tile", True) +def tile_shape_func(attrs, inputs, _): + """ + Shape function for dyn.tile op. + """ + reps = inputs[1] + ndim = len(inputs[0].shape) + rndim = inputs[1].shape[0].value + tndim = ndim if ndim > rndim else rndim + return [_tile_shape_func(inputs[0], reps, convert(ndim), + convert(tndim), convert(rndim))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 188cd5cbcd85..173db64de258 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -496,7 +496,7 @@ def tile(data, reps): data : relay.Expr The input data to the operator. - reps : tuple of int + reps : tuple of int or relay.Expr The number of times repeating the tensor data. Returns @@ -524,7 +524,8 @@ def tile(data, reps): data is promoted to be d-dimensional by prepending new axes. If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it. """ - + if isinstance(reps, Expr): + return _dyn_make.tile(data, reps) return _make.tile(data, reps) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 18eaa67242f5..0b8a15676cc3 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -29,6 +29,8 @@ #include #include +#include + namespace tvm { namespace relay { namespace dyn { @@ -128,6 +130,71 @@ RELAY_REGISTER_OP("dyn.reshape") .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); +// tile operator +// TVM_REGISTER_NODE_TYPE(TileAttrs); + +bool TileRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, reps, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* reps = types[1].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "tile: expect input type to be TensorType but get " << types[0]; + return false; + } + if (reps == nullptr) { + CHECK(types[1].as()) + << "tile: expect input type to be TensorType but get " << types[1]; + return false; + } + const IntImmNode* reps_shape = reps->shape[0].as(); + CHECK(reps_shape) << "Parameter reps must have static shape"; + const size_t ndim = data->shape.size(); + const size_t rndim = reps_shape->value; + size_t tndim = (ndim > rndim) ? ndim : rndim; + std::vector oshape; + oshape.reserve(tndim); + for (size_t i = 0; i < tndim; ++i) { + oshape.emplace_back(Any()); + } + reporter->Assign(types[2], TensorType(oshape, data->dtype)); + return true; +} + +Array TileCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + CHECK_EQ(inputs.size(), 2); + const auto* out_ttype = out_type.as(); + size_t rndim = inputs[1]->shape[0].as()->value; + return {topi::dyn_tile(inputs[0], out_ttype->shape, rndim)}; +} + +Expr MakeTile(Expr data, Expr reps) { + auto attrs = make_object(); + static const Op& op = Op::Get("dyn.tile"); + return Call(op, {data, reps}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.tile").set_body_typed(MakeTile); + +RELAY_REGISTER_OP("dyn.tile") + .describe(R"code(Repeat the whole array multiple times. + +- **data**: The input data to the operator. +- **reps**: The number of times to repeat the operator. + +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("reps", "Tensor", "The number of times to repeat the input on each axis.") + .set_support_level(3) + .add_type_rel("DynamicTile", TileRel) + .set_attr("FTVMCompute", TileCompute) + .set_attr("TOpPattern", kInjective); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 7b3f1957811b..d09230ac30d7 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -32,7 +32,8 @@ namespace relay { class DynamicToStaticMutator : public MixedModeMutator { public: - DynamicToStaticMutator() : dyn_reshape_op_(Op::Get("dyn.reshape")) {} + DynamicToStaticMutator() + : dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {} private: Expr Rewrite_(const CallNode* pre, const Expr& post) override { @@ -46,6 +47,14 @@ class DynamicToStaticMutator : public MixedModeMutator { static const Op& reshape = Op::Get("reshape"); return Call(reshape, {call_node->args[0]}, Attrs(attrs), {}); } + } else if (call_node->op == dyn_tile_op_) { + if (const ConstantNode* reps = call_node->args[1].as()) { + auto attrs = make_object(); + CHECK_EQ(reps->data->ndim, 1); + attrs->reps = ToVector(reps->data); + static const Op& op = Op::Get("tile"); + return Call(op, {call_node->args[0]}, Attrs(attrs), {}); + } } return post; } @@ -58,6 +67,7 @@ class DynamicToStaticMutator : public MixedModeMutator { } const Op& dyn_reshape_op_; + const Op& dyn_tile_op_; }; Expr DynamicToStatic(Function f, IRModule m) { diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 29168b6d0d1a..2f473c9de070 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -70,6 +70,21 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4), (8, 3), (8, 3)) verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) +def test_dyn_tile(): + def verify_tile(dshape, reps): + x = relay.var("x", relay.TensorType(dshape, "float32")) + r = relay.var("reps", relay.TensorType((len(reps), ), "float32")) + z = relay.tile(x, r) + + func = relay.Function([x, r], z) + x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") + ref_res = np.tile(x_data, reps=reps) + verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res) + verify_tile((2, 3, 4), (3, 2, 1)) + verify_tile((2, 3, 4), (1, 2)) + verify_tile((2, 3), (3, 2, 1)) + if __name__ == "__main__": test_dyn_reshape() test_dyn_shape_reshape() + test_dyn_tile() diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 052d95cef0a7..3415ce01d5fd 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -108,8 +108,30 @@ def verify_reshape(shape, newshape): verify_reshape((2, 3, 4), (8, 3)) verify_reshape((4, 7), (2, 7, 2)) +def test_dynamic_to_static_tile(): + def verify_tile(shape, reps, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(reps, "float32")) + z = relay.tile(x, relay.shape_of(y)) + func = run_infer_type(relay.Function([x, y], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("tile") + assert zz.checked_type == relay.ty.TensorType(oshape, "float32") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=reps).astype("float32") + ref_res = np.tile(x_data, reps) + verify_func(func2, [x_data, y_data], ref_res) + + verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20)) + verify_tile((4, 7), (4, 2), (16, 14)) + if __name__=="__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() test_dynamic_to_static_quad_reshape() + test_dynamic_to_static_tile() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 0da1d719ccc5..b5fc02ae7f3c 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1016,6 +1016,43 @@ inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_t } } +/*! + * \brief Creates an operation to tile elements of an array + * + * \param x The input tensor + * \param new_shape The shape of the output after tiling + * \param rdim The rank of the reps, provided by caller + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the tile operation + */ +inline Tensor dyn_tile(const Tensor& x, Array new_shape, size_t rdim, + std::string name = "T_tile", std::string tag = kBroadcast) { + size_t ndim = x->shape.size(); + if (is_empty_shape(new_shape)) { + return compute( + new_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); + } else { + return compute( + new_shape, + [&](const Array& indices) { + Array idx; + if (ndim >= rdim) { + for (size_t i = 0; i < ndim; ++i) { + idx.push_back(indexmod(indices[i], x->shape[i])); + } + } else { + for (size_t i = 0; i < ndim; ++i) { + idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); + } + } + return x(idx); + }, + name, tag); + } +} + /*! * \brief Gather values along given axis from given indices. *