diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 58f28fb1b306..42288cf562cd 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -27,9 +27,6 @@ from tvm.script import relax as R from tvm.script import tir as T from tvm.relax.frontend.torch import from_exported_program -from packaging import version - -torch_version = torch.__version__ def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=None): @@ -56,10 +53,17 @@ def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=No (torch.erf, R.erf), (torch.exp, R.exp), (torch.floor, R.floor), + (torch.ops.aten.gelu, R.nn.gelu), (torch.log, R.log), (torch.neg, R.negative), + (torch.relu, R.nn.relu), + (torch.relu_, R.nn.relu), (torch.round, R.round), (torch.rsqrt, R.rsqrt), + (torch.selu, R.nn.selu), + (torch.sigmoid, R.sigmoid), + (torch.ops.aten.silu, R.nn.silu), + (torch.ops.aten.silu_, R.nn.silu), (torch.sin, R.sin), (torch.sinh, R.sinh), (torch.sign, R.sign), @@ -314,35 +318,6 @@ def main( verify_model(Elu(), example_args, {}, expected_elu) verify_model(Elu2(), example_args, {}, expected_elu) - # gelu - class Gelu(Module): - def __init__(self): - super().__init__() - self.gelu = torch.nn.GELU() - - def forward(self, input): - return self.gelu(input) - - class Gelu2(Module): - def forward(self, input): - return torch.nn.functional.gelu(input) - - @tvm.script.ir_module - class expected_gelu: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Gelu(), example_args, {}, expected_gelu) - verify_model(Gelu2(), example_args, {}, expected_gelu) - # hardsigmoid class Hardsigmoid(torch.nn.Module): def __init__(self): @@ -413,15 +388,6 @@ def main( verify_model(Hardswish2(), example_args, {}, expected1) verify_model(Hardswish3(), example_args, {}, expected1) - # hardtanh - test_hardtanh() - - # leakyrelu - test_leakyrelu() - - # softplus - test_softplus() - # log2 class Log2(Module): def forward(self, x): @@ -487,9 +453,6 @@ def main( verify_model(Log1p(), example_args, {}, Expected_log1p) - # log_softmax - test_logsoftmax() - # reciprocal class Reciprocal(Module): def forward(self, input): @@ -511,140 +474,6 @@ def main( verify_model(Reciprocal(), example_args, {}, expected_reciprocal) - # relu - class ReLU0(Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, input): - return self.relu(input) - - class ReLU1(Module): - def forward(self, input): - return torch.nn.functional.relu(input) - - class ReLU2(Module): - def forward(self, input): - return torch.ops.aten.relu_(input) - - @tvm.script.ir_module - class expected_relu: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(ReLU0(), example_args, {}, expected_relu) - verify_model(ReLU1(), example_args, {}, expected_relu) - verify_model(ReLU2(), example_args, {}, expected_relu) - - # selu - class Selu1(Module): - def __init__(self): - super().__init__() - self.selu = torch.nn.SELU() - - def forward(self, input): - return self.selu(input) - - class Selu2(Module): - def forward(self, input): - return torch.nn.functional.selu(input) - - @tvm.script.ir_module - class expected_selu: - @R.function - def main( - input: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.selu(input) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Selu1(), example_args, {}, expected_selu) - verify_model(Selu2(), example_args, {}, expected_selu) - - # sigmoid - class Sigmoid(Module): - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, input): - return self.sigmoid(input) - - class Sigmoid2(Module): - def forward(self, input): - return torch.sigmoid(input) - - @tvm.script.ir_module - class expected_sigmoid: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(Sigmoid(), example_args, {}, expected_sigmoid) - verify_model(Sigmoid2(), example_args, {}, expected_sigmoid) - - # silu - class SiLU(Module): - def __init__(self): - super().__init__() - self.silu = torch.nn.SiLU() - - def forward(self, input): - return self.silu(input) - - class SiLU2(Module): - def forward(self, input): - return torch.nn.functional.silu(input) - - class SiLU3(Module): - def forward(self, input): - return torch.ops.aten.silu_(input) - - @tvm.script.ir_module - class expected_silu: - @R.function - def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 - with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) - R.output(gv) - return gv - - verify_model(SiLU(), example_args, {}, expected_silu) - verify_model(SiLU2(), example_args, {}, expected_silu) - verify_model(SiLU3(), example_args, {}, expected_silu) - - # softmax - test_softmax() - - # softshrink - test_softshrink() - - # tril, triu - test_tril_triu() - def test_hardtanh(): class Hardtanh(torch.nn.Module): @@ -1044,7 +873,6 @@ def test_binary3(): torch.randn(10, 10, dtype=torch.float32), torch.randn(10, 10, dtype=torch.float32), ) - example_args2 = (torch.randn(10, 10, dtype=torch.float32),) # Max class Max1(Module):