-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[NFC][SPIRV] Re-work extension parsing #171826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-backend-spir-v Author: Alex Voicu (AlexVlx) ChangesThis changes the extension parsing mechanism underpinning Full diff: https://github.com/llvm/llvm-project/pull/171826.diff 1 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 42edad255ce82..04c54f9b0e53d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -17,7 +17,9 @@
#include "llvm/TargetParser/Triple.h"
#include <functional>
+#include <iterator>
#include <map>
+#include <set>
#include <string>
#include <utility>
#include <vector>
@@ -26,7 +28,7 @@
using namespace llvm;
-static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
+static const std::map<StringRef, SPIRV::Extension::Extension>
SPIRVExtensionMap = {
{"SPV_EXT_shader_atomic_float_add",
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_add},
@@ -181,57 +183,52 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
std::set<SPIRV::Extension::Extension> &Vals) {
SmallVector<StringRef, 10> Tokens;
ArgValue.split(Tokens, ",", -1, false);
- llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) {
- // We want to ensure that we handle "all" first, to ensure that any
- // subsequent disablement actually behaves as expected i.e. given
- // --spv-ext=all,-foo, we first enable all and then disable foo; this should
- // be revisited and simplified.
- if (LHS == "all")
- return true;
- if (RHS == "all")
- return false;
- return !(RHS < LHS);
- });
std::set<SPIRV::Extension::Extension> EnabledExtensions;
- for (const auto &Token : Tokens) {
- if (Token == "all") {
- for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
- EnabledExtensions.insert(ExtensionEnum);
+ auto M = partition(Tokens, [](auto &&T) { return T.starts_with('+'); });
+
+ if (std::any_of(M, Tokens.end(), [](auto &&T) { return T == "all"; }))
+ copy(make_second_range(SPIRVExtensionMap), std::inserter(Vals, Vals.end()));
+
+ for (auto &&Token : make_range(Tokens.begin(), M)) {
+ StringRef ExtensionName = Token.substr(1);
+ auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
+ if (NameValuePair == SPIRVExtensionMap.end())
+ return O.error("Unknown SPIR-V extension: " + Token.str());
+
+ EnabledExtensions.insert(NameValuePair->second);
+ }
+
+ for (auto &&Token : make_range(M, Tokens.end())) {
+ if (Token == "all")
continue;
- }
if (Token.size() == 3 && Token.upper() == "KHR") {
for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
if (StringRef(ExtensionName).starts_with("SPV_KHR_"))
- EnabledExtensions.insert(ExtensionEnum);
+ Vals.insert(ExtensionEnum);
continue;
}
if (Token.empty() || (!Token.starts_with("+") && !Token.starts_with("-")))
- return O.error("Invalid extension list format: " + Token.str());
+ return O.error("Invalid extension list format: " + Token);
- StringRef ExtensionName = Token.substr(1);
- auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
+ auto NameValuePair = SPIRVExtensionMap.find(Token.substr(1));
- if (NameValuePair == SPIRVExtensionMap.end())
+ if (NameValuePair == SPIRVExtensionMap.cend())
return O.error("Unknown SPIR-V extension: " + Token.str());
+ if (EnabledExtensions.count(NameValuePair->second))
+ return O.error(
+ "Extension cannot be allowed and disallowed at the same time: " +
+ NameValuePair->first);
- if (Token.starts_with("+")) {
- EnabledExtensions.insert(NameValuePair->second);
- } else if (EnabledExtensions.count(NameValuePair->second)) {
- if (llvm::is_contained(Tokens, "+" + ExtensionName.str()))
- return O.error(
- "Extension cannot be allowed and disallowed at the same time: " +
- ExtensionName.str());
-
- EnabledExtensions.erase(NameValuePair->second);
- }
+ Vals.erase(NameValuePair->second);
}
- Vals = std::move(EnabledExtensions);
+ Vals.insert(EnabledExtensions.cbegin(), EnabledExtensions.cend());
+
return false;
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have verified that the 143 lit tests that previously failed with expensive checks when built with gcc now pass with this patch, since the bad sort comparison has been removed so that looks good.
But I can't really review the actual code changes since I don't know this at all.
This changes the extension parsing mechanism underpinning
--spirv-extto be more explicit about what it is doing and not rely on a sort. More specifically, we partition extensions into enabled (prefixed with+) and others, and then individually handle the resulting ranges.