Skip to content

Commit

Permalink
[Relax][PyTorch] Add support for torch.tile (#17291)
Browse files Browse the repository at this point in the history
* add test

* add support for torch.tile
  • Loading branch information
mshr-h authored Aug 22, 2024
1 parent 8db545d commit 481c2dc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,14 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var:
dim = None
return self.block_builder.emit(relax.op.squeeze(x, dim))

def _tile(self, node: fx.node.Node) -> relax.Var:
import torch # type: ignore

args = self.retrieve_args(node)
if isinstance(args[1], (torch.Size, tuple, list)):
return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1])))
return self.block_builder.emit(relax.op.tile(args[0], args[1:]))

def _cumsum(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]

Expand Down Expand Up @@ -1450,6 +1458,7 @@ def create_convert_map(self):
"permute": self._permute,
"reshape": self._reshape,
"split": self._split,
"tile": self._tile,
"cumsum": self._cumsum,
"chunk": self._chunk,
"transpose": self._transpose,
Expand Down
42 changes: 42 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3126,6 +3126,48 @@ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype=
verify_model(Reshape(), input_info, {}, expected1)


def test_tile():
input_info = [([1, 3], "float32")]

class Tile1(Module):
def forward(self, x):
return x.tile((2,))

class Tile2(Module):
def forward(self, x):
return x.tile(4, 2)

class Tile3(Module):
def forward(self, x):
return torch.tile(x, (4, 2))

@tvm.script.ir_module
class expected1:
@R.function
def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((1, 6), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2])
gv: R.Tensor((1, 6), dtype="float32") = lv
R.output(gv)
return gv

@tvm.script.ir_module
class expected2:
@R.function
def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
gv: R.Tensor((4, 6), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Tile1(), input_info, {}, expected1)
verify_model(Tile2(), input_info, {}, expected2)
verify_model(Tile3(), input_info, {}, expected2)


def test_transpose():
input_info = [([1, 2, 3, 4], "float32")]

Expand Down

0 comments on commit 481c2dc

Please sign in to comment.