diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 13f41fc87001..4e724da366da 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -495,3 +495,35 @@ def reshape_like_shape_func(attrs, inputs, _): Shape function for reshape_like op. """ return [_reshape_like_shape_func(inputs[1])] + +@script +def _tile_shape_func(data, reps, ndim, tndim, rndim): + out = output_tensor((tndim,), "int64") + + if ndim == rndim: + for i in const_range(tndim): + out[i] = data[i] * int64(reps[i]) + elif ndim > rndim: + ngap = ndim - rndim + for i in const_range(ndim): + if i < ngap: + out[i] = data[i] + else: + out[i] = data[i] * int64(reps[i - ngap]) + else: + rgap = rndim - ndim + for i in const_range(rndim): + if i < rgap: + out[i] = int64(reps[i]) + else: + out[i] = int64(reps[i]) * data[i - rgap] + return out + +@_reg.register_shape_func("tile", False) +def tile_shape_func(attrs, inputs, _): + reps = get_const_tuple(attrs.reps) + ndim = inputs[0].shape[0].value + rndim = len(reps) + tndim = ndim if ndim > rndim else rndim + return [_tile_shape_func(inputs[0], convert(reps), convert(ndim), + convert(tndim), convert(rndim))] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 203a0411d3c4..944d0dbe9852 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1393,28 +1393,39 @@ bool TileRel(const Array& types, reps_shape.reserve(tndim); if (ndim == rndim) { for (size_t i = 0; i < tndim; ++i) { - data_shape.emplace_back(data->shape[i]); - reps_shape.emplace_back(reps[i]); + data_shape.emplace_back(data->shape[i]); + reps_shape.emplace_back(reps[i]); } } else if (ndim > rndim) { - for (size_t i = 0; i < ndim; ++i) - data_shape.emplace_back(data->shape[i]); - for (size_t i = 0; i < (ndim - rndim); ++i) - reps_shape.emplace_back(1); - for (size_t i = 0; i < rndim; ++i) - reps_shape.emplace_back(reps[i]); + for (size_t i = 0; i < ndim; ++i) { + data_shape.emplace_back(data->shape[i]); + } + for (size_t i = 0; i < (ndim - rndim); ++i) { + reps_shape.emplace_back(1); + } + for (size_t i = 0; i < rndim; ++i) { + reps_shape.emplace_back(reps[i]); + } } else { - for (size_t i = 0; i < rndim; ++i) - reps_shape.emplace_back(reps[i]); - for (size_t i = 0; i < (rndim - ndim); ++i) - data_shape.emplace_back(1); - for (size_t i = 0; i < ndim; ++i) - data_shape.emplace_back(data->shape[i]); + for (size_t i = 0; i < rndim; ++i) { + reps_shape.emplace_back(reps[i]); + } + for (size_t i = 0; i < (rndim - ndim); ++i) { + data_shape.emplace_back(1); + } + for (size_t i = 0; i < ndim; ++i) { + data_shape.emplace_back(data->shape[i]); + } } std::vector oshape; oshape.reserve(tndim); for (size_t i = 0; i < tndim; ++i) { - oshape.emplace_back(data_shape[i] * reps_shape[i]); + // Save Any if it is dynamic shape + if (!data_shape[i].as()) { + oshape.emplace_back(Any::make()); + } else { + oshape.emplace_back(data_shape[i] * reps_shape[i]); + } } reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); return true; diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 75be88cbcb19..8001c8ee6b34 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -166,6 +166,25 @@ def test_any_take(): verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4)) verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5)) +def verify_any_tile(dshape, reps, np_dshape, np_reps): + mod = relay.Module() + x = relay.var("x", shape=dshape, dtype="float32") + y = relay.tile(x, reps=reps) + mod["main"] = relay.Function([x], y) + x_data = np.random.uniform(size=np_dshape).astype("float32") + ref_res = np.tile(x_data, reps=np_reps) + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + res = ex.evaluate()(x_data) + tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5) + +def test_any_tile(): + verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1)) + verify_any_tile(any_dims(3), (1, 2), (2, 3, 4), (1, 2)) + verify_any_tile(any_dims(2), (3, 2, 1), (2, 3), (3, 2, 1)) + verify_any_tile(any_dims(3), (1,), (2, 3, 4), (1,)) + def test_any_shape_of(): x = relay.var('x', shape=any_dims(2), dtype='float32') y = relay.shape_of(x) @@ -558,6 +577,7 @@ def _body(i, st): test_any_concat() test_any_reshape() test_any_take() + test_any_tile() test_any_shape_of() test_any_reduce() test_any_layout_transform()