diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 890f925079e0..cca03e95e62c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -959,6 +959,12 @@ def _where(self, node: fx.Node) -> relax.Var: ########## Manipulation ########## + def _argsort(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) + return self.block_builder.emit(relax.op.argsort(x, dim, descending)) + def _cat(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) @@ -1071,6 +1077,12 @@ def _scatter(self, node: fx.Node) -> relax.Var: raise Exception("Unexpected args " + str(node.args)) return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) + def _sort(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) + return self.block_builder.emit(relax.op.sort(x, dim, descending)) + def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] split_size = node.args[1] @@ -1121,6 +1133,22 @@ def _tile(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.tile(x, dims)) + def _topk(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + k = args[1] if len(args) > 1 else node.kwargs.get("k", 1) + dim = args[2] if len(args) > 2 else node.kwargs.get("dim", -1) + largest = args[3] if len(args) > 3 else node.kwargs.get("largest", True) + _sorted = args[4] if len(args) > 4 else node.kwargs.get("_sorted", True) + + if not _sorted: + msg = "Currently supports only sorted output for topk operator." + raise AssertionError(msg) + + return self.block_builder.emit( + relax.op.topk(x, k=k, axis=dim, largest=largest, ret_type="both", dtype="int64") + ) + def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) full_idx = list(range(len(self.shape_of(args[0])))) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 3ddf919c2ed1..f803b453ca2b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -785,6 +785,7 @@ def create_convert_map( "argmin": self._argmax_argmin(relax.op.argmin), "where": self._where, # tensor manipulation + "argsort": self._argsort, "cat": self._cat, "chunk": self._chunk, "concat": self._cat, @@ -803,11 +804,13 @@ def create_convert_map( "scatter": self._scatter, "select": self._select, "size": self._size, + "sort": self._sort, "split": self._split, "squeeze": self._squeeze, "stack": self._stack, "take": self._take, "tile": self._tile, + "topk": self._topk, "transpose": self._transpose, "unsqueeze": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index d913baf13a0d..2c5560b577c4 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4368,5 +4368,67 @@ def main( ) +def test_argsort(): + class Argsort(Module): + def forward(self, x): + return torch.argsort(x, dim=1, descending=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="int32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="int32") = R.argsort(inp_0, axis=1, descending=True) + gv: R.Tensor((5, 3), dtype="int32") = lv + R.output(gv) + return gv + + verify_model(Argsort(), [([5, 3], "float32")], {}, Expected) + + +def test_sort(): + class Sort(Module): + def forward(self, x): + return torch.sort(x, dim=1, descending=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1, descending=True) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sort(), [([5, 3], "float32")], {}, Expected) + + +def test_topk(): + class Topk(Module): + def forward(self, x): + return torch.topk(x, k=2, dim=1, largest=True, sorted=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tuple( + R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64") + ) = R.topk(inp_0, k=2, axis=1, ret_type="both", largest=True, dtype="int64") + gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = lv + R.output(gv) + return gv + + verify_model(Topk(), [([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main()