diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d49cfa6893e6..dffe2b60eb31 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -660,23 +660,29 @@ def create_convert_map( "triu": self._tril_triu(relax.op.triu), # binary "add": self._binary_op(relax.op.add, operator.add), + "and_": self._binary_op(relax.op.bitwise_and, operator.and_), "eq": self._binary_op(relax.op.equal, operator.eq), "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv), "ge": self._binary_op(relax.op.greater_equal, operator.ge), "gt": self._binary_op(relax.op.greater, operator.gt), "iadd": self._binary_op(relax.op.add, operator.add), "le": self._binary_op(relax.op.less_equal, operator.le), + "lshift": self._binary_op(relax.op.left_shift, operator.lshift), "lt": self._binary_op(relax.op.less, operator.lt), "matmul": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), "max": self._binary_op(relax.op.maximum, max), + "min": self._binary_op(relax.op.minimum, min), "mod": self._binary_op(relax.op.mod, operator.mod), "mul": self._binary_op(relax.op.multiply, operator.mul), "ne": self._binary_op(relax.op.not_equal, operator.ne), "pow": self._binary_op(relax.op.power, operator.pow), + "or_": self._binary_op(relax.op.bitwise_or, operator.or_), + "rshift": self._binary_op(relax.op.right_shift, operator.rshift), "sub": self._binary_op(relax.op.subtract, operator.sub), "truediv": self._binary_op(relax.op.divide, operator.truediv), + "xor": self._binary_op(relax.op.bitwise_xor, operator.xor), # neural network "adaptive_avg_pool2d": self._adaptive_avg_pool2d, "addmm": self._addmm, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 371343b60a46..8b4ea5c8cc98 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1485,6 +1485,8 @@ def main( def test_binary(): input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] input_info2 = [([1, 3, 10, 10], "float32")] + input_info3 = [([1, 3, 10, 10], "int32"), ([1, 3, 10, 10], "int32")] + input_info4 = [([1, 3, 10, 10], "int32")] # Add class Add1(Module): @@ -1962,6 +1964,211 @@ def main( verify_model(Ne1(), input_info1, {}, expected23) verify_model(Ne2(), input_info2, {}, expected24) + # Lshift + class LShift1(Module): + def forward(self, lhs, rhs): + return lhs << rhs + + @tvm.script.ir_module + class expected25: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.left_shift(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + class LShift2(Module): + def forward(self, lhs): + return lhs << 1 + + @tvm.script.ir_module + class expected26: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.left_shift(lhs_1, R.const(1)) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + verify_model(LShift1(), input_info3, {}, expected25) + verify_model(LShift2(), input_info4, {}, expected26) + + # Rshift + class RShift1(Module): + def forward(self, lhs, rhs): + return lhs >> rhs + + @tvm.script.ir_module + class expected27: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.right_shift(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + class RShift2(Module): + def forward(self, lhs): + return lhs >> 1 + + @tvm.script.ir_module + class expected28: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.right_shift(lhs_1, R.const(1)) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + verify_model(RShift1(), input_info3, {}, expected27) + verify_model(RShift2(), input_info4, {}, expected28) + + # Bitwise and + class BitwiseAnd1(Module): + def forward(self, lhs, rhs): + return lhs & rhs + + @tvm.script.ir_module + class expected29: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_and(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + class BitwiseAnd2(Module): + def forward(self, lhs): + return lhs & 1 + + @tvm.script.ir_module + class expected30: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_and(lhs_1, R.const(1)) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + verify_model(BitwiseAnd1(), input_info3, {}, expected29) + verify_model(BitwiseAnd2(), input_info4, {}, expected30) + + # Bitwise or + class BitwiseOr1(Module): + def forward(self, lhs, rhs): + return lhs | rhs + + @tvm.script.ir_module + class expected31: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_or(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + class BitwiseOr2(Module): + def forward(self, lhs): + return lhs | 1 + + @tvm.script.ir_module + class expected32: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_or(lhs_1, R.const(1)) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + verify_model(BitwiseOr1(), input_info3, {}, expected31) + verify_model(BitwiseOr2(), input_info4, {}, expected32) + + # Bitwise xor + class BitwiseXor1(Module): + def forward(self, lhs, rhs): + return lhs ^ rhs + + @tvm.script.ir_module + class expected33: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_xor(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + class BitwiseXor2(Module): + def forward(self, lhs): + return lhs ^ 1 + + @tvm.script.ir_module + class expected34: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="int32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_xor(lhs_1, R.const(1)) + gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv + R.output(gv) + + return gv + + verify_model(BitwiseXor1(), input_info3, {}, expected33) + verify_model(BitwiseXor2(), input_info4, {}, expected34) + def test_size(): input_info = [([1, 3, 10, 10], "float32")] @@ -3745,6 +3952,27 @@ def main( verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1) +def test_min(): + class Min(Module): + def forward(self, x, y): + return torch.min(x, y) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32"), + inp_1: R.Tensor((256, 256), dtype="float32"), + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.minimum(inp_0, inp_1) + gv: R.Tensor((256, 256), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1) + + def test_attention(): @I.ir_module class Expected1: