diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d52b3d598f89..a41b9b6d4f9a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -185,6 +185,39 @@ def convert(node: fx.Node) -> relax.Var: return convert + ########## Binary Ops ########## + + def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + def promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def call_binary_op(op, lhs, rhs): + lhs, rhs = promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return call_binary_op(relax_op, lhs, rhs) + elif isinstance(lhs, relax.expr.Constant): + return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) + elif isinstance(rhs, relax.expr.Constant): + return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) + return intrinsic_op(lhs, rhs) + + return convert + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: @@ -283,6 +316,35 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + ########## Statistical ########## + + def _mean(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + + def _sum(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node): + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(op(x, dim, keepdim)) + + return convert + ########## Manipulation ########## def _reshape(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1ceddad7d79f..11594690cdc2 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -19,6 +19,7 @@ # pylint: disable=import-outside-toplevel """PyTorch ExportedProgram of Relax.""" from collections import ChainMap, OrderedDict +from functools import partial from typing import Callable, Dict, List, Tuple import torch @@ -76,6 +77,8 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: + import operator + return { # unary "acos.default": self._unary_op(relax.op.acos), @@ -109,11 +112,33 @@ def create_convert_map( "tanh.default": self._unary_op(relax.op.tanh), "tril.default": self._tril_triu(relax.op.tril), "triu.default": self._tril_triu(relax.op.triu), + # binary + "add.Tensor": self._binary_op(relax.op.add, operator.add), + "div.Tensor": self._binary_op(relax.op.divide, operator.truediv), + "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), + "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), + "floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv), + "lt.Scalar": self._binary_op(relax.op.less, operator.lt), + "lt.Tensor": self._binary_op(relax.op.less, operator.lt), + "matmul.default": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), + "max.other": self._binary_op(relax.op.maximum, max), + "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul), + "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow), + "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), + "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), # neural network "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, + # statistical + "mean.dim": self._mean, + "sum.dim_IntList": self._sum, + # search + "argmax.default": self._argmax_argmin(relax.op.argmax), + "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "view.default": self._reshape, } diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6f7c6fa2c575..dc6ebc2eb34f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -96,39 +96,6 @@ def convert(node: fx.Node) -> relax.Var: return convert - ########## Binary Ops ########## - - def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - def promote_binary_op_args(lhs, rhs): - if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): - return lhs, rhs - elif isinstance(lhs, relax.Expr): - assert isinstance(lhs.struct_info, relax.TensorStructInfo) - return lhs, relax.const(rhs, lhs.struct_info.dtype) - elif isinstance(rhs, relax.Expr): - assert isinstance(rhs.struct_info, relax.TensorStructInfo) - return relax.const(lhs, rhs.struct_info.dtype), rhs - else: - assert False - - def call_binary_op(op, lhs, rhs): - lhs, rhs = promote_binary_op_args(lhs, rhs) - return self.block_builder.emit(op(lhs, rhs)) - - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return call_binary_op(relax_op, lhs, rhs) - elif isinstance(lhs, relax.expr.Constant): - return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) - elif isinstance(rhs, relax.expr.Constant): - return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) - return intrinsic_op(lhs, rhs) - - return convert - ########## Neural Network ########## def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: @@ -794,35 +761,6 @@ def _unbind(self, node: fx.Node) -> relax.Var: ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) - ########## Statistical ########## - - def _mean(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) - return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) - - def _sum(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) - - ########## Search ########## - - def _argmax_argmin(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node): - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) - keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) - return self.block_builder.emit(op(x, dim, keepdim)) - - return convert - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 6c17d96004b6..25e6dbfae308 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -790,6 +790,372 @@ def main( verify_model(Triu(), example_args, {}, expected_triu) +def test_binary(): + example_args1 = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + example_args2 = (torch.randn(10, 10, dtype=torch.float32),) + + # Add + class Add1(Module): + def forward(self, lhs, rhs): + return lhs + rhs + + @tvm.script.ir_module + class expected_add1: + @R.function + def main( + lhs: R.Tensor((10, 10), dtype="float32"), + rhs: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs, rhs) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + + @tvm.script.ir_module + class expected_add2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.add(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Add1(), example_args1, {}, expected_add1) + verify_model(Add2(), example_args2, {}, expected_add2) + + # True div + class TrueDiv1(Module): + def forward(self, lhs, rhs): + return lhs / rhs + + @tvm.script.ir_module + class expected_truediv1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + + @tvm.script.ir_module + class expected_truediv2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.divide(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(TrueDiv1(), example_args1, {}, expected_truediv1) + verify_model(TrueDiv2(), example_args2, {}, expected_truediv2) + + # EQ + class EQ1(Module): + def forward(self, lhs, rhs): + return lhs == rhs + + @tvm.script.ir_module + class expected_eq1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) + R.output(gv) + return gv + + class EQ2(Module): + def forward(self, lhs): + return lhs == 1.0 + + @tvm.script.ir_module + class expected_eq2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bool") = R.equal(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) + R.output(gv) + return gv + + verify_model(EQ1(), example_args1, {}, expected_eq1) + verify_model(EQ2(), example_args2, {}, expected_eq2) + + # Floor div + class FloorDiv1(Module): + def forward(self, lhs, rhs): + return lhs // rhs + + @tvm.script.ir_module + class expected_floordiv1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.floor_divide(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + + @tvm.script.ir_module + class expected_floordiv2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.floor_divide(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(FloorDiv1(), example_args1, {}, expected_floordiv1) + verify_model(FloorDiv2(), example_args2, {}, expected_floordiv2) + + # LT + class LT1(Module): + def forward(self, lhs, rhs): + return lhs < rhs + + @tvm.script.ir_module + class expected_lt1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) + R.output(gv) + return gv + + class LT2(Module): + def forward(self, lhs): + return lhs < 1.0 + + @tvm.script.ir_module + class expected_lt2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="bool") = R.less(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) + R.output(gv) + return gv + + verify_model(LT1(), example_args1, {}, expected_lt1) + verify_model(LT2(), example_args2, {}, expected_lt2) + + # MatMul + class MatMul1(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + @tvm.script.ir_module + class expected_matmul1: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32"), + input_2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(MatMul1(), example_args1, {}, expected_matmul1) + + # Max + class Max1(Module): + def forward(self, x, y): + return torch.max(x, y) + + @I.ir_module + class expected_max1: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32"), + inp_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.maximum(inp_0, inp_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Max1(), example_args1, {}, expected_max1) + + # Mul + class Mul1(Module): + def forward(self, lhs, rhs): + return lhs * rhs + + @tvm.script.ir_module + class expected_mul1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + + @tvm.script.ir_module + class expected_mul2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.multiply(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Mul1(), example_args1, {}, expected_mul1) + verify_model(Mul2(), example_args2, {}, expected_mul2) + + # Power + class Power1(Module): + def forward(self, lhs, rhs): + return lhs**rhs + + @tvm.script.ir_module + class expected_power1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Power2(Module): + def forward(self, lhs): + return lhs**1.0 + + @tvm.script.ir_module + class expected_power2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.power(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Power1(), example_args1, {}, expected_power1) + verify_model(Power2(), example_args2, {}, expected_power2) + + # Sub + class Sub1(Module): + def forward(self, lhs, rhs): + return lhs - rhs + + @tvm.script.ir_module + class expected_sub1: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + rhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, rhs_1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + + @tvm.script.ir_module + class expected_sub2: + @R.function + def main( + lhs_1: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(lhs_1, R.const(1.0)) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(Sub1(), example_args1, {}, expected_sub1) + verify_model(Sub2(), example_args2, {}, expected_sub2) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): @@ -1094,6 +1460,152 @@ def main( verify_model(MaxPool2d3(), example_args, {}, expected3) +def test_mean(): + class Mean(Module): + def forward(self, input): + return input.mean(-1) + + class MeanKeepDim(Module): + def forward(self, input: torch.Tensor): + return input.mean(-1, keepdim=True) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) + gv: R.Tuple(R.Tensor((256,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) + gv: R.Tuple(R.Tensor((256, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(256, 256, dtype=torch.float32),) + verify_model(Mean(), example_args, {}, Expected1) + verify_model(MeanKeepDim(), example_args, {}, Expected2) + + +def test_sum(): + class Sum(Module): + def forward(self, x): + return torch.sum(x, (2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False) + gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(Sum(), example_args, {}, expected1) + + +def test_argmax_argmin(): + example_args = (torch.randn(256, 256, dtype=torch.float32),) + + class Argmax1(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmax(input, dim=-1) + + class Argmax2(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmax(input, dim=-1, keepdim=True) + + @tvm.script.ir_module + class expected_argmax1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256,), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) + gv: R.Tuple(R.Tensor((256,), dtype="int64")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_argmax2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) + gv: R.Tuple(R.Tensor((256, 1), dtype="int64")) = (lv,) + R.output(gv) + return gv + + verify_model(Argmax1(), example_args, {}, expected_argmax1) + verify_model(Argmax2(), example_args, {}, expected_argmax2) + + class Argmin1(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmin(input) + + class Argmin2(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + return torch.argmin(input, keepdim=True) + + @tvm.script.ir_module + class expected_argmin1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="int64")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_argmin2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) + gv: R.Tuple(R.Tensor((1, 1), dtype="int64")) = (lv,) + R.output(gv) + return gv + + verify_model(Argmin1(), example_args, {}, expected_argmin1) + verify_model(Argmin2(), example_args, {}, expected_argmin2) + + def test_view(): class View(Module): def forward(self, x):