diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 3e81ff1f0bfe..a2f50bf9a98d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1427,6 +1427,15 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _inplace_fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled + def _full(self, node: fx.Node) -> relax.Var: import torch @@ -1640,6 +1649,10 @@ def _zeros_inplace(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = output return output + def _zeros_like(self, node: fx.node) -> relax.Var: + x = self.env[node.args[0]] + return self.block_builder.emit(relax.op.zeros_like(x)) + @abc.abstractmethod def create_convert_map( self, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a3ab575c4b78..df37a5b45085 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -471,6 +471,7 @@ def create_convert_map( "eye.default": self._eye, "eye.m": self._eye, "fill.Scalar": self._fill, + "fill_.Scalar": self._inplace_fill, "full.default": self._full, "full_like.default": self._full_like, "index_select.default": self._index_select, @@ -485,6 +486,7 @@ def create_convert_map( ), "zero_.default": self._zeros_inplace, "zeros.default": self._zeros, + "zeros_like.default": self._zeros_like, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 18dba2d988f2..492416e97f7c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -515,15 +515,6 @@ def _size(self, node: fx.Node) -> relax.Expr: ########## Creation ########## - def _inplace_fill(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) - filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) - self.env[node.args[0]] = filled - return filled - def _inplace_copy(self, node: fx.Node) -> relax.Var: src = self.env[node.args[1]] self.env[node.args[0]] = src @@ -828,6 +819,7 @@ def create_convert_map( "clone": lambda node: self.env[node.args[0]], "empty": self._empty, "empty_like": self._empty_like, + "fill": self._fill, "fill_": self._inplace_fill, "full": self._full, "index_select": self._index_select, @@ -842,6 +834,7 @@ def create_convert_map( ), "tensor": self._tensor, "zero_": self._zeros_inplace, + "zeros_like": self._zeros_like, "copy_": self._inplace_copy, # datatype "astype": self._type, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e3b6f4ad9c17..d1f0e5767aa3 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3647,6 +3647,30 @@ def main( verify_model(Fill(), example_args, {}, Expected) +def test_fill_inplace(): + class FillInplace(Module): + def forward(self, input: torch.Tensor): + input.fill_(42.0) + return input + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.full( + R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32" + ) + gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32),) + verify_model(FillInplace(), example_args, {}, Expected) + + def test_masked_fill(): class Masked_Fill(Module): def forward(self, input: torch.Tensor, mask: torch.Tensor): @@ -4014,6 +4038,27 @@ def main( verify_model(Zeros(), example_args, {}, Expected) +def test_zeros_like(): + class ZerosLike(Module): + def forward(self, input): + return torch.zeros_like(input) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32") + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void") + gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.rand(128, 128, dtype=torch.float32),) + verify_model(ZerosLike(), example_args, {}, Expected) + + def test_type_as(): class TypeAs(Module): def forward(self, input, other): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4003202d4f55..c6f4c40522f2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4717,6 +4717,26 @@ def main( verify_model(ZeroInplace(), [([128, 128], "float32")], {}, Expected) +def test_zeros_like(): + class ZerosLike(Module): + def forward(self, data): + return torch.zeros_like(data) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((128, 128), dtype="float32") + ) -> R.Tensor((128, 128), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(inp_0, dtype="void") + gv: R.Tensor((128, 128), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(ZerosLike(), [([128, 128], "float32")], {}, Expected) + + def test_type_as(): class TypeAs(Module): def forward(self, data, other):