Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ConstantFPRange][UnitTest] Ignore NaN payloads when enumerating values in a range #111083

Merged
merged 1 commit into from
Oct 4, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Oct 4, 2024

NaN payloads can be ignored because they are unrelated with ConstantFPRange (except the conversion from ConstantFPRange to KnownBits). This patch just enumerates +/-[S/Q]NaN to avoid enumerating 32 NaN values in all ranges which contain NaN values.
Addresses comment #110082 (comment). This patch reduces the execution time for unittests from 30.37s to 10.59s with an optimized build.

@dtcxzyw dtcxzyw requested a review from arsenm October 4, 2024 02:29
@llvmbot llvmbot added the llvm:ir label Oct 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-llvm-ir

Author: Yingwei Zheng (dtcxzyw)

Changes

NaN payloads can be ignored because they are unrelated with ConstantFPRange (except the conversion from ConstantFPRange to KnownBits). This patch just enumerates +/-[S/Q]NaN to avoid enumerating 32 NaN values in all ranges which contain NaN values.
Addresses comment #110082 (comment).


Full diff: https://github.com/llvm/llvm-project/pull/111083.diff

1 Files Affected:

  • (modified) llvm/unittests/IR/ConstantFPRangeTest.cpp (+85-24)
diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp
index 17a08207fe1ba0..158d08f9b77a0a 100644
--- a/llvm/unittests/IR/ConstantFPRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp
@@ -150,26 +150,80 @@ static void EnumerateTwoInterestingConstantFPRanges(Fn TestFn,
 
 template <typename Fn>
 static void EnumerateValuesInConstantFPRange(const ConstantFPRange &CR,
-                                             Fn TestFn) {
+                                             Fn TestFn, bool IgnoreNaNPayload) {
   const fltSemantics &Sem = CR.getSemantics();
-  unsigned Bits = APFloat::semanticsSizeInBits(Sem);
-  assert(Bits < 32 && "Too many bits");
-  for (unsigned I = 0, E = (1U << Bits) - 1; I != E; ++I) {
-    APFloat V(Sem, APInt(Bits, I));
-    if (CR.contains(V))
-      TestFn(V);
+  if (IgnoreNaNPayload) {
+    if (CR.containsSNaN()) {
+      TestFn(APFloat::getSNaN(Sem, false));
+      TestFn(APFloat::getSNaN(Sem, true));
+    }
+    if (CR.containsQNaN()) {
+      TestFn(APFloat::getQNaN(Sem, false));
+      TestFn(APFloat::getQNaN(Sem, true));
+    }
+    if (CR.isNaNOnly())
+      return;
+    APFloat Lower = CR.getLower();
+    const APFloat &Upper = CR.getUpper();
+    auto Next = [&](APFloat &V) {
+      if (V.bitwiseIsEqual(Upper))
+        return false;
+      strictNext(V);
+      return true;
+    };
+    do
+      TestFn(Lower);
+    while (Next(Lower));
+  } else {
+    unsigned Bits = APFloat::semanticsSizeInBits(Sem);
+    assert(Bits < 32 && "Too many bits");
+    for (unsigned I = 0, E = (1U << Bits) - 1; I != E; ++I) {
+      APFloat V(Sem, APInt(Bits, I));
+      if (CR.contains(V))
+        TestFn(V);
+    }
   }
 }
 
 template <typename Fn>
-static bool AnyOfValueInConstantFPRange(const ConstantFPRange &CR, Fn TestFn) {
+static bool AnyOfValueInConstantFPRange(const ConstantFPRange &CR, Fn TestFn,
+                                        bool IgnoreNaNPayload) {
   const fltSemantics &Sem = CR.getSemantics();
-  unsigned Bits = APFloat::semanticsSizeInBits(Sem);
-  assert(Bits < 32 && "Too many bits");
-  for (unsigned I = 0, E = (1U << Bits) - 1; I != E; ++I) {
-    APFloat V(Sem, APInt(Bits, I));
-    if (CR.contains(V) && TestFn(V))
+  if (IgnoreNaNPayload) {
+    if (CR.containsSNaN()) {
+      if (TestFn(APFloat::getSNaN(Sem, false)))
+        return true;
+      if (TestFn(APFloat::getSNaN(Sem, true)))
+        return true;
+    }
+    if (CR.containsQNaN()) {
+      if (TestFn(APFloat::getQNaN(Sem, false)))
+        return true;
+      if (TestFn(APFloat::getQNaN(Sem, true)))
+        return true;
+    }
+    if (CR.isNaNOnly())
+      return false;
+    APFloat Lower = CR.getLower();
+    const APFloat &Upper = CR.getUpper();
+    auto Next = [&](APFloat &V) {
+      if (V.bitwiseIsEqual(Upper))
+        return false;
+      strictNext(V);
       return true;
+    };
+    do {
+      if (TestFn(Lower))
+        return true;
+    } while (Next(Lower));
+  } else {
+    unsigned Bits = APFloat::semanticsSizeInBits(Sem);
+    assert(Bits < 32 && "Too many bits");
+    for (unsigned I = 0, E = (1U << Bits) - 1; I != E; ++I) {
+      APFloat V(Sem, APInt(Bits, I));
+      if (CR.contains(V) && TestFn(V))
+        return true;
+    }
   }
   return false;
 }
@@ -385,13 +439,16 @@ TEST_F(ConstantFPRangeTest, FPClassify) {
       [](const ConstantFPRange &CR) {
         unsigned Mask = fcNone;
         bool HasPos = false, HasNeg = false;
-        EnumerateValuesInConstantFPRange(CR, [&](const APFloat &V) {
-          Mask |= V.classify();
-          if (V.isNegative())
-            HasNeg = true;
-          else
-            HasPos = true;
-        });
+        EnumerateValuesInConstantFPRange(
+            CR,
+            [&](const APFloat &V) {
+              Mask |= V.classify();
+              if (V.isNegative())
+                HasNeg = true;
+              else
+                HasPos = true;
+            },
+            /*IgnoreNaNPayload=*/true);
 
         std::optional<bool> SignBit = std::nullopt;
         if (HasPos != HasNeg)
@@ -453,11 +510,15 @@ TEST_F(ConstantFPRangeTest, makeAllowedFCmpRegion) {
           EnumerateValuesInConstantFPRange(
               ConstantFPRange::getFull(CR.getSemantics()),
               [&](const APFloat &V) {
-                if (AnyOfValueInConstantFPRange(CR, [&](const APFloat &U) {
-                      return FCmpInst::compare(V, U, Pred);
-                    }))
+                if (AnyOfValueInConstantFPRange(
+                        CR,
+                        [&](const APFloat &U) {
+                          return FCmpInst::compare(V, U, Pred);
+                        },
+                        /*IgnoreNaNPayload=*/true))
                   Optimal = Optimal.unionWith(ConstantFPRange(V));
-              });
+              },
+              /*IgnoreNaNPayload=*/true);
 
           EXPECT_TRUE(Res.contains(Optimal))
               << "Wrong result for makeAllowedFCmpRegion(" << Pred << ", " << CR

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Oct 4, 2024

cc @vporpo @aeubanks

@arsenm arsenm added the floating-point Floating-point math label Oct 4, 2024
@dtcxzyw dtcxzyw merged commit 856774d into llvm:main Oct 4, 2024
12 checks passed
@dtcxzyw dtcxzyw deleted the cfr-enum-opt branch October 4, 2024 08:24
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
…es in a range (llvm#111083)

NaN payloads can be ignored because they are unrelated with
ConstantFPRange (except the conversion from ConstantFPRange to
KnownBits). This patch just enumerates `+/-[S/Q]NaN` to avoid
enumerating 32 NaN values in all ranges which contain NaN values.
Addresses comment
llvm#110082 (comment).
This patch reduces the execution time for unittests from 30.37s to
10.59s with an optimized build.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
floating-point Floating-point math llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants