From 1a6bacf6caca4e41401db2993cfc7e6055de93ec Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 9 Mar 2025 20:19:16 +0800 Subject: [PATCH 1/5] Update fx_translator.py --- .../tvm/relax/frontend/torch/fx_translator.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ef98d3c02501..513d045d2833 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -409,6 +409,11 @@ def _flatten_module(self, node: fx.Node) -> relax.Var: end_dim = module.end_dim return self._flatten_impl(x, start_dim, end_dim) + def _numel(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + shape = self.shape_of(x) + return relax.const(reduce(lambda x, y: x * y, [s.value for s in shape]), "int32") + def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -511,6 +516,20 @@ def _ones(self, node: fx.Node) -> relax.Var: ) ) + def _one_hot(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") + if num_classes is None: + raise ValueError("num_classes not found in node.args or node.kwargs") + on_value = node.args[2] if len(node.args) > 2 else node.kwargs.get("on_value", 1) + off_value = node.args[3] if len(node.args) > 3 else node.kwargs.get("off_value", 0) + axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis", -1) + on_value = relax.PrimValue(on_value) + off_value = relax.PrimValue(off_value) + return self.block_builder.emit( + relax.op.one_hot(x, on_value, off_value, num_classes, axis) + ) + def _tensor(self, node: fx.Node) -> relax.Var: dtype = node.kwargs.get("dtype", None) if isinstance(node.args[0], float): @@ -735,6 +754,7 @@ def create_convert_map( "flatten": self._flatten, "flip": self._flip, "gather": self._gather, + "numel": self._numel, "permute": self._permute, "repeat": self._repeat, "reshape": self._reshape, @@ -753,6 +773,7 @@ def create_convert_map( # tensor creation "arange": self._arange, "empty": self._empty, + "empty_like": self._empty_like, "fill_": self._inplace_fill, "full": self._full, "index_select": self._index_select, @@ -761,6 +782,7 @@ def create_convert_map( "masked_scatter": self._masked_scatter, "new_ones": self._new_ones, "ones": self._ones, + "one_hot": self._one_hot, "tensor": self._tensor, # datatype "astype": self._type, From 36567aee998071d5521e5f4311176d6958f741de Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 9 Mar 2025 20:20:39 +0800 Subject: [PATCH 2/5] Update base_fx_graph_translator.py --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 003ceebec6ff..a9f54d91e3ce 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1018,6 +1018,10 @@ def _empty(self, node: fx.Node) -> relax.Var: dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env) return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + def _empty_like(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.zeros_like(x)) + def _fill(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] From ee276701292a7f912ac641840ffd9acf70125cbd Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 9 Mar 2025 20:27:12 +0800 Subject: [PATCH 3/5] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 61 +++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 0b4b34e0c9bb..be1c39bd2b65 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4037,5 +4037,66 @@ def main( verify_model(Take(), [([5], "float32"), ([3], "int32")], {}, Expected) +def test_one_hot(): + class OneHot(Module): + def forward(self, indices): + return torch.nn.functional.one_hot(indices, num_classes=10) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="int32"), + ) -> R.Tensor((5, 10), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(inp_0, R.prim_value(1), R.prim_value(0), depth=10, + axis=-1) + gv: R.Tensor((5, 10), dtype="int64") = lv + R.output(gv) + + return gv + + verify_model(OneHot(), [([5], "int32")], {}, Expected) + + +def test_empty_like(): + class EmptyLike(Module): + def forward(self, data): + return torch.empty_like(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), + ) -> R.Tensor((5,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0) + gv: R.Tensor((5,), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(EmptyLike(), [([5], "float32")], {}, Expected) + + +def test_numel(): + class Numel(Module): + def forward(self, data): + return torch.numel(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + gv: R.Tensor((), dtype="int32") = R.const(15, "int32") + R.output(gv) + return gv + + verify_model(Numel(), [([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main() From 6abcfcc0df142e17aeced49b74089cb94339e5e9 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 9 Mar 2025 23:19:11 +0800 Subject: [PATCH 4/5] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index be1c39bd2b65..020fc8f5b3c2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4046,11 +4046,12 @@ def forward(self, indices): class Expected: @R.function def main( - inp_0: R.Tensor((5,), dtype="int32"), + inp_0: R.Tensor((5,), dtype="int32"), ) -> R.Tensor((5, 10), dtype="int64"): with R.dataflow(): - lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(inp_0, R.prim_value(1), R.prim_value(0), depth=10, - axis=-1) + lv: R.Tensor((5, 10), dtype="int64") = R.one_hot( + inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1 + ) gv: R.Tensor((5, 10), dtype="int64") = lv R.output(gv) From d8ef728c13897b234247cbab9ec4672b5f3c41f8 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 9 Mar 2025 23:24:08 +0800 Subject: [PATCH 5/5] lint --- python/tvm/relax/frontend/torch/fx_translator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 513d045d2833..29d959818f21 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -526,9 +526,7 @@ def _one_hot(self, node: fx.Node) -> relax.Var: axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis", -1) on_value = relax.PrimValue(on_value) off_value = relax.PrimValue(off_value) - return self.block_builder.emit( - relax.op.one_hot(x, on_value, off_value, num_classes, axis) - ) + return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis)) def _tensor(self, node: fx.Node) -> relax.Var: dtype = node.kwargs.get("dtype", None)