Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit 3a0c708

Browse files
author
Ivan Butygin
authored
Refactor linalg optimizations flow (#210)
1 parent a9cc160 commit 3a0c708

File tree

1 file changed

+41
-35
lines changed

1 file changed

+41
-35
lines changed

mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "plier/rewrites/force_inline.hpp"
3535
#include "plier/rewrites/index_type_propagation.hpp"
3636
#include "plier/rewrites/loop_rewrites.hpp"
37+
#include "plier/rewrites/memory_rewrites.hpp"
3738
#include "plier/transforms/loop_utils.hpp"
3839

3940
#include "base_pipeline.hpp"
@@ -44,6 +45,29 @@
4445

4546
namespace
4647
{
48+
void applyOptimizations(mlir::FuncOp op, const mlir::FrozenRewritePatternList& patterns, llvm::function_ref<mlir::LogicalResult(mlir::FuncOp)> additionalOpts = nullptr)
49+
{
50+
bool repeat = false;
51+
do
52+
{
53+
repeat = false;
54+
(void)mlir::applyPatternsAndFoldGreedily(op, patterns);
55+
if (mlir::succeeded(plier::applyCSE(op.getRegion(), false)))
56+
{
57+
repeat = true;
58+
}
59+
if (mlir::succeeded(plier::promoteLoads(op.getRegion())))
60+
{
61+
repeat = true;
62+
}
63+
if (additionalOpts && mlir::succeeded(additionalOpts(op)))
64+
{
65+
repeat = true;
66+
}
67+
}
68+
while(repeat);
69+
}
70+
4771
enum class ArrayLayout
4872
{
4973
C,
@@ -900,12 +924,12 @@ void LowerLinalgPass::runOnOperation()
900924
}
901925

902926
struct PostPlierToLinalgPass :
903-
public mlir::PassWrapper<PostPlierToLinalgPass, mlir::OperationPass<mlir::ModuleOp>>
927+
public mlir::PassWrapper<PostPlierToLinalgPass, mlir::FunctionPass>
904928
{
905-
void runOnOperation() override;
929+
void runOnFunction() override;
906930
};
907931

908-
void PostPlierToLinalgPass::runOnOperation()
932+
void PostPlierToLinalgPass::runOnFunction()
909933
{
910934
mlir::OwningRewritePatternList patterns;
911935

@@ -916,7 +940,7 @@ void PostPlierToLinalgPass::runOnOperation()
916940
SimplifyExpandDims
917941
>(&getContext());
918942

919-
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
943+
applyOptimizations(getFunction(), std::move(patterns));
920944
}
921945

922946
struct TensorFusionPass :
@@ -1016,12 +1040,12 @@ void RetainArgsPass::runOnFunction()
10161040
}
10171041

10181042
struct PostLinalgOptPass :
1019-
public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp>>
1043+
public mlir::PassWrapper<PostLinalgOptPass, mlir::FunctionPass>
10201044
{
1021-
void runOnOperation() override;
1045+
void runOnFunction() override;
10221046
};
10231047

1024-
void PostLinalgOptPass::runOnOperation()
1048+
void PostLinalgOptPass::runOnFunction()
10251049
{
10261050
mlir::OwningRewritePatternList patterns;
10271051

@@ -1032,37 +1056,19 @@ void PostLinalgOptPass::runOnOperation()
10321056
plier::CanonicalizeReduction
10331057
>(&context);
10341058

1035-
mlir::FrozenRewritePatternList frozenPatterns(std::move(patterns));
1036-
1037-
while (true)
1059+
applyOptimizations(getFunction(), std::move(patterns), [](mlir::FuncOp op)
10381060
{
1039-
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
1040-
bool rerun = false;
1041-
for (auto& op : getOperation().getRegion().front())
1042-
{
1043-
if (auto func = mlir::dyn_cast<mlir::FuncOp>(op))
1044-
{
1045-
if (mlir::succeeded(plier::naivelyFuseParallelOps(func.getRegion())))
1046-
{
1047-
rerun = true;
1048-
}
1049-
}
1050-
}
1051-
if (!rerun)
1052-
{
1053-
break;
1054-
}
1055-
}
1056-
1061+
return plier::naivelyFuseParallelOps(op.getRegion());
1062+
});
10571063
}
10581064

10591065
struct PromoteParallelPass :
1060-
public mlir::PassWrapper<PromoteParallelPass, mlir::OperationPass<mlir::ModuleOp>>
1066+
public mlir::PassWrapper<PromoteParallelPass, mlir::FunctionPass>
10611067
{
1062-
void runOnOperation() override;
1068+
void runOnFunction() override;
10631069
};
10641070

1065-
void PromoteParallelPass::runOnOperation()
1071+
void PromoteParallelPass::runOnFunction()
10661072
{
10671073
mlir::OwningRewritePatternList patterns;
10681074

@@ -1074,13 +1080,13 @@ void PromoteParallelPass::runOnOperation()
10741080
plier::PromoteToParallel // TODO
10751081
>(&context);
10761082

1077-
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1083+
applyOptimizations(getFunction(), std::move(patterns));
10781084
}
10791085

10801086
void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm)
10811087
{
10821088
pm.addPass(std::make_unique<PlierToLinalgPass>());
1083-
pm.addPass(std::make_unique<PostPlierToLinalgPass>());
1089+
pm.addNestedPass<mlir::FuncOp>(std::make_unique<PostPlierToLinalgPass>());
10841090
pm.addPass(mlir::createSymbolDCEPass());
10851091
}
10861092

@@ -1105,9 +1111,9 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
11051111
pm.addPass(mlir::createCopyRemovalPass());
11061112

11071113
pm.addPass(std::make_unique<LowerLinalgPass>());
1108-
pm.addPass(std::make_unique<PostLinalgOptPass>());
1114+
pm.addNestedPass<mlir::FuncOp>(std::make_unique<PostLinalgOptPass>());
11091115
pm.addPass(mlir::createSymbolDCEPass());
1110-
pm.addPass(std::make_unique<PromoteParallelPass>());
1116+
pm.addNestedPass<mlir::FuncOp>(std::make_unique<PromoteParallelPass>());
11111117
}
11121118
}
11131119

0 commit comments

Comments
 (0)