Skip to content

Conversation

@AlexVlx
Copy link
Contributor

@AlexVlx AlexVlx commented Dec 11, 2025

This changes the extension parsing mechanism underpinning --spirv-ext to be more explicit about what it is doing and not rely on a sort. More specifically, we partition extensions into enabled (prefixed with +) and others, and then individually handle the resulting ranges.

@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Alex Voicu (AlexVlx)

Changes

This changes the extension parsing mechanism underpinning --spirv-ext to be more explicit about what it is doing and not rely on a sort. More specifically, we partition extensions into enabled (prefixed with +) and others, and then individually handle the resulting ranges.


Full diff: https://github.com/llvm/llvm-project/pull/171826.diff

1 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+30-33)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 42edad255ce82..04c54f9b0e53d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -17,7 +17,9 @@
 #include "llvm/TargetParser/Triple.h"
 
 #include <functional>
+#include <iterator>
 #include <map>
+#include <set>
 #include <string>
 #include <utility>
 #include <vector>
@@ -26,7 +28,7 @@
 
 using namespace llvm;
 
-static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
+static const std::map<StringRef, SPIRV::Extension::Extension>
     SPIRVExtensionMap = {
         {"SPV_EXT_shader_atomic_float_add",
          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_add},
@@ -181,57 +183,52 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   std::set<SPIRV::Extension::Extension> &Vals) {
   SmallVector<StringRef, 10> Tokens;
   ArgValue.split(Tokens, ",", -1, false);
-  llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) {
-    // We want to ensure that we handle "all" first, to ensure that any
-    // subsequent disablement actually behaves as expected i.e. given
-    // --spv-ext=all,-foo, we first enable all and then disable foo; this should
-    // be revisited and simplified.
-    if (LHS == "all")
-      return true;
-    if (RHS == "all")
-      return false;
-    return !(RHS < LHS);
-  });
 
   std::set<SPIRV::Extension::Extension> EnabledExtensions;
 
-  for (const auto &Token : Tokens) {
-    if (Token == "all") {
-      for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
-        EnabledExtensions.insert(ExtensionEnum);
+  auto M = partition(Tokens, [](auto &&T) { return T.starts_with('+'); });
+
+  if (std::any_of(M, Tokens.end(), [](auto &&T) { return T == "all"; }))
+    copy(make_second_range(SPIRVExtensionMap), std::inserter(Vals, Vals.end()));
+
+  for (auto &&Token : make_range(Tokens.begin(), M)) {
+    StringRef ExtensionName = Token.substr(1);
+    auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
 
+    if (NameValuePair == SPIRVExtensionMap.end())
+      return O.error("Unknown SPIR-V extension: " + Token.str());
+
+    EnabledExtensions.insert(NameValuePair->second);
+  }
+
+  for (auto &&Token : make_range(M, Tokens.end())) {
+    if (Token == "all")
       continue;
-    }
 
     if (Token.size() == 3 && Token.upper() == "KHR") {
       for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
         if (StringRef(ExtensionName).starts_with("SPV_KHR_"))
-          EnabledExtensions.insert(ExtensionEnum);
+          Vals.insert(ExtensionEnum);
       continue;
     }
 
     if (Token.empty() || (!Token.starts_with("+") && !Token.starts_with("-")))
-      return O.error("Invalid extension list format: " + Token.str());
+      return O.error("Invalid extension list format: " + Token);
 
-    StringRef ExtensionName = Token.substr(1);
-    auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
+    auto NameValuePair = SPIRVExtensionMap.find(Token.substr(1));
 
-    if (NameValuePair == SPIRVExtensionMap.end())
+    if (NameValuePair == SPIRVExtensionMap.cend())
       return O.error("Unknown SPIR-V extension: " + Token.str());
+    if (EnabledExtensions.count(NameValuePair->second))
+      return O.error(
+          "Extension cannot be allowed and disallowed at the same time: " +
+          NameValuePair->first);
 
-    if (Token.starts_with("+")) {
-      EnabledExtensions.insert(NameValuePair->second);
-    } else if (EnabledExtensions.count(NameValuePair->second)) {
-      if (llvm::is_contained(Tokens, "+" + ExtensionName.str()))
-        return O.error(
-            "Extension cannot be allowed and disallowed at the same time: " +
-            ExtensionName.str());
-
-      EnabledExtensions.erase(NameValuePair->second);
-    }
+    Vals.erase(NameValuePair->second);
   }
 
-  Vals = std::move(EnabledExtensions);
+  Vals.insert(EnabledExtensions.cbegin(), EnabledExtensions.cend());
+
   return false;
 }
 

Copy link
Collaborator

@mikaelholmen mikaelholmen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have verified that the 143 lit tests that previously failed with expensive checks when built with gcc now pass with this patch, since the bad sort comparison has been removed so that looks good.
But I can't really review the actual code changes since I don't know this at all.

@AlexVlx AlexVlx merged commit 0fff58a into llvm:main Dec 15, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants