diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index d67cacb960bf..bc7a4c4cb046 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -174,6 +174,23 @@ def _slice(self, node: fx.Node) -> relax.Var: stride = [node.args[4] if len(node.args) > 4 else 1] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + ########## Creation ########## + + def _one_hot(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") + if num_classes is None: + raise ValueError("num_classes not found in node.args or node.kwargs") + + on_value = node.args[2] if len(node.args) > 2 else node.kwargs.get("on_value", 1) + off_value = node.args[3] if len(node.args) > 3 else node.kwargs.get("off_value", 0) + axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis", -1) + + on_value = relax.PrimValue(on_value) + off_value = relax.PrimValue(off_value) + + return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis)) + ########## Others ########## def create_convert_map( @@ -331,8 +348,10 @@ def create_convert_map( "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, + "empty_like.default": self._empty_like, "fill.Scalar": self._fill, "new_ones.default": self._new_ones, + "one_hot.default": self._one_hot, # other "getitem": self._getitem, } diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e2be933050d3..1b4e8025395e 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3425,6 +3425,74 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_empty_like(): + class EmptyLike(Module): + def forward(self, data): + return torch.empty_like(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), + ) -> R.Tuple(R.Tensor((5,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, dtype="void") + gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(5, dtype=torch.float32),) + + verify_model(EmptyLike(), example_args, {}, Expected) + + +def test_one_hot(): + class OneHot(Module): + def forward(self, indices): + return torch.nn.functional.one_hot(indices, num_classes=10) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="int64"), + ) -> R.Tuple(R.Tensor((5, 10), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((5, 10), dtype="int64") = R.one_hot( + inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1 + ) + gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),) + + verify_model(OneHot(), example_args, {}, Expected) + + +def test_select(): + class Select(Module): + def forward(self, input): + return torch.select(input, 0, 1) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((3,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, "int64"), axis=0) + gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32),) + + verify_model(Select(), example_args, {}, Expected) + + def test_gather(): class Gather0(Module): def forward(self, data, indices):