diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ae4c918900ec..733a5d6b1a87 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1416,6 +1416,13 @@ def _empty_like(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.zeros_like(x)) + def _eye(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + n = args[0] + m = args[1] if len(args) > 1 else n + dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.eye(n, m, dtype=dtype)) + def _fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 932607287571..af1393329e1f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -453,6 +453,8 @@ def create_convert_map( "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, "empty_like.default": self._empty_like, + "eye.default": self._eye, + "eye.m": self._eye, "fill.Scalar": self._fill, "full.default": self._full, "full_like.default": self._full_like, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 80c0bd5fb4f5..ce68089048a1 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4377,5 +4377,45 @@ def main( verify_model(Narrow(), example_args, {}, Expected) +def test_eye(): + class Eye1(Module): + def forward(self, input): + return torch.eye(3, 5, dtype=torch.float32) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + input: R.Tensor((3, 5), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32") + gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Eye2(Module): + def forward(self, input): + return torch.eye(5, dtype=torch.float32) + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + input: R.Tensor((5,), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32") + gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args1 = (torch.randn(3, 5, dtype=torch.float32),) + verify_model(Eye1(), example_args1, {}, Expected1) + + example_args2 = (torch.randn(5, dtype=torch.float32),) + verify_model(Eye2(), example_args2, {}, Expected2) + + if __name__ == "__main__": tvm.testing.main()