diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py index 93ad034be2ef..da5eaf415257 100644 --- a/tests/python/relay/test_pass_fast_math.py +++ b/tests/python/relay/test_pass_fast_math.py @@ -47,6 +47,21 @@ def test_tanh(): fast_mod = relay.optimize(mod, target='llvm', params=None) assert "fast_tanh" in fast_mod[0].astext() +def test_erf(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + y = relay.erf(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + fast_mod = FastMath()(mod) + assert "fast_erf" in fast_mod.astext() + + # Check that FastMath option works for relay.build. + with tvm.transform.PassContext(opt_level=3, required_pass=['FastMath']): + fast_mod = relay.optimize(mod, target='llvm', params=None) + assert "fast_erf" in fast_mod[0].astext() + if __name__ == "__main__": test_exp() test_tanh() + test_erf()