diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index f7ffe895b48b5..8279b1249cedc 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -87,20 +87,37 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims): @script -def _tile_shape_func(data, reps, ndim): - out = output_tensor((ndim,), "int64") +def _tile_shape_func(data, reps, ndim, tndim, rndim): + out = output_tensor((tndim,), "int64") - for i in const_range(ndim): - out[i] = data.shape[i] * int64(reps[i]) + if ndim == rndim: + for i in const_range(tndim): + out[i] = int64(data.shape[i] * reps[i]) + elif ndim > rndim: + ngap = ndim - rndim + for i in const_range(ndim): + if i < ngap: + out[i] = int64(data.shape[i]) + else: + out[i] = int64(data.shape[i] * reps[i - ngap]) + else: + rgap = rndim - ndim + for i in const_range(rndim): + if i < rgap: + out[i] = int64(reps[i]) + else: + out[i] = int64(reps[i] * data.shape[i - rgap]) return out @_reg.register_shape_func("dyn.tile", True) def tile_shape_func(attrs, inputs, _): """ - Shape function for tile op. + Shape function for dyn.tile op. """ + reps = inputs[1] ndim = len(inputs[0].shape) - rdim = inputs[1].shape[0].value - assert ndim == rdim, "tile data and reps ranks don't match" - return [_tile_shape_func(inputs[0], inputs[1], convert(ndim))] + rndim = inputs[1].shape[0].value + tndim = ndim if ndim > rndim else rndim + return [_tile_shape_func(inputs[0], reps, convert(ndim), + convert(tndim), convert(rndim))] diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index e64d09f3399e5..0b8a15676cc33 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -149,14 +149,14 @@ bool TileRel(const Array& types, int num_inputs, const Attrs& attrs, << "tile: expect input type to be TensorType but get " << types[1]; return false; } - const size_t ndim = data->shape.size(); const IntImmNode* reps_shape = reps->shape[0].as(); CHECK(reps_shape) << "Parameter reps must have static shape"; - // check dimension match - CHECK_EQ(ndim, reps_shape->value) << "tile: the shape of reps must match the rank of data"; + const size_t ndim = data->shape.size(); + const size_t rndim = reps_shape->value; + size_t tndim = (ndim > rndim) ? ndim : rndim; std::vector oshape; - oshape.reserve(ndim); - for (size_t i = 0; i < ndim; ++i) { + oshape.reserve(tndim); + for (size_t i = 0; i < tndim; ++i) { oshape.emplace_back(Any()); } reporter->Assign(types[2], TensorType(oshape, data->dtype)); @@ -167,7 +167,8 @@ Array TileCompute(const Attrs& attrs, const Array& input const Type& out_type) { CHECK_EQ(inputs.size(), 2); const auto* out_ttype = out_type.as(); - return {topi::dyn_tile(inputs[0], out_ttype->shape)}; + size_t rndim = inputs[1]->shape[0].as()->value; + return {topi::dyn_tile(inputs[0], out_ttype->shape, rndim)}; } Expr MakeTile(Expr data, Expr reps) { diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 176641583026e..2f473c9de070b 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -73,7 +73,7 @@ def verify_reshape(shape, newshape, oshape): def test_dyn_tile(): def verify_tile(dshape, reps): x = relay.var("x", relay.TensorType(dshape, "float32")) - r = relay.var("reps", relay.TensorType((len(dshape), ), "float32")) + r = relay.var("reps", relay.TensorType((len(reps), ), "float32")) z = relay.tile(x, r) func = relay.Function([x, r], z) @@ -81,8 +81,8 @@ def verify_tile(dshape, reps): ref_res = np.tile(x_data, reps=reps) verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res) verify_tile((2, 3, 4), (3, 2, 1)) - verify_tile((2, 3, 4), (1, 2, 1)) - verify_tile((1, 2, 3), (3, 2, 1)) + verify_tile((2, 3, 4), (1, 2)) + verify_tile((2, 3), (3, 2, 1)) if __name__ == "__main__": test_dyn_reshape() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 75bf7495a843e..4586fb2d1a706 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1021,13 +1021,14 @@ inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_t * * \param x The input tensor * \param new_shape The shape of the output after tiling + * \param rdim The rank of the reps, provided by caller * \param name The name of the operation * \param tag The tag to mark the operation * * \return A Tensor whose op member is the tile operation */ -inline Tensor dyn_tile(const Tensor& x, Array new_shape, std::string name = "T_tile", - std::string tag = kBroadcast) { +inline Tensor dyn_tile(const Tensor& x, Array new_shape, size_t rdim, + std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); if (is_empty_shape(new_shape)) { return compute( @@ -1037,8 +1038,14 @@ inline Tensor dyn_tile(const Tensor& x, Array new_shape, std::string n new_shape, [&](const Array& indices) { Array idx; - for (size_t i = 0; i < ndim; ++i) { - idx.push_back(indexmod(indices[i], x->shape[i])); + if (ndim >= rdim) { + for (size_t i = 0; i < ndim; ++i) { + idx.push_back(indexmod(indices[i], x->shape[i])); + } + } else { + for (size_t i = 0; i < ndim; ++i) { + idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); + } } return x(idx); },