Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][pass] Add composite pass utility #87166

Merged
merged 11 commits into from
Apr 2, 2024
Merged

Conversation

Hardcode84
Copy link
Contributor

Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached. The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 30, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 30, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Ivan Butygin (Hardcode84)

Changes

Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached. The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.


Full diff: https://github.com/llvm/llvm-project/pull/87166.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Transforms/Passes.h (+7)
  • (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Transforms/CompositePass.cpp (+81)
  • (added) mlir/test/Transforms/composite-pass.mlir (+25)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Transforms/TestCompositePass.cpp (+30)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 11f5b23e62c663..0cf45d8d40a93d 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -130,6 +130,13 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
 /// their producers.
 std::unique_ptr<Pass> createTopologicalSortPass();
 
+/// Create composite pass, which runs selected set of passes until fixed point
+/// or maximum number of iterations reached.
+std::unique_ptr<Pass>
+createCompositePass(std::string name, std::string argument,
+                    llvm::function_ref<void(OpPassManager &)> populateFunc,
+                    unsigned maxIterations = 10);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 6c32ecf8a2a2f1..90c0298fb5e46a 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Utils)
 
 add_mlir_library(MLIRTransforms
   Canonicalizer.cpp
+  CompositePass.cpp
   ControlFlowSink.cpp
   CSE.cpp
   GenerateRuntimeVerification.cpp
diff --git a/mlir/lib/Transforms/CompositePass.cpp b/mlir/lib/Transforms/CompositePass.cpp
new file mode 100644
index 00000000000000..3b9700f1f05176
--- /dev/null
+++ b/mlir/lib/Transforms/CompositePass.cpp
@@ -0,0 +1,81 @@
+//===- CompositePass.cpp - Composite pass code ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CompositePass allows to run set of passes until fixed point is reached.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+
+namespace {
+struct CompositePass final
+    : public PassWrapper<CompositePass, OperationPass<void>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CompositePass)
+
+  CompositePass(std::string name_, std::string argument_,
+                llvm::function_ref<void(OpPassManager &)> populateFunc,
+                unsigned maxIterations)
+      : name(std::move(name_)), argument(std::move(argument_)),
+        dynamicPM(std::make_shared<OpPassManager>()), maxIters(maxIterations) {
+    populateFunc(*dynamicPM);
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    dynamicPM->getDependentDialects(registry);
+  }
+
+  void runOnOperation() override {
+    auto op = getOperation();
+    OperationFingerPrint fp(op);
+
+    unsigned currentIter = 0;
+    while (true) {
+      if (failed(runPipeline(*dynamicPM, op)))
+        return signalPassFailure();
+
+      if (currentIter++ >= maxIters) {
+        op->emitWarning("Composite pass \"" + llvm::Twine(name) +
+                        "\"+ didn't converge in " + llvm::Twine(maxIters) +
+                        " iterations");
+        break;
+      }
+
+      OperationFingerPrint newFp(op);
+      if (newFp == fp)
+        break;
+
+      fp = newFp;
+    }
+  }
+
+protected:
+  llvm::StringRef getName() const override { return name; }
+
+  llvm::StringRef getArgument() const override { return argument; }
+
+private:
+  std::string name;
+  std::string argument;
+  std::shared_ptr<OpPassManager> dynamicPM;
+  unsigned maxIters;
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCompositePass(
+    std::string name, std::string argument,
+    llvm::function_ref<void(OpPassManager &)> populateFunc,
+    unsigned maxIterations) {
+
+  return std::make_unique<CompositePass>(std::move(name), std::move(argument),
+                                         populateFunc, maxIterations);
+}
diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir
new file mode 100644
index 00000000000000..4bf83d3a79754a
--- /dev/null
+++ b/mlir/test/Transforms/composite-pass.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --log-actions-to=- --test-composite-pass -split-input-file | FileCheck %s
+
+// CHECK-LABEL: running `TestCompositePass`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//   CHECK-NOT: running `Canonicalizer`
+//   CHECK-NOT: running `CSE`
+func.func @test() {
+  return
+}
+
+// -----
+
+// CHECK-LABEL: running `TestCompositePass`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//       CHECK: running `Canonicalizer`
+//       CHECK: running `CSE`
+//   CHECK-NOT: running `Canonicalizer`
+//   CHECK-NOT: running `CSE`
+func.func @test() {
+// this constant will be canonicalized away, causing another pass iteration
+  %0 = arith.constant 1.5 : f32
+  return
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 2a3a8608db5442..a849b7ebd29e23 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ endif()
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
   TestCommutativityUtils.cpp
+  TestCompositePass.cpp
   TestConstantFold.cpp
   TestControlFlowSink.cpp
   TestInlining.cpp
diff --git a/mlir/test/lib/Transforms/TestCompositePass.cpp b/mlir/test/lib/Transforms/TestCompositePass.cpp
new file mode 100644
index 00000000000000..64299685b3286e
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestCompositePass.cpp
@@ -0,0 +1,30 @@
+//===------ TestCompositePass.cpp --- composite test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composite pass utility.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestCompositePass() {
+  registerPass([]() -> std::unique_ptr<Pass> {
+    return createCompositePass("TestCompositePass", "test-composite-pass",
+                               [](OpPassManager &p) {
+                                 p.addPass(createCanonicalizerPass());
+                                 p.addPass(createCSEPass());
+                               });
+  });
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 82b3881792bf3f..6ce9f3041d6f48 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
 void registerVectorizerTestPass();
 
 namespace test {
+void registerTestCompositePass();
 void registerCommutativityUtils();
 void registerConvertCallOpPass();
 void registerInliner();
@@ -195,6 +196,7 @@ void registerTestPasses() {
   registerVectorizerTestPass();
   registerTosaTestQuantUtilAPIPass();
 
+  mlir::test::registerTestCompositePass();
   mlir::test::registerCommutativityUtils();
   mlir::test::registerConvertCallOpPass();
   mlir::test::registerInliner();

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that seems like a nice addition.

@@ -130,6 +130,13 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
/// their producers.
std::unique_ptr<Pass> createTopologicalSortPass();

/// Create composite pass, which runs selected set of passes until fixed point
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I would include "fixed point" in the name, "composite" is very generic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to CompositeFixedPointPass, suggestions for better naming are welcome :)

mlir/lib/Transforms/CompositePass.cpp Outdated Show resolved Hide resolved
Comment on lines 37 to 41
std::string pipeline;
llvm::raw_string_ostream os(pipeline);
dynamicPM->printAsTextualPipeline(os);
os.flush();
pipelineStr = pipeline;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::string pipeline;
llvm::raw_string_ostream os(pipeline);
dynamicPM->printAsTextualPipeline(os);
os.flush();
pipelineStr = pipeline;
llvm::raw_string_ostream os(pipelineStr);
dynamicPM->printAsTextualPipeline(os);

Wouldn't this do the same thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


llvm::raw_string_ostream os(pipelineStr);
dynamicPM.printAsTextualPipeline(os);
os.flush();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
os.flush();

Nit: the flush is implicit when the stream goes out of scope the next line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return failure();

if (failed(parsePassPipeline(pipelineStr, dynamicPM))) {
llvm::errs() << "Failed to parse composite pass pipeline\n";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't use llvm::errs(): you have access to a MLIRContext and can emit an error here.


LogicalResult initialize(MLIRContext * /*context*/) override {
if (maxIter <= 0) {
llvm::errs() << "Invalid maxIterations value: " << maxIter << "\n";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to emitError here, but I don't think we have access to context in initializeOptions

LogicalResult result = pass->initializeOptions(options);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, I can't move parsePassPipeline to initialize because getDependentDialects is called before initialize

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn’t a getContext() on the pass itself? (On a phone, I can’t look for a suggestion now)

in any case llvm::errs isn’t in scope of any pass or library code, so we need to figure out something else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think context even exists as this point, but the calling function provides errorHandler callback:

function_ref<LogicalResult(const Twine &)> errorHandler) {

We probably need to change initializeOptions signature to accept error callback as well.

Also, default pass implementation calls parseFromString which also diretly writes to llvm::errs()

LogicalResult Pass::initializeOptions(StringRef options) {

LogicalResult detail::PassOptions::parseFromString(StringRef options) {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, default pass implementation calls parseFromString which also diretly writes to llvm::errs()

That's a problem as well, it was maybe not considered as critical because string parsing is more of a command-line integration point in general, but technically it still should work through injection (the kind of handler you're hinting at).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also agree the right fix is to pass in the errorHandler, this is quite visible here:

if (clientAPI != "vulkan" && clientAPI != "opencl")
return failure();

where no diagnostic is emitted at all on the failure!

There are only a couple of occurrences of Pass::initializeOptions override, can you send a separate PR ahead of this one to add the error handler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will look into updating initializeOptions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebased and updated

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG with the remaining small comments

Hardcode84 added a commit that referenced this pull request Apr 1, 2024
…87289)

There is no good way to report detailed errors from inside
`Pass::initializeOptions` function as context may not be available at
this point and writing directly to `llvm::errs()` is not composable.

See
#87166 (comment)

* Add error handler callback to `Pass::initializeOptions`
* Update `PassOptions::parseFromString` to support custom error stream
instead of using `llvm::errs()` directly.
* Update default `Pass::initializeOptions` implementation to propagate
error string from `parseFromString` to new error handler.
* Update `MapMemRefStorageClassPass` to report error details using new
API.
Composite pass allows to run sequence of passes in the loop until fixed point or maximum number of iterations is reached.
The usual candidates are canonicalize+CSE as canonicalize can open more opportunities for CSE and vice-versa.
@@ -43,6 +43,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLDCE
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just replace all this list with #define GEN_PASS_DECL? Are there passes we don't want here?

Copy link
Contributor Author

@Hardcode84 Hardcode84 Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are ViewOpGraph and LocationSnapshot passes which are not listed here but listed in separate header files, so if I addGEN_PASS_DECL here it fails with struct redefinition errors.

@Hardcode84 Hardcode84 merged commit 5b66b6a into llvm:main Apr 2, 2024
4 checks passed
@Hardcode84 Hardcode84 deleted the composite-pass branch April 2, 2024 10:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants