diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index cc9217c9f5f8..c05858fd887e 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -39,8 +39,8 @@ class ExportedProgramImporter(BaseFXGraphImporter): def _hardtanh(self, node: fx.Node) -> relax.Expr: args = self.retrieve_args(node) x = args[0] - min_val = node.args[1] if len(args) > 1 else node.kwargs("min_val", -1.0) - max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0) + min_val = node.args[1] if len(args) > 1 else node.kwargs.get("min_val", -1.0) + max_val = node.args[2] if len(args) > 2 else node.kwargs.get("max_val", 1.0) return self.block_builder.emit(relax.op.clip(x, min_val, max_val)) def _log2(self, node: fx.Node) -> relax.Var: @@ -216,6 +216,19 @@ def _slice(self, node: fx.Node) -> relax.Var: stride = [node.args[4] if len(node.args) > 4 else 1] return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + def _unflatten(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = node.args[1] + sizes = node.args[2] + + x_shape = list(self.shape_of(x)) + if dim < 0: + dim += len(x_shape) + + new_shape = x_shape[:dim] + sizes + x_shape[dim + 1 :] + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + ########## Creation ########## def _one_hot(self, node: fx.Node) -> relax.Var: @@ -258,6 +271,7 @@ def create_convert_map( "cos.default": self._unary_op(relax.op.cos), "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], + "dropout_.default": lambda node: self.env[node.args[0]], "elu.default": self._elu, "erf.default": self._unary_op(relax.op.erf), "exp.default": self._unary_op(relax.op.exp), @@ -265,7 +279,9 @@ def create_convert_map( "gelu.default": self._gelu, "hardsigmoid.default": self._hardsigmoid, "hardswish.default": self._hardswish, + "hardswish_.default": self._hardswish, "hardtanh.default": self._hardtanh, + "hardtanh_.default": self._hardtanh, "isfinite.default": self._unary_op(relax.op.isfinite), "isinf.default": self._unary_op(relax.op.isinf), "isnan.default": self._unary_op(relax.op.isnan), @@ -278,12 +294,14 @@ def create_convert_map( "neg.default": self._unary_op(relax.op.negative), "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), + "relu_.default": self._unary_op(relax.op.nn.relu), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), "selu.default": self._unary_op(relax.op.nn.selu), "sigmoid.default": self._unary_op(relax.op.sigmoid), "sign.default": self._unary_op(relax.op.sign), "silu.default": self._unary_op(relax.op.nn.silu), + "silu_.default": self._unary_op(relax.op.nn.silu), "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), "softmax.int": self._softmax, @@ -296,6 +314,7 @@ def create_convert_map( "triu.default": self._tril_triu(relax.op.triu), # binary "add.Tensor": self._binary_op(relax.op.add, operator.add), + "add_.Tensor": self._binary_op(relax.op.add, operator.add), "div.Tensor": self._binary_op(relax.op.divide, operator.truediv), "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), @@ -393,6 +412,7 @@ def create_convert_map( "tile.default": self._tile, "topk.default": self._topk, "transpose.int": self._transpose, + "unflatten.int": self._unflatten, "unsqueeze.default": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) ), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 081b82b3c563..dd4ead9e593e 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -254,6 +254,10 @@ class Dropout2(Module): def forward(self, input): return torch.dropout(input, 0.5, train=True) + class Dropout3(Module): + def forward(self, input): + return torch.ops.aten.dropout_(input, 0.5, train=True) + @tvm.script.ir_module class expected_dropout: @R.function @@ -268,6 +272,7 @@ def main( verify_model(Dropout1(), example_args, {}, expected_dropout) verify_model(Dropout2(), example_args, {}, expected_dropout) + verify_model(Dropout3(), example_args, {}, expected_dropout) # elu class Elu(Module): @@ -383,6 +388,10 @@ class Hardswish2(torch.nn.Module): def forward(self, input): return torch.nn.functional.hardswish(input) + class Hardswish3(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.hardswish_(input) + @tvm.script.ir_module class expected1: @R.function @@ -402,6 +411,7 @@ def main( verify_model(Hardswish(), example_args, {}, expected1) verify_model(Hardswish2(), example_args, {}, expected1) + verify_model(Hardswish3(), example_args, {}, expected1) # hardtanh test_hardtanh() @@ -511,6 +521,10 @@ class ReLU1(Module): def forward(self, input): return torch.nn.functional.relu(input) + class ReLU2(Module): + def forward(self, input): + return torch.ops.aten.relu_(input) + @tvm.script.ir_module class expected_relu: @R.function @@ -526,6 +540,7 @@ def main( verify_model(ReLU0(), example_args, {}, expected_relu) verify_model(ReLU1(), example_args, {}, expected_relu) + verify_model(ReLU2(), example_args, {}, expected_relu) # selu class Selu1(Module): @@ -597,6 +612,10 @@ class SiLU2(Module): def forward(self, input): return torch.nn.functional.silu(input) + class SiLU3(Module): + def forward(self, input): + return torch.ops.aten.silu_(input) + @tvm.script.ir_module class expected_silu: @R.function @@ -612,6 +631,7 @@ def main( verify_model(SiLU(), example_args, {}, expected_silu) verify_model(SiLU2(), example_args, {}, expected_silu) + verify_model(SiLU3(), example_args, {}, expected_silu) # softmax test_softmax() @@ -636,6 +656,10 @@ class Hardtanh2(torch.nn.Module): def forward(self, input): return torch.nn.functional.hardtanh(input) + class Hardtanh3(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.hardtanh_(input) + @tvm.script.ir_module class expected1: @R.function @@ -653,6 +677,7 @@ def main( example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) verify_model(Hardtanh(), example_args, {}, expected1) verify_model(Hardtanh2(), example_args, {}, expected1) + verify_model(Hardtanh3(), example_args, {}, expected1) def test_leakyrelu(): @@ -845,6 +870,7 @@ def main( operator_binary_1 = [ (operator.add, R.add), + (torch.ops.aten.add_, R.add), (operator.sub, R.subtract), (operator.mul, R.multiply), (operator.truediv, R.divide), @@ -3603,6 +3629,33 @@ def main( verify_model(Select(), example_args, {}, Expected) +def test_unflatten(): + class Unflatten(Module): + def forward(self, input): + return torch.ops.aten.unflatten(input, 1, (3, 5)) + + class Unflatten1(Module): + def forward(self, input): + return torch.ops.aten.unflatten(input, -2, (3, 5)) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((2, 15, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 3, 5, 7), dtype="float32") = R.reshape(inp_0, [2, 3, 5, 7]) + gv: R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 15, 7, dtype=torch.float32),) + + verify_model(Unflatten(), example_args, {}, Expected) + verify_model(Unflatten1(), example_args, {}, Expected) + + def test_gather(): class Gather0(Module): def forward(self, data, indices):