Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][pass] Add composite pass utility #87166

Merged
merged 11 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLDCE
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just replace all this list with #define GEN_PASS_DECL? Are there passes we don't want here?

Copy link
Contributor Author

@Hardcode84 Hardcode84 Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are ViewOpGraph and LocationSnapshot passes which are not listed here but listed in separate header files, so if I addGEN_PASS_DECL here it fails with struct redefinition errors.

#include "mlir/Transforms/Passes.h.inc"

/// Creates an instance of the Canonicalizer pass, configured with default
Expand Down Expand Up @@ -130,6 +131,12 @@ createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
/// their producers.
std::unique_ptr<Pass> createTopologicalSortPass();

/// Create composite pass, which runs provided set of passes until fixed point
/// or maximum number of iterations reached.
std::unique_ptr<Pass> createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations = 10);

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -552,4 +552,21 @@ def TopologicalSort : Pass<"topological-sort"> {
let constructor = "mlir::createTopologicalSortPass()";
}

def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
let summary = "Composite fixed point pass";
let description = [{
Composite pass runs provided set of passes until fixed point or maximum
number of iterations reached.
}];

let options = [
Option<"name", "name", "std::string", /*default=*/"\"CompositeFixedPointPass\"",
"Composite pass display name">,
Option<"pipelineStr", "pipeline", "std::string", /*default=*/"",
"Composite pass inner pipeline">,
Option<"maxIter", "max-iterations", "int", /*default=*/"10",
"Maximum number of iterations if inner pipeline">,
];
}

#endif // MLIR_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_subdirectory(Utils)

add_mlir_library(MLIRTransforms
Canonicalizer.cpp
CompositePass.cpp
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
Expand Down
105 changes: 105 additions & 0 deletions mlir/lib/Transforms/CompositePass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//===- CompositePass.cpp - Composite pass code ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// CompositePass allows to run set of passes until fixed point is reached.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/Passes.h"

#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

namespace mlir {
#define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {
struct CompositeFixedPointPass final
: public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
using CompositeFixedPointPassBase::CompositeFixedPointPassBase;

CompositeFixedPointPass(
std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations) {
name = std::move(name_);
maxIter = maxIterations;
populateFunc(dynamicPM);

llvm::raw_string_ostream os(pipelineStr);
dynamicPM.printAsTextualPipeline(os);
}

LogicalResult initializeOptions(
StringRef options,
function_ref<LogicalResult(const Twine &)> errorHandler) override {
if (failed(CompositeFixedPointPassBase::initializeOptions(options,
errorHandler)))
return failure();

if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
return errorHandler("Failed to parse composite pass pipeline");

return success();
}

LogicalResult initialize(MLIRContext *context) override {
if (maxIter <= 0)
return emitError(UnknownLoc::get(context))
<< "Invalid maxIterations value: " << maxIter << "\n";

return success();
}

void getDependentDialects(DialectRegistry &registry) const override {
dynamicPM.getDependentDialects(registry);
}

void runOnOperation() override {
auto op = getOperation();
OperationFingerPrint fp(op);

int currentIter = 0;
int maxIterVal = maxIter;
while (true) {
if (failed(runPipeline(dynamicPM, op)))
return signalPassFailure();

if (currentIter++ >= maxIterVal) {
op->emitWarning("Composite pass \"" + llvm::Twine(name) +
"\"+ didn't converge in " + llvm::Twine(maxIterVal) +
" iterations");
break;
}

OperationFingerPrint newFp(op);
if (newFp == fp)
break;

fp = newFp;
}
}

protected:
llvm::StringRef getName() const override { return name; }

private:
OpPassManager dynamicPM;
};
} // namespace

std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
int maxIterations) {

return std::make_unique<CompositeFixedPointPass>(std::move(name),
populateFunc, maxIterations);
}
26 changes: 26 additions & 0 deletions mlir/test/Transforms/composite-pass.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt %s --log-actions-to=- --test-composite-fixed-point-pass -split-input-file | FileCheck %s
// RUN: mlir-opt %s --log-actions-to=- --composite-fixed-point-pass='name=TestCompositePass pipeline=any(canonicalize,cse)' -split-input-file | FileCheck %s

// CHECK-LABEL: running `TestCompositePass`
// CHECK: running `Canonicalizer`
// CHECK: running `CSE`
// CHECK-NOT: running `Canonicalizer`
// CHECK-NOT: running `CSE`
func.func @test() {
return
}

// -----

// CHECK-LABEL: running `TestCompositePass`
// CHECK: running `Canonicalizer`
// CHECK: running `CSE`
// CHECK: running `Canonicalizer`
// CHECK: running `CSE`
// CHECK-NOT: running `Canonicalizer`
// CHECK-NOT: running `CSE`
func.func @test() {
// this constant will be canonicalized away, causing another pass iteration
%0 = arith.constant 1.5 : f32
return
}
1 change: 1 addition & 0 deletions mlir/test/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ endif()
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
TestCompositePass.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/lib/Transforms/TestCompositePass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===------ TestCompositePass.cpp --- composite test pass -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to test the composite pass utility.
//
//===----------------------------------------------------------------------===//

#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"

namespace mlir {
namespace test {
void registerTestCompositePass() {
registerPassPipeline(
"test-composite-fixed-point-pass", "Test composite pass",
[](OpPassManager &pm, StringRef optionsStr,
function_ref<LogicalResult(const Twine &)> errorHandler) {
if (!optionsStr.empty())
return failure();

pm.addPass(createCompositeFixedPointPass(
"TestCompositePass", [](OpPassManager &p) {
p.addPass(createCanonicalizerPass());
p.addPass(createCSEPass());
}));
return success();
},
[](function_ref<void(const detail::PassOptions &)>) {});
}
} // namespace test
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void registerTosaTestQuantUtilAPIPass();
void registerVectorizerTestPass();

namespace test {
void registerTestCompositePass();
void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerInliner();
Expand Down Expand Up @@ -195,6 +196,7 @@ void registerTestPasses() {
registerVectorizerTestPass();
registerTosaTestQuantUtilAPIPass();

mlir::test::registerTestCompositePass();
mlir::test::registerCommutativityUtils();
mlir::test::registerConvertCallOpPass();
mlir::test::registerInliner();
Expand Down
Loading