diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7bcd20c462bd..1c676d02675b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -156,33 +156,44 @@ def create_convert_map( return { # unary + "abs.default": self._unary_op(relax.op.abs), "acos.default": self._unary_op(relax.op.acos), "acosh.default": self._unary_op(relax.op.acosh), "asin.default": self._unary_op(relax.op.asin), "asinh.default": self._unary_op(relax.op.asinh), "atan.default": self._unary_op(relax.op.atan), "atanh.default": self._unary_op(relax.op.atanh), + "bitwise_not.default": self._unary_op(relax.op.bitwise_not), + "ceil.default": self._unary_op(relax.op.ceil), "clamp.default": self._clamp, "cos.default": self._unary_op(relax.op.cos), "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], + "erf.default": self._unary_op(relax.op.erf), "exp.default": self._unary_op(relax.op.exp), + "floor.default": self._unary_op(relax.op.floor), "gelu.default": self._gelu, "hardsigmoid.default": self._hardsigmoid, "hardswish.default": self._hardswish, "hardtanh.default": self._hardtanh, + "isfinite.default": self._unary_op(relax.op.isfinite), + "isinf.default": self._unary_op(relax.op.isinf), + "isnan.default": self._unary_op(relax.op.isnan), "leaky_relu.default": self._leakyrelu, + "log.default": self._unary_op(relax.op.log), "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "relu.default": self._unary_op(relax.op.nn.relu), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), "sigmoid.default": self._unary_op(relax.op.sigmoid), + "sign.default": self._unary_op(relax.op.sign), "silu.default": self._unary_op(relax.op.nn.silu), "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), "softmax.int": self._softmax, "sqrt.default": self._unary_op(relax.op.sqrt), + "square.default": self._unary_op(relax.op.square), "tan.default": self._unary_op(relax.op.tan), "tanh.default": self._unary_op(relax.op.tanh), "tril.default": self._tril_triu(relax.op.tril), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 0d8425fc7f30..33379e74ac24 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import torch from torch.nn import Module from torch.export import export @@ -36,235 +37,241 @@ def verify_model(torch_model, example_args, binding, expected): tvm.ir.assert_structural_equal(mod, expected) -def test_unary(): +operator_basic_unary = [ + (torch.abs, R.abs), + (torch.acos, R.acos), + (torch.acosh, R.acosh), + (torch.asin, R.asin), + (torch.asinh, R.asinh), + (torch.atan, R.atan), + (torch.atanh, R.atanh), + (torch.bitwise_not, R.bitwise_not), + (torch.ceil, R.ceil), + (torch.cos, R.cos), + (torch.cosh, R.cosh), + (torch.erf, R.erf), + (torch.exp, R.exp), + (torch.floor, R.floor), + (torch.log, R.log), + (torch.neg, R.negative), + (torch.round, R.round), + (torch.rsqrt, R.rsqrt), + (torch.sin, R.sin), + (torch.sinh, R.sinh), + (torch.sign, R.sign), + (torch.sqrt, R.sqrt), + (torch.square, R.square), + (torch.tan, R.tan), + (torch.tanh, R.tanh), +] + + +@pytest.mark.parametrize("pytorch_op, relax_op", operator_basic_unary) +def test_basic_unary_ops(pytorch_op, relax_op): example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - # acos - class Acos(Module): + class UnaryOp(Module): def forward(self, input): - return torch.acos(input) + return pytorch_op(input) @tvm.script.ir_module - class expected_acos: + class expected: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(input_1) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - verify_model(Acos(), example_args, {}, expected_acos) + verify_model(UnaryOp(), example_args, {}, expected) - # acosh - class Acosh(Module): - def forward(self, input): - return torch.acosh(input) - @tvm.script.ir_module - class expected_acosh: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acosh(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv +operator_bool_unary = [ + (torch.isfinite, R.isfinite), + (torch.isinf, R.isinf), + (torch.isnan, R.isnan), +] + - verify_model(Acosh(), example_args, {}, expected_acosh) +@pytest.mark.parametrize("pytorch_op, relax_op", operator_bool_unary) +def test_bool_unary_ops(pytorch_op, relax_op): + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - # asin - class Asin(Module): + class UnaryOp(Module): def forward(self, input): - return torch.asin(input) + return pytorch_op(input) @tvm.script.ir_module - class expected_asin: + class expected: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv,) R.output(gv) return gv - verify_model(Asin(), example_args, {}, expected_asin) + verify_model(UnaryOp(), example_args, {}, expected) - # asinh - class Asinh(Module): - def forward(self, input): - return torch.asinh(input) - @tvm.script.ir_module - class expected_asinh: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asinh(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Asinh(), example_args, {}, expected_asinh) +def test_extended_unary_ops(): + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - # atan - class Atan(Module): + # clamp + class Clamp(Module): def forward(self, input): - return torch.atan(input) + return torch.clamp(input, min=0.1, max=0.5) @tvm.script.ir_module - class expected_atan: + class expected_clamp: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - verify_model(Atan(), example_args, {}, expected_atan) - - # atanh - class Atanh(Module): - def forward(self, input): - return torch.atanh(input) + verify_model(Clamp(), example_args, {}, expected_clamp) - @tvm.script.ir_module - class expected_atanh: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atanh(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv + # dropout + class Dropout1(Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(0.5) - verify_model(Atanh(), example_args, {}, expected_atanh) + def forward(self, input): + return self.dropout(input) - # cos - class Cos(Module): + class Dropout2(Module): def forward(self, input): - return torch.cos(input) + return torch.dropout(input, 0.5, train=True) @tvm.script.ir_module - class expected_cos: + class expected_dropout: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input_1,) R.output(gv) return gv - verify_model(Cos(), example_args, {}, expected_cos) + verify_model(Dropout1(), example_args, {}, expected_dropout) + verify_model(Dropout2(), example_args, {}, expected_dropout) + + # gelu + class Gelu(Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU() + + def forward(self, input): + return self.gelu(input) - # cosh - class Cosh(Module): + class Gelu2(Module): def forward(self, input): - return torch.cosh(input) + return torch.nn.functional.gelu(input) @tvm.script.ir_module - class expected_cosh: + class expected_gelu: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - verify_model(Cosh(), example_args, {}, expected_cosh) + verify_model(Gelu(), example_args, {}, expected_gelu) + verify_model(Gelu2(), example_args, {}, expected_gelu) - # dropout - class Dropout1(Module): + # hardsigmoid + class Hardsigmoid(torch.nn.Module): def __init__(self): super().__init__() - self.dropout = torch.nn.Dropout(0.5) + self.hs = torch.nn.Hardsigmoid() def forward(self, input): - return self.dropout(input) + return self.hs(input) - class Dropout2(Module): + class Hardsigmoid2(torch.nn.Module): def forward(self, input): - return torch.dropout(input, 0.5, train=True) + return torch.nn.functional.hardsigmoid(input) @tvm.script.ir_module - class expected_dropout: + class expected_hardsigmoid: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 with R.dataflow(): - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input_1,) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,) R.output(gv) return gv - verify_model(Dropout1(), example_args, {}, expected_dropout) - verify_model(Dropout2(), example_args, {}, expected_dropout) + verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid) + verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid) + + # hardwish + class Hardswish(torch.nn.Module): + def __init__(self): + super().__init__() + self.hs = torch.nn.Hardswish() - # exp - class Exp(Module): def forward(self, input): - return torch.exp(input) + return self.hs(input) + + class Hardswish2(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.hardswish(input) @tvm.script.ir_module - class expected_exp: + class expected1: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( + lv1, R.const(6, "float32") + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) R.output(gv) return gv - verify_model(Exp(), example_args, {}, expected_exp) + verify_model(Hardswish(), example_args, {}, expected1) + verify_model(Hardswish2(), example_args, {}, expected1) - # neg - class Neg(Module): - def forward(self, input): - return -input + # hardtanh + test_hardtanh() - @I.ir_module - class expected_neg: - @R.function - def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.negative(inp_0) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv + # leakyrelu + test_leakyrelu() - verify_model(Neg(), example_args, {}, expected_neg) + # log_softmax + test_logsoftmax() # relu class ReLU0(Module): @@ -295,26 +302,6 @@ def main( verify_model(ReLU0(), example_args, {}, expected_relu) verify_model(ReLU1(), example_args, {}, expected_relu) - # rsqrt - class Rsqrt(Module): - def forward(self, input): - return torch.rsqrt(input) - - @I.ir_module - class expected_rsqrt: - @R.function - def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.rsqrt(inp_0) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Rsqrt(), example_args, {}, expected_rsqrt) - # sigmoid class Sigmoid(Module): def __init__(self): @@ -373,227 +360,11 @@ def main( verify_model(SiLU(), example_args, {}, expected_silu) verify_model(SiLU2(), example_args, {}, expected_silu) - # sin - class Sin(Module): - def forward(self, input: torch.Tensor): - return torch.sin(input) - - @tvm.script.ir_module - class expected_sin: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Sin(), example_args, {}, expected_sin) - - # sinh - class Sinh(Module): - def forward(self, input): - return torch.sinh(input) + # softmax + test_softmax() - @tvm.script.ir_module - class expected_sinh: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Sinh(), example_args, {}, expected_sinh) - - # sqrt - class Sqrt(Module): - def forward(self, input): - return torch.sqrt(input) - - @tvm.script.ir_module - class expected_sqrt: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Sqrt(), example_args, {}, expected_sqrt) - - # tan - class Tan(Module): - def forward(self, input): - return torch.tan(input) - - @tvm.script.ir_module - class expected_tan: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Tan(), example_args, {}, expected_tan) - - # tanh - class Tanh(Module): - def forward(self, input): - return torch.tanh(input) - - @tvm.script.ir_module - class expected_tanh: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Tanh(), example_args, {}, expected_tanh) - - -def test_clamp(): - class Clamp(Module): - def forward(self, input): - return torch.clamp(input, min=0.1, max=0.5) - - @tvm.script.ir_module - class expected_clamp: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Clamp(), example_args, {}, expected_clamp) - - -def test_gelu(): - class Gelu(Module): - def __init__(self): - super().__init__() - self.gelu = torch.nn.GELU() - - def forward(self, input): - return self.gelu(input) - - class Gelu2(Module): - def forward(self, input): - return torch.nn.functional.gelu(input) - - @tvm.script.ir_module - class expected_gelu: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Gelu(), example_args, {}, expected_gelu) - verify_model(Gelu2(), example_args, {}, expected_gelu) - - -def test_hardsigmoid(): - class Hardsigmoid(torch.nn.Module): - def __init__(self): - super().__init__() - self.hs = torch.nn.Hardsigmoid() - - def forward(self, input): - return self.hs(input) - - class Hardsigmoid2(torch.nn.Module): - def forward(self, input): - return torch.nn.functional.hardsigmoid(input) - - @tvm.script.ir_module - class expected_hardsigmoid: - @R.function - def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( - lv1, R.const(6, "float32") - ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,) - R.output(gv) - return gv - - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid) - verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid) - - -def test_hardswish(): - class Hardswish(torch.nn.Module): - def __init__(self): - super().__init__() - self.hs = torch.nn.Hardswish() - - def forward(self, input): - return self.hs(input) - - class Hardswish2(torch.nn.Module): - def forward(self, input): - return torch.nn.functional.hardswish(input) - - @tvm.script.ir_module - class expected1: - @R.function - def main( - inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) - lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( - lv1, R.const(6, "float32") - ) - lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) - R.output(gv) - return gv - - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Hardswish(), example_args, {}, expected1) - verify_model(Hardswish2(), example_args, {}, expected1) + # tril, triu + test_tril_triu() def test_hardtanh(): @@ -695,28 +466,6 @@ def main( verify_model(LogSoftmax2(), example_args, {}, expected1) -def test_round(): - class Round(Module): - def forward(self, input): - return torch.round(input) - - @tvm.script.ir_module - class expected: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.round(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(Round(), example_args, {}, expected) - - def test_softmax(): class Softmax(Module): def __init__(self):