diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 875ec3b83ea8..0f97092946bf 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -202,6 +202,13 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: ########## Manipulation ########## + def _narrow(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + start = node.args[2] + length = node.args[3] + return self.block_builder.emit(relax.op.strided_slice(x, [dim], [start], [length])) + def _select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] @@ -390,6 +397,7 @@ def create_convert_map( "where.self": self._where, # tensor manipulation "argsort.default": self._argsort, + "broadcast_to.default": self._broadcast_to, "cat.default": self._cat, "chunk.default": self._chunk, "clamp.Tensor": self._clamp, @@ -402,6 +410,7 @@ def create_convert_map( "flatten.using_ints": self._flatten, "flip.default": self._flip, "gather.default": self._gather, + "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, "select.int": self._select, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 42288cf562cd..284544be5079 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3856,5 +3856,55 @@ def main( verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes) +def test_broadcast_to(): + class BroadcastTo(Module): + def forward(self, x): + return torch.broadcast_to(x, (5, 3)) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((5, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(x, R.shape([5, 3])) + gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,) + R.output(gv) + + return gv + + example_args = (torch.randn(5, 1, dtype=torch.float32),) + verify_model(BroadcastTo(), example_args, {}, Expected) + + +def test_narrow(): + class Narrow(Module): + def forward(self, x): + return torch.narrow(x, 1, 0, 2) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((5, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice( + x, + (R.prim_value(1),), + (R.prim_value(0),), + (R.prim_value(2),), + assume_inbound=False, + ) + gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) + R.output(gv) + + return gv + + example_args = (torch.randn(5, 3, dtype=torch.float32),) + verify_model(Narrow(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main()