diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7b9587b67561..902222c6181c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -340,11 +340,15 @@ def create_convert_map( ), "max.other": self._binary_op(relax.op.maximum, max), "min.other": self._binary_op(relax.op.minimum, min), + "max.default": self._unary_op(relax.op.max), + "min.default": self._unary_op(relax.op.min), "remainder.Tensor": self._binary_op(relax.op.mod, operator.mod), "remainder.Scalar": self._binary_op(relax.op.mod, operator.mod), "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul), + "mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul), "ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne), "ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne), + "pow.Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 284544be5079..435631e1bc2f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -474,6 +474,44 @@ def main( verify_model(Reciprocal(), example_args, {}, expected_reciprocal) + # Returns the maximum value of all elements in the input tensor. + class MaxModel(Module): + def forward(self, input): + return torch.max(input) + + @tvm.script.ir_module + class expected_max: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.max(input, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(MaxModel(), example_args, {}, expected_max) + + # Returns the minimum value of all elements in the input tensor. + class MinModel(Module): + def forward(self, input): + return torch.min(input) + + @tvm.script.ir_module + class expected_min: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.min(input, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(MinModel(), example_args, {}, expected_min) + def test_hardtanh(): class Hardtanh(torch.nn.Module): @@ -742,6 +780,7 @@ def main( (torch.ops.aten.add_, R.add), (operator.sub, R.subtract), (operator.mul, R.multiply), + (torch.ops.aten.mul_, R.multiply), (operator.truediv, R.divide), (operator.floordiv, R.floor_divide), (operator.pow, R.power),