diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 86f5de5f367c..7919e288c2c8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -307,6 +307,7 @@ def create_convert_map( "log2.default": self._log2, "log10.default": self._log10, "log1p.default": self._log1p, + "logical_not.default": self._unary_op(relax.op.logical_not), "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "pad.default": self._pad, @@ -481,6 +482,7 @@ def create_convert_map( "lift_fresh_copy.default": self._to_copy, "linspace.default": self._linspace, "masked_fill.Scalar": self._masked_fill, + "masked_fill_.Scalar": self._inplace_masked_fill, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, "ones.default": self._ones, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index dd1869a23c63..b9385d1cc20f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3727,6 +3727,30 @@ def main( verify_model(Masked_Fill(), example_args, {}, Expected) +def test_masked_fill_inplace(): + class Masked_Fill_Inplace(Module): + def forward(self, input: torch.Tensor, mask: torch.Tensor): + return input.masked_fill_(mask, 1.5) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool") + ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((128, 128), dtype="float32") = R.full_like( + input, R.const(1.5, "float32"), dtype="void" + ) + lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input) + gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5) + verify_model(Masked_Fill_Inplace(), example_args, {}, Expected) + + def test_new_ones(): class NewOnes(Module): def forward(self, x): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f60f158cbfa4..fdec5ed19ccd 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3243,6 +3243,31 @@ def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype verify_model(InplaceFill(), [([10, 10], "float32")], {}, Expected) +def test_masked_fill_inplace(): + class Masked_Fill_Inplace(Module): + def forward(self, input: torch.Tensor, mask: torch.Tensor): + input.masked_fill_(mask, 1.5) + return input + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32"), mask: R.Tensor((10, 10), dtype="bool") + ) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.full_like( + input, R.const(1.5, "float32"), dtype="void" + ) + lv1: R.Tensor((10, 10), dtype="float32") = R.where(mask, lv, input) + gv: R.Tensor((10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [((10, 10), "float32"), ((10, 10), "bool")] + verify_model(Masked_Fill_Inplace(), input_info, {}, Expected) + + def test_arange(): import numpy as np