diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c9480b58748..099049ffc92c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1138,6 +1138,13 @@ def _reshape(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.reshape(x, dims)) + def _reshape_as(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + other = args[1] + dims = self.shape_of(other) + return self.block_builder.emit(relax.op.reshape(x, dims)) + def _scatter(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if len(node.args) == 1: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c82a5e2b1100..1c7b37fb975f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -293,6 +293,7 @@ def create_convert_map( "isinf.default": self._unary_op(relax.op.isinf), "isnan.default": self._unary_op(relax.op.isnan), "leaky_relu.default": self._leakyrelu, + "leaky_relu_.default": self._leakyrelu, "log.default": self._unary_op(relax.op.log), "log2.default": self._log2, "log10.default": self._log10, @@ -431,6 +432,7 @@ def create_convert_map( ), "view.default": self._reshape, "reshape.default": self._reshape, + "reshape_as.default": self._reshape_as, # tensor creation "_to_copy.default": self._to_copy, "arange.default": self._arange, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 26d3d3f7bde2..fe93fa833315 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -567,6 +567,10 @@ class LeakyReLU1(Module): def forward(self, input): return torch.nn.functional.leaky_relu(input, 0.02) + class LeakyReLU2(Module): + def forward(self, input): + return torch.ops.aten.leaky_relu_(input, 0.02) + @tvm.script.ir_module class expected: @R.function @@ -583,6 +587,7 @@ def main( example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) verify_model(LeakyReLU0(), example_args, {}, expected) verify_model(LeakyReLU1(), example_args, {}, expected) + verify_model(LeakyReLU2(), example_args, {}, expected) def test_logaddexp(): @@ -2736,6 +2741,32 @@ def main( verify_model(Reshape(), example_args, {}, expected1) +def test_reshape_as(): + class ReshapeAs(Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x.reshape_as(y) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32"), + y: R.Tensor((2, 12), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(1, 2, 3, 4, dtype=torch.float32), + torch.randn(2, 12, dtype=torch.float32), + ) + verify_model(ReshapeAs(), example_args, {}, expected1) + + def test_select_slice(): class Slice1(Module): def forward(self, x):