From 599c7d8d2f4ec041124454e2be464a75225433b9 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 27 Feb 2025 12:12:59 +0800 Subject: [PATCH 1/3] Update exported_program_translator.py --- .../torch/exported_program_translator.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0acc6ec1a019..c8d9d12505c6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -204,16 +204,33 @@ def create_convert_map( "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), "floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv), + "ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge), + "ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge), + "gt.Scalar": self._binary_op(relax.op.greater, operator.gt), + "gt.Tensor": self._binary_op(relax.op.greater, operator.gt), + "le.Scalar": self._binary_op(relax.op.less_equal, operator.le), + "le.Tensor": self._binary_op(relax.op.less_equal, operator.le), "lt.Scalar": self._binary_op(relax.op.less, operator.lt), "lt.Tensor": self._binary_op(relax.op.less, operator.lt), "matmul.default": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), "max.other": self._binary_op(relax.op.maximum, max), + "min.other": self._binary_op(relax.op.minimum, min), + "remainder.Tensor": self._binary_op(relax.op.mod, operator.mod), + "remainder.Scalar": self._binary_op(relax.op.mod, operator.mod), "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul), + "ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne), + "ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne), "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), + "__and__.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_), + "__and__.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_), + "__or__.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), + "__or__.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), + "__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor), + "__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor), # neural network "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, From 7ec45806d91e86df65725a178fd4b43b9112a175 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 27 Feb 2025 12:16:43 +0800 Subject: [PATCH 2/3] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 358 +++++------------- 1 file changed, 85 insertions(+), 273 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 52cdc12bb781..514f119189d1 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -542,233 +542,142 @@ def main( verify_model(Triu(), example_args, {}, expected_triu) -def test_binary(): +operator_binary_1 = [ + (operator.add, R.add), + (operator.sub, R.subtract), + (operator.mul, R.multiply), + (operator.truediv, R.divide), + (operator.floordiv, R.floor_divide), + (operator.pow, R.power), + (operator.mod, R.mod), + (operator.and_, R.bitwise_and), + (operator.or_, R.bitwise_or), + (operator.xor, R.bitwise_xor), +] + + +@pytest.mark.parametrize("op, relax_op", operator_binary_1) +def test_binary1(op, relax_op): example_args1 = ( torch.randn(10, 10, dtype=torch.float32), torch.randn(10, 10, dtype=torch.float32), ) example_args2 = (torch.randn(10, 10, dtype=torch.float32),) - # Add - class Add1(Module): + class Binary1(Module): + def __init__(self, op): + super().__init__() + self.op = op + def forward(self, lhs, rhs): - return lhs + rhs + return self.op(lhs, rhs) @tvm.script.ir_module - class expected_add1: + class expected_binary1: @R.function def main( lhs: R.Tensor((10, 10), dtype="float32"), rhs: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs, rhs) + lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs) gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - class Add2(Module): - def forward(self, lhs): - return lhs + 1.0 - - @tvm.script.ir_module - class expected_add2: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs_1, R.const(1.0)) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Add1(), example_args1, {}, expected_add1) - verify_model(Add2(), example_args2, {}, expected_add2) - - # True div - class TrueDiv1(Module): - def forward(self, lhs, rhs): - return lhs / rhs - - @tvm.script.ir_module - class expected_truediv1: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - rhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, rhs_1) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv + class Binary2(Module): + def __init__(self, op): + super().__init__() + self.op = op - class TrueDiv2(Module): def forward(self, lhs): - return lhs / 1.0 + return self.op(lhs, 1.0) @tvm.script.ir_module - class expected_truediv2: + class expected_binary2: @R.function def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), + lhs: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, R.const(1.0)) + lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, R.const(1.0)) gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - verify_model(TrueDiv1(), example_args1, {}, expected_truediv1) - verify_model(TrueDiv2(), example_args2, {}, expected_truediv2) - - # EQ - class EQ1(Module): - def forward(self, lhs, rhs): - return lhs == rhs + verify_model(Binary1(op), example_args1, {}, expected_binary1) + verify_model(Binary2(op), example_args2, {}, expected_binary2) - @tvm.script.ir_module - class expected_eq1: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - rhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, rhs_1) - gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) - R.output(gv) - return gv - - class EQ2(Module): - def forward(self, lhs): - return lhs == 1.0 - @tvm.script.ir_module - class expected_eq2: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, R.const(1.0)) - gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) - R.output(gv) - return gv - - verify_model(EQ1(), example_args1, {}, expected_eq1) - verify_model(EQ2(), example_args2, {}, expected_eq2) - - # Floor div - class FloorDiv1(Module): - def forward(self, lhs, rhs): - return lhs // rhs +operator_binary_2 = [ + (operator.eq, R.equal), + (operator.ne, R.not_equal), + (operator.lt, R.less), + (operator.le, R.less_equal), + (operator.gt, R.greater), + (operator.ge, R.greater_equal), +] - @tvm.script.ir_module - class expected_floordiv1: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - rhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.floor_divide(lhs_1, rhs_1) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - class FloorDiv2(Module): - def forward(self, lhs): - return lhs // 1.0 - - @tvm.script.ir_module - class expected_floordiv2: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.floor_divide(lhs_1, R.const(1.0)) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv +@pytest.mark.parametrize("op, relax_op", operator_binary_2) +def test_binary2(op, relax_op): + example_args1 = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + example_args2 = (torch.randn(10, 10, dtype=torch.float32),) - verify_model(FloorDiv1(), example_args1, {}, expected_floordiv1) - verify_model(FloorDiv2(), example_args2, {}, expected_floordiv2) + class Binary1(Module): + def __init__(self, op): + super().__init__() + self.op = op - # LT - class LT1(Module): def forward(self, lhs, rhs): - return lhs < rhs + return self.op(lhs, rhs) @tvm.script.ir_module - class expected_lt1: + class expected_binary1: @R.function def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - rhs_1: R.Tensor((10, 10), dtype="float32"), + lhs: R.Tensor((10, 10), dtype="float32"), + rhs: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): - # block 0 with R.dataflow(): - lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, rhs_1) + lv: R.Tensor((10, 10), dtype="bool") = relax_op(lhs, rhs) gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) R.output(gv) return gv - class LT2(Module): + class Binary2(Module): + def __init__(self, op): + super().__init__() + self.op = op + def forward(self, lhs): - return lhs < 1.0 + return self.op(lhs, 1.0) @tvm.script.ir_module - class expected_lt2: + class expected_binary2: @R.function def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), + lhs: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): - # block 0 with R.dataflow(): - lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, R.const(1.0)) + lv: R.Tensor((10, 10), dtype="bool") = relax_op(lhs, R.const(1.0)) gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) R.output(gv) return gv - verify_model(LT1(), example_args1, {}, expected_lt1) - verify_model(LT2(), example_args2, {}, expected_lt2) - - # MatMul - class MatMul1(Module): - def __init__(self): - super().__init__() + verify_model(Binary1(op), example_args1, {}, expected_binary1) + verify_model(Binary2(op), example_args2, {}, expected_binary2) - def forward(self, x, y): - return torch.matmul(x, y) - @tvm.script.ir_module - class expected_matmul1: - @R.function - def main( - input_1: R.Tensor((10, 10), dtype="float32"), - input_2: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.matmul( - input_1, input_2, out_dtype="float32" - ) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(MatMul1(), example_args1, {}, expected_matmul1) +def test_binary3(): + example_args1 = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + example_args2 = (torch.randn(10, 10, dtype=torch.float32),) # Max class Max1(Module): @@ -790,122 +699,25 @@ def main( verify_model(Max1(), example_args1, {}, expected_max1) - # Mul - class Mul1(Module): - def forward(self, lhs, rhs): - return lhs * rhs - - @tvm.script.ir_module - class expected_mul1: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - rhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, rhs_1) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - class Mul2(Module): - def forward(self, lhs): - return lhs * 1.0 - - @tvm.script.ir_module - class expected_mul2: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, R.const(1.0)) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Mul1(), example_args1, {}, expected_mul1) - verify_model(Mul2(), example_args2, {}, expected_mul2) - - # Power - class Power1(Module): - def forward(self, lhs, rhs): - return lhs**rhs - - @tvm.script.ir_module - class expected_power1: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - rhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, rhs_1) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - class Power2(Module): - def forward(self, lhs): - return lhs**1.0 - - @tvm.script.ir_module - class expected_power2: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, R.const(1.0)) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Power1(), example_args1, {}, expected_power1) - verify_model(Power2(), example_args2, {}, expected_power2) - - # Sub - class Sub1(Module): - def forward(self, lhs, rhs): - return lhs - rhs - - @tvm.script.ir_module - class expected_sub1: - @R.function - def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), - rhs_1: R.Tensor((10, 10), dtype="float32"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, rhs_1) - gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - class Sub2(Module): - def forward(self, lhs): - return lhs - 1.0 + # Min + class Min1(Module): + def forward(self, x, y): + return torch.min(x, y) - @tvm.script.ir_module - class expected_sub2: + @I.ir_module + class expected_min1: @R.function def main( - lhs_1: R.Tensor((10, 10), dtype="float32"), + inp_0: R.Tensor((10, 10), dtype="float32"), + inp_1: R.Tensor((10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): - # block 0 with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, R.const(1.0)) + lv: R.Tensor((10, 10), dtype="float32") = R.minimum(inp_0, inp_1) gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - verify_model(Sub1(), example_args1, {}, expected_sub1) - verify_model(Sub2(), example_args2, {}, expected_sub2) + verify_model(Min1(), example_args1, {}, expected_min1) @pytest.mark.skipif( From 58b15b76e1c5c6016db5bc2eabddd89dc175d880 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 28 Feb 2025 10:21:00 +0800 Subject: [PATCH 3/3] Update test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 514f119189d1..8ca335c2fe7a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import operator import pytest import torch from torch.nn import Module