From ff185c63b37ed91f96254de81632fe54613ae45f Mon Sep 17 00:00:00 2001
From: Trevor Gross <tmgross@umich.edu>
Date: Thu, 26 Dec 2024 09:13:58 +0000
Subject: [PATCH] Replace string function name matching with enums where
 possible

---
 crates/libm-test/src/gen/random.rs  |   5 +-
 crates/libm-test/src/precision.rs   | 100 ++++++++++++----------------
 crates/libm-test/src/test_traits.rs |   3 -
 3 files changed, 46 insertions(+), 62 deletions(-)

diff --git a/crates/libm-test/src/gen/random.rs b/crates/libm-test/src/gen/random.rs
index e347b3c6..527cd135 100644
--- a/crates/libm-test/src/gen/random.rs
+++ b/crates/libm-test/src/gen/random.rs
@@ -7,7 +7,7 @@ use rand::{Rng, SeedableRng};
 use rand_chacha::ChaCha8Rng;
 
 use super::CachedInput;
-use crate::{CheckCtx, GenerateInput};
+use crate::{BaseName, CheckCtx, GenerateInput};
 
 const SEED: [u8; 32] = *b"3.141592653589793238462643383279";
 
@@ -110,7 +110,6 @@ pub fn get_test_cases<RustArgs>(ctx: &CheckCtx) -> impl Iterator<Item = RustArgs
 where
     CachedInput: GenerateInput<RustArgs>,
 {
-    let inputs =
-        if ctx.fn_name == "jn" || ctx.fn_name == "jnf" { &TEST_CASES_JN } else { &TEST_CASES };
+    let inputs = if ctx.base_name == BaseName::Jn { &TEST_CASES_JN } else { &TEST_CASES };
     inputs.get_cases()
 }
diff --git a/crates/libm-test/src/precision.rs b/crates/libm-test/src/precision.rs
index cf911543..c7f9d9e3 100644
--- a/crates/libm-test/src/precision.rs
+++ b/crates/libm-test/src/precision.rs
@@ -6,7 +6,7 @@ use core::f32;
 use CheckBasis::{Mpfr, Musl};
 use Identifier as Id;
 
-use crate::{CheckBasis, CheckCtx, Float, Identifier, Int, TestResult};
+use crate::{BaseName, CheckBasis, CheckCtx, Float, Identifier, Int, TestResult};
 
 /// Type implementing [`IgnoreCase`].
 pub struct SpecialCase;
@@ -106,25 +106,26 @@ impl MaybeOverride<(f32,)> for SpecialCase {
         ctx: &CheckCtx,
     ) -> Option<TestResult> {
         if ctx.basis == CheckBasis::Musl {
-            if ctx.fn_name == "expm1f" && input.0 > 80.0 && actual.is_infinite() {
+            if ctx.base_name == BaseName::Expm1 && input.0 > 80.0 && actual.is_infinite() {
                 // we return infinity but the number is representable
                 return XFAIL;
             }
 
-            if ctx.fn_name == "sinhf" && input.0.abs() > 80.0 && actual.is_nan() {
+            if ctx.base_name == BaseName::Sinh && input.0.abs() > 80.0 && actual.is_nan() {
                 // we return some NaN that should be real values or infinite
                 // doesn't seem to happen on x86
                 return XFAIL;
             }
         }
 
-        if ctx.fn_name == "acoshf" && input.0 < -1.0 {
+        if ctx.base_name == BaseName::Acosh && input.0 < -1.0 {
             // acoshf is undefined for x <= 1.0, but we return a random result at lower
             // values.
             return XFAIL;
         }
 
-        if ctx.fn_name == "lgammaf" || ctx.fn_name == "lgammaf_r" && input.0 < 0.0 {
+        if ctx.base_name == BaseName::Lgamma || ctx.base_name == BaseName::LgammaR && input.0 < 0.0
+        {
             // loggamma should not be defined for x < 0, yet we both return results
             return XFAIL;
         }
@@ -141,7 +142,7 @@ impl MaybeOverride<(f32,)> for SpecialCase {
         // On MPFR for lgammaf_r, we set -1 as the integer result for negative infinity but MPFR
         // sets +1
         if ctx.basis == CheckBasis::Mpfr
-            && ctx.fn_name == "lgammaf_r"
+            && ctx.base_name == BaseName::LgammaR
             && input.0 == f32::NEG_INFINITY
             && actual.abs() == expected.abs()
         {
@@ -161,13 +162,13 @@ impl MaybeOverride<(f64,)> for SpecialCase {
         ctx: &CheckCtx,
     ) -> Option<TestResult> {
         if ctx.basis == CheckBasis::Musl {
-            if cfg!(target_arch = "x86") && ctx.fn_name == "acosh" && input.0 < 1.0 {
+            if cfg!(target_arch = "x86") && ctx.base_name == BaseName::Acosh && input.0 < 1.0 {
                 // The function is undefined, both implementations return random results
                 return SKIP;
             }
 
             if cfg!(x86_no_sse)
-                && ctx.fn_name == "ceil"
+                && ctx.base_name == BaseName::Ceil
                 && input.0 < 0.0
                 && input.0 > -1.0
                 && expected == F::ZERO
@@ -178,13 +179,14 @@ impl MaybeOverride<(f64,)> for SpecialCase {
             }
         }
 
-        if ctx.fn_name == "acosh" && input.0 < 1.0 {
+        if ctx.base_name == BaseName::Acosh && input.0 < 1.0 {
             // The function is undefined for the inputs, musl and our libm both return
             // random results.
             return XFAIL;
         }
 
-        if ctx.fn_name == "lgamma" || ctx.fn_name == "lgamma_r" && input.0 < 0.0 {
+        if ctx.base_name == BaseName::Lgamma || ctx.base_name == BaseName::LgammaR && input.0 < 0.0
+        {
             // loggamma should not be defined for x < 0, yet we both return results
             return XFAIL;
         }
@@ -201,7 +203,7 @@ impl MaybeOverride<(f64,)> for SpecialCase {
         // On MPFR for lgamma_r, we set -1 as the integer result for negative infinity but MPFR
         // sets +1
         if ctx.basis == CheckBasis::Mpfr
-            && ctx.fn_name == "lgamma_r"
+            && ctx.base_name == BaseName::LgammaR
             && input.0 == f64::NEG_INFINITY
             && actual.abs() == expected.abs()
         {
@@ -214,7 +216,7 @@ impl MaybeOverride<(f64,)> for SpecialCase {
 
 /// Check NaN bits if the function requires it
 fn maybe_check_nan_bits<F: Float>(actual: F, expected: F, ctx: &CheckCtx) -> Option<TestResult> {
-    if !(ctx.base_name_str == "fabs" || ctx.base_name_str == "copysign") {
+    if !(ctx.base_name == BaseName::Fabs || ctx.base_name == BaseName::Copysign) {
         return None;
     }
 
@@ -270,24 +272,16 @@ fn maybe_skip_binop_nan<F1: Float, F2: Float>(
     expected: F2,
     ctx: &CheckCtx,
 ) -> Option<TestResult> {
-    match ctx.basis {
-        CheckBasis::Musl => {
-            if (ctx.base_name_str == "fmax" || ctx.base_name_str == "fmin")
-                && (input.0.is_nan() || input.1.is_nan())
-                && expected.is_nan()
-            {
-                XFAIL
-            } else {
-                None
-            }
-        }
-        CheckBasis::Mpfr => {
-            if ctx.base_name_str == "copysign" && input.1.is_nan() {
-                SKIP
-            } else {
-                None
-            }
+    match (&ctx.basis, ctx.base_name) {
+        (Musl, BaseName::Fmin | BaseName::Fmax)
+            if (input.0.is_nan() || input.1.is_nan()) && expected.is_nan() =>
+        {
+            XFAIL
         }
+
+        (Mpfr, BaseName::Copysign) if input.1.is_nan() => SKIP,
+
+        _ => None,
     }
 }
 
@@ -299,20 +293,17 @@ impl MaybeOverride<(i32, f32)> for SpecialCase {
         ulp: &mut u32,
         ctx: &CheckCtx,
     ) -> Option<TestResult> {
-        match ctx.basis {
-            CheckBasis::Musl => bessel_prec_dropoff(input, ulp, ctx),
-            CheckBasis::Mpfr => {
-                // We return +0.0, MPFR returns -0.0
-                if ctx.fn_name == "jnf"
-                    && input.1 == f32::NEG_INFINITY
-                    && actual == F::ZERO
-                    && expected == F::ZERO
-                {
-                    XFAIL
-                } else {
-                    None
-                }
+        match (&ctx.basis, ctx.base_name) {
+            (Musl, _) => bessel_prec_dropoff(input, ulp, ctx),
+
+            // We return +0.0, MPFR returns -0.0
+            (Mpfr, BaseName::Jn)
+                if input.1 == f32::NEG_INFINITY && actual == F::ZERO && expected == F::ZERO =>
+            {
+                XFAIL
             }
+
+            _ => None,
         }
     }
 }
@@ -324,20 +315,17 @@ impl MaybeOverride<(i32, f64)> for SpecialCase {
         ulp: &mut u32,
         ctx: &CheckCtx,
     ) -> Option<TestResult> {
-        match ctx.basis {
-            CheckBasis::Musl => bessel_prec_dropoff(input, ulp, ctx),
-            CheckBasis::Mpfr => {
-                // We return +0.0, MPFR returns -0.0
-                if ctx.fn_name == "jn"
-                    && input.1 == f64::NEG_INFINITY
-                    && actual == F::ZERO
-                    && expected == F::ZERO
-                {
-                    XFAIL
-                } else {
-                    bessel_prec_dropoff(input, ulp, ctx)
-                }
+        match (&ctx.basis, ctx.base_name) {
+            (Musl, _) => bessel_prec_dropoff(input, ulp, ctx),
+
+            // We return +0.0, MPFR returns -0.0
+            (Mpfr, BaseName::Jn)
+                if input.1 == f64::NEG_INFINITY && actual == F::ZERO && expected == F::ZERO =>
+            {
+                XFAIL
             }
+
+            _ => None,
         }
     }
 }
@@ -348,7 +336,7 @@ fn bessel_prec_dropoff<F: Float>(
     ulp: &mut u32,
     ctx: &CheckCtx,
 ) -> Option<TestResult> {
-    if ctx.base_name_str == "jn" {
+    if ctx.base_name == BaseName::Jn {
         if input.0 > 4000 {
             return XFAIL;
         } else if input.0 > 2000 {
diff --git a/crates/libm-test/src/test_traits.rs b/crates/libm-test/src/test_traits.rs
index b8e0aa10..ca933bbd 100644
--- a/crates/libm-test/src/test_traits.rs
+++ b/crates/libm-test/src/test_traits.rs
@@ -22,8 +22,6 @@ pub struct CheckCtx {
     pub base_name: BaseName,
     /// Function name.
     pub fn_name: &'static str,
-    /// Return the unsuffixed version of the function name.
-    pub base_name_str: &'static str,
     /// Source of truth for tests.
     pub basis: CheckBasis,
 }
@@ -36,7 +34,6 @@ impl CheckCtx {
             fn_ident,
             fn_name: fn_ident.as_str(),
             base_name: fn_ident.base_name(),
-            base_name_str: fn_ident.base_name().as_str(),
             basis,
         };
         ret.ulp = crate::default_ulp(&ret);