From 481c2dc85209fa3d104c020b0d8d8e4ce7ed20c1 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 23 Aug 2024 07:16:44 +0900 Subject: [PATCH] [Relax][PyTorch] Add support for torch.tile (#17291) * add test * add support for torch.tile --- .../tvm/relax/frontend/torch/fx_translator.py | 9 ++++ tests/python/relax/test_frontend_from_fx.py | 42 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 093f3ae4cf7a..35131d324076 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -612,6 +612,14 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _tile(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + def _cumsum(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1450,6 +1458,7 @@ def create_convert_map(self): "permute": self._permute, "reshape": self._reshape, "split": self._split, + "tile": self._tile, "cumsum": self._cumsum, "chunk": self._chunk, "transpose": self._transpose, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 1a2cc5da6242..6be3e7b23e9d 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3126,6 +3126,48 @@ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype= verify_model(Reshape(), input_info, {}, expected1) +def test_tile(): + input_info = [([1, 3], "float32")] + + class Tile1(Module): + def forward(self, x): + return x.tile((2,)) + + class Tile2(Module): + def forward(self, x): + return x.tile(4, 2) + + class Tile3(Module): + def forward(self, x): + return torch.tile(x, (4, 2)) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((1, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2]) + gv: R.Tensor((1, 6), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tensor((4, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tile1(), input_info, {}, expected1) + verify_model(Tile2(), input_info, {}, expected2) + verify_model(Tile3(), input_info, {}, expected2) + + def test_transpose(): input_info = [([1, 2, 3, 4], "float32")]