-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][pass] Add composite pass utility #87166
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Ivan Butygin (Hardcode84) ChangesComposite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached. The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa. Full diff: https://github.com/llvm/llvm-project/pull/87166.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 11f5b23e62c663..0cf45d8d40a93d 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -130,6 +130,13 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
/// their producers.
std::unique_ptr<Pass> createTopologicalSortPass();
+/// Create composite pass, which runs selected set of passes until fixed point
+/// or maximum number of iterations reached.
+std::unique_ptr<Pass>
+createCompositePass(std::string name, std::string argument,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations = 10);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 6c32ecf8a2a2f1..90c0298fb5e46a 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
add_mlir_library(MLIRTransforms
Canonicalizer.cpp
+ CompositePass.cpp
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
new file mode 100644
index 00000000000000..3b9700f1f05176
--- /dev/null
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -0,0 +1,81 @@
+//===- CompositePass.cpp - Composite pass code ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CompositePass allows to run set of passes until fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+
+namespace {
+struct CompositePass final
+ : public PassWrapper<CompositePass, OperationPass<void>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositePass)
+
+ CompositePass(std::string name_, std::string argument_,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations)
+ : name(std::move(name_)), argument(std::move(argument_)),
+ dynamicPM(std::make_shared<OpPassManager>()), maxIters(maxIterations) {
+ populateFunc(*dynamicPM);
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ dynamicPM->getDependentDialects(registry);
+ }
+
+ void runOnOperation() override {
+ auto op = getOperation();
+ OperationFingerPrint fp(op);
+
+ unsigned currentIter = 0;
+ while (true) {
+ if (failed(runPipeline(*dynamicPM, op)))
+ return signalPassFailure();
+
+ if (currentIter++ >= maxIters) {
+ op->emitWarning("Composite pass \"" + llvm::Twine(name) +
+ "\"+ didn't converge in " + llvm::Twine(maxIters) +
+ " iterations");
+ break;
+ }
+
+ OperationFingerPrint newFp(op);
+ if (newFp == fp)
+ break;
+
+ fp = newFp;
+ }
+ }
+
+protected:
+ llvm::StringRef getName() const override { return name; }
+
+ llvm::StringRef getArgument() const override { return argument; }
+
+private:
+ std::string name;
+ std::string argument;
+ std::shared_ptr<OpPassManager> dynamicPM;
+ unsigned maxIters;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCompositePass(
+ std::string name, std::string argument,
+ llvm::function_ref<void(OpPassManager &)> populateFunc,
+ unsigned maxIterations) {
+
+ return std::make_unique<CompositePass>(std::move(name), std::move(argument),
+ populateFunc, maxIterations);
+}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
new file mode 100644
index 00000000000000..4bf83d3a79754a
--- /dev/null
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
+
+// CHECK-LABEL: running `TestCompositePass`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK-NOT: running `Canonicalizer`
+// CHECK-NOT: running `CSE`
+func.func @test() {
+ return
+}
+
+// -----
+
+// CHECK-LABEL: running `TestCompositePass`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK: running `Canonicalizer`
+// CHECK: running `CSE`
+// CHECK-NOT: running `Canonicalizer`
+// CHECK-NOT: running `CSE`
+func.func @test() {
+// this constant will be canonicalized away, causing another pass iteration
+ %0 = arith.constant 1.5 : f32
+ return
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..a849b7ebd29e23 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ endif()
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
+ TestCompositePass.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
new file mode 100644
index 00000000000000..64299685b3286e
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -0,0 +1,30 @@
+//===------ TestCompositePass.cpp --- composite test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composite pass utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestCompositePass() {
+ registerPass([]() -> std::unique_ptr<Pass> {
+ return createCompositePass("TestCompositePass", "test-composite-pass",
+ [](OpPassManager &p) {
+ p.addPass(createCanonicalizerPass());
+ p.addPass(createCSEPass());
+ });
+ });
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 82b3881792bf3f..6ce9f3041d6f48 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
void registerVectorizerTestPass();
namespace test {
+void registerTestCompositePass();
void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
registerVectorizerTestPass();
registerTosaTestQuantUtilAPIPass();
+ mlir::test::registerTestCompositePass();
mlir::test::registerCommutativityUtils();
mlir::test::registerConvertCallOpPass();
mlir::test::registerInliner();
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that seems like a nice addition.
@@ -130,6 +130,13 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {}); | |||
/// their producers. | |||
std::unique_ptr<Pass> createTopologicalSortPass(); | |||
|
|||
/// Create composite pass, which runs selected set of passes until fixed point |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I would include "fixed point" in the name, "composite" is very generic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to CompositeFixedPointPass
, suggestions for better naming are welcome :)
std::string pipeline; | ||
llvm::raw_string_ostream os(pipeline); | ||
dynamicPM->printAsTextualPipeline(os); | ||
os.flush(); | ||
pipelineStr = pipeline; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::string pipeline; | |
llvm::raw_string_ostream os(pipeline); | |
dynamicPM->printAsTextualPipeline(os); | |
os.flush(); | |
pipelineStr = pipeline; | |
llvm::raw_string_ostream os(pipelineStr); | |
dynamicPM->printAsTextualPipeline(os); |
Wouldn't this do the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
llvm::raw_string_ostream os(pipelineStr); | ||
dynamicPM.printAsTextualPipeline(os); | ||
os.flush(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.flush(); |
Nit: the flush is implicit when the stream goes out of scope the next line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
return failure(); | ||
|
||
if (failed(parsePassPipeline(pipelineStr, dynamicPM))) { | ||
llvm::errs() << "Failed to parse composite pass pipeline\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't use llvm::errs(): you have access to a MLIRContext and can emit an error here.
|
||
LogicalResult initialize(MLIRContext * /*context*/) override { | ||
if (maxIter <= 0) { | ||
llvm::errs() << "Invalid maxIterations value: " << maxIter << "\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to emitError
here, but I don't think we have access to context in initializeOptions
LogicalResult result = pass->initializeOptions(options); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, I can't move parsePassPipeline
to initialize
because getDependentDialects
is called before initialize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn’t a getContext() on the pass itself? (On a phone, I can’t look for a suggestion now)
in any case llvm::errs isn’t in scope of any pass or library code, so we need to figure out something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think context even exists as this point, but the calling function provides errorHandler
callback:
function_ref<LogicalResult(const Twine &)> errorHandler) { |
We probably need to change initializeOptions
signature to accept error callback as well.
Also, default pass implementation calls parseFromString
which also diretly writes to llvm::errs()
llvm-project/mlir/lib/Pass/Pass.cpp
Line 53 in 6318dd8
LogicalResult Pass::initializeOptions(StringRef options) { |
llvm-project/mlir/lib/Pass/PassRegistry.cpp
Line 283 in 6318dd8
LogicalResult detail::PassOptions::parseFromString(StringRef options) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, default pass implementation calls parseFromString which also diretly writes to llvm::errs()
That's a problem as well, it was maybe not considered as critical because string parsing is more of a command-line integration point in general, but technically it still should work through injection (the kind of handler you're hinting at).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also agree the right fix is to pass in the errorHandler, this is quite visible here:
llvm-project/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
Lines 335 to 336 in 6318dd8
if (clientAPI != "vulkan" && clientAPI != "opencl") | |
return failure(); |
where no diagnostic is emitted at all on the failure!
There are only a couple of occurrences of Pass::initializeOptions
override, can you send a separate PR ahead of this one to add the error handler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will look into updating initializeOptions
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rebased and updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG with the remaining small comments
…87289) There is no good way to report detailed errors from inside `Pass::initializeOptions` function as context may not be available at this point and writing directly to `llvm::errs()` is not composable. See #87166 (comment) * Add error handler callback to `Pass::initializeOptions` * Update `PassOptions::parseFromString` to support custom error stream instead of using `llvm::errs()` directly. * Update default `Pass::initializeOptions` implementation to propagate error string from `parseFromString` to new error handler. * Update `MapMemRefStorageClassPass` to report error details using new API.
Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached. The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.
89aa740
to
bb49a5f
Compare
@@ -43,6 +43,7 @@ class GreedyRewriteConfig; | |||
#define GEN_PASS_DECL_SYMBOLDCE | |||
#define GEN_PASS_DECL_SYMBOLPRIVATIZE | |||
#define GEN_PASS_DECL_TOPOLOGICALSORT | |||
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we just replace all this list with #define GEN_PASS_DECL
? Are there passes we don't want here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are ViewOpGraph
and LocationSnapshot
passes which are not listed here but listed in separate header files, so if I addGEN_PASS_DECL
here it fails with struct redefinition errors.
Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached. The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.