diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index affbd81e1c28..7660c1f5756c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -966,6 +966,12 @@ def _argsort(self, node: fx.Node) -> relax.Var: descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) return self.block_builder.emit(relax.op.argsort(x, dim, descending)) + def _broadcast_to(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + shape = args[1] if len(args) > 1 else args[0] + return self.block_builder.emit(relax.op.broadcast_to(x, shape)) + def _cat(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a151a57ae659..2f2126cc43ac 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -422,6 +422,13 @@ def _flatten_module(self, node: fx.Node) -> relax.Var: end_dim = module.end_dim return self._flatten_impl(x, start_dim, end_dim) + def _narrow(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + start = node.args[2] + length = node.args[3] + return self.block_builder.emit(relax.op.strided_slice(x, [dim], [start], [length])) + def _numel(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -755,6 +762,7 @@ def create_convert_map( "where": self._where, # tensor manipulation "argsort": self._argsort, + "broadcast_to": self._broadcast_to, "cat": self._cat, "chunk": self._chunk, "concat": self._cat, @@ -766,6 +774,7 @@ def create_convert_map( "flatten": self._flatten, "flip": self._flip, "gather": self._gather, + "narrow": self._narrow, "numel": self._numel, "permute": self._permute, "repeat": self._repeat, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 2c5560b577c4..9505356fcefd 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4430,5 +4430,48 @@ def main( verify_model(Topk(), [([5, 3], "float32")], {}, Expected) +def test_broadcast_to(): + class BroadcastTo(Module): + def forward(self, x): + return torch.broadcast_to(x, (5, 3)) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 1), dtype="float32"), + ) -> R.Tensor((5, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(inp_0, (5, 3)) + gv: R.Tensor((5, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(BroadcastTo(), [([5, 1], "float32")], {}, Expected) + + +def test_narrow(): + class Narrow(Module): + def forward(self, x): + return torch.narrow(x, 1, 0, 2) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((5, 3), dtype="float32"), + ) -> R.Tensor((5, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice( + inp_0, axes=[1], begin=[0], end=[2] + ) + gv: R.Tensor((5, 2), dtype="float32") = lv + R.output(gv) + + return gv + + verify_model(Narrow(), [([5, 3], "float32")], {}, Expected) + + if __name__ == "__main__": tvm.testing.main()