From 5d3e7125c8638bc45a78cd90eb685d5786c65200 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 13 Mar 2025 18:38:39 +0800 Subject: [PATCH 1/4] Update exported_program_translator.py --- .../torch/exported_program_translator.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0016046b0ecf..118a391ef62f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -174,6 +174,25 @@ def _slice(self, node: fx.Node) -> relax.Var: stride = [node.args[4] if len(node.args) > 4 else 1] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + ########## Creation ########## + + def _one_hot(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") + if num_classes is None: + raise ValueError("num_classes not found in node.args or node.kwargs") + + on_value = node.args[2] if len(node.args) > 2 else node.kwargs.get("on_value", 1) + off_value = node.args[3] if len(node.args) > 3 else node.kwargs.get("off_value", 0) + axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis", -1) + + on_value = relax.PrimValue(on_value) + off_value = relax.PrimValue(off_value) + + return self.block_builder.emit( + relax.op.one_hot(x, on_value, off_value, num_classes, axis) + ) + ########## Others ########## def create_convert_map( @@ -328,8 +347,10 @@ def create_convert_map( "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, + "empty_like.default": self._empty_like, "fill.Scalar": self._fill, "new_ones.default": self._new_ones, + "one_hot.default": self._one_hot, # other "getitem": self._getitem, } From 4ee72d28498c8691605006aa1cfee8ced4e8acc9 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 13 Mar 2025 18:40:45 +0800 Subject: [PATCH 2/4] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e18986187d1f..d7087e812f70 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3425,5 +3425,73 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_empty_like(): + class EmptyLike(Module): + def forward(self, data): + return torch.empty_like(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), + ) -> R.Tuple(R.Tensor((5,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, dtype="void") + gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(5, dtype=torch.float32),) + + verify_model(EmptyLike(), example_args, {}, Expected) + + +def test_one_hot(): + class OneHot(Module): + def forward(self, indices): + return torch.nn.functional.one_hot(indices, num_classes=10) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="int64"), + ) -> R.Tuple(R.Tensor((5, 10), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((5, 10), dtype="int64") = R.one_hot( + inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1 + ) + gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),) + + verify_model(OneHot(), example_args, {}, Expected) + + +def test_select(): + class Select(Module): + def forward(self, input): + return torch.select(input, 0, 1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((3,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, "int64"), axis=0) + gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32),) + + verify_model(Select(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From ab5e077952585a7fe3e5344d5f275fd9609d45dc Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 15 Mar 2025 22:19:30 +0800 Subject: [PATCH 3/4] Update test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 340c19cccca6..1b4e8025395e 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3434,7 +3434,7 @@ def forward(self, data): class Expected: @R.function def main( - inp_0: R.Tensor((5,), dtype="float32"), + inp_0: R.Tensor((5,), dtype="float32"), ) -> R.Tuple(R.Tensor((5,), dtype="float32")): with R.dataflow(): lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, dtype="void") @@ -3475,7 +3475,7 @@ def test_select(): class Select(Module): def forward(self, input): return torch.select(input, 0, 1) - + @tvm.script.ir_module class Expected: @R.function @@ -3491,7 +3491,7 @@ def main( example_args = (torch.randn(2, 3, dtype=torch.float32),) verify_model(Select(), example_args, {}, Expected) - + def test_gather(): class Gather0(Module): From 87eafc69cd982ca34b3e02ed2e36659927d49067 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 15 Mar 2025 22:20:58 +0800 Subject: [PATCH 4/4] Update exported_program_translator.py --- .../tvm/relax/frontend/torch/exported_program_translator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c98edf40de45..bc7a4c4cb046 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -189,10 +189,8 @@ def _one_hot(self, node: fx.Node) -> relax.Var: on_value = relax.PrimValue(on_value) off_value = relax.PrimValue(off_value) - return self.block_builder.emit( - relax.op.one_hot(x, on_value, off_value, num_classes, axis) - ) - + return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis)) + ########## Others ########## def create_convert_map(