diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index eefa547efdee..e519c9eef397 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -30,6 +30,8 @@ #include #include +#include + #include "../intrin_rule.h" namespace tvm { @@ -175,7 +177,15 @@ TVM_REGISTER_OP("tir.asin") PrimExpr term7 = term5 * x2 * make_const(x.dtype(), 25) / make_const(x.dtype(), 112); PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456); PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160); - return term1 + term3 + term5 + term7 + term9 + term11; + PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11; + /* --- domain limit check --- */ + PrimExpr lower = make_const(x.dtype(), -1.0); + PrimExpr upper = make_const(x.dtype(), 1.0); + PrimExpr out_range = tir::Or(x upper); + // Use a quiet NaN constant + PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); + // select: if out of [-1,1] → NaN, else → series + return tir::Select(out_range, nan_const, series); }); TVM_REGISTER_OP("tir.acos") diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index c73e9c36879c..3e731f55fb0e 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -100,6 +100,23 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5): func(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=atol, rtol=rtol) + # Out‐of‐bounds test for asin/acos + name = tvm_intrin.__name__ + if name in ("asin", "acos"): + # generate some values outside [-1, 1] + n = 8 + out_np = np.concatenate( + [ + np.random.uniform(1.1, 2.0, size=n // 2), + np.random.uniform(-2.0, -1.1, size=n // 2), + ] + ).astype(A.dtype) + a2 = tvm.nd.array(out_np, dev) + b2 = tvm.nd.array(np.empty_like(out_np), dev) + func(a2, b2) + # all outputs should be NaN + assert np.all(np.isnan(b2.numpy())) + for func in test_funcs: atol = rtol = 1e-3 if func[0].__name__ in ["asin", "acos", "atan"] else 1e-5 run_test(*func, atol, rtol)