3434#include " plier/rewrites/force_inline.hpp"
3535#include " plier/rewrites/index_type_propagation.hpp"
3636#include " plier/rewrites/loop_rewrites.hpp"
37+ #include " plier/rewrites/memory_rewrites.hpp"
3738#include " plier/transforms/loop_utils.hpp"
3839
3940#include " base_pipeline.hpp"
4445
4546namespace
4647{
48+ void applyOptimizations (mlir::FuncOp op, const mlir::FrozenRewritePatternList& patterns, llvm::function_ref<mlir::LogicalResult(mlir::FuncOp)> additionalOpts = nullptr)
49+ {
50+ bool repeat = false ;
51+ do
52+ {
53+ repeat = false ;
54+ (void )mlir::applyPatternsAndFoldGreedily (op, patterns);
55+ if (mlir::succeeded (plier::applyCSE (op.getRegion (), false )))
56+ {
57+ repeat = true ;
58+ }
59+ if (mlir::succeeded (plier::promoteLoads (op.getRegion ())))
60+ {
61+ repeat = true ;
62+ }
63+ if (additionalOpts && mlir::succeeded (additionalOpts (op)))
64+ {
65+ repeat = true ;
66+ }
67+ }
68+ while (repeat);
69+ }
70+
4771enum class ArrayLayout
4872{
4973 C,
@@ -900,12 +924,12 @@ void LowerLinalgPass::runOnOperation()
900924}
901925
902926struct PostPlierToLinalgPass :
903- public mlir::PassWrapper<PostPlierToLinalgPass, mlir::OperationPass<mlir::ModuleOp> >
927+ public mlir::PassWrapper<PostPlierToLinalgPass, mlir::FunctionPass >
904928{
905- void runOnOperation () override ;
929+ void runOnFunction () override ;
906930};
907931
908- void PostPlierToLinalgPass::runOnOperation ()
932+ void PostPlierToLinalgPass::runOnFunction ()
909933{
910934 mlir::OwningRewritePatternList patterns;
911935
@@ -916,7 +940,7 @@ void PostPlierToLinalgPass::runOnOperation()
916940 SimplifyExpandDims
917941 >(&getContext ());
918942
919- ( void ) mlir::applyPatternsAndFoldGreedily ( getOperation (), std::move (patterns));
943+ applyOptimizations ( getFunction (), std::move (patterns));
920944}
921945
922946struct TensorFusionPass :
@@ -1016,12 +1040,12 @@ void RetainArgsPass::runOnFunction()
10161040}
10171041
10181042struct PostLinalgOptPass :
1019- public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp> >
1043+ public mlir::PassWrapper<PostLinalgOptPass, mlir::FunctionPass >
10201044{
1021- void runOnOperation () override ;
1045+ void runOnFunction () override ;
10221046};
10231047
1024- void PostLinalgOptPass::runOnOperation ()
1048+ void PostLinalgOptPass::runOnFunction ()
10251049{
10261050 mlir::OwningRewritePatternList patterns;
10271051
@@ -1032,37 +1056,19 @@ void PostLinalgOptPass::runOnOperation()
10321056 plier::CanonicalizeReduction
10331057 >(&context);
10341058
1035- mlir::FrozenRewritePatternList frozenPatterns (std::move (patterns));
1036-
1037- while (true )
1059+ applyOptimizations (getFunction (), std::move (patterns), [](mlir::FuncOp op)
10381060 {
1039- (void )mlir::applyPatternsAndFoldGreedily (getOperation (), frozenPatterns);
1040- bool rerun = false ;
1041- for (auto & op : getOperation ().getRegion ().front ())
1042- {
1043- if (auto func = mlir::dyn_cast<mlir::FuncOp>(op))
1044- {
1045- if (mlir::succeeded (plier::naivelyFuseParallelOps (func.getRegion ())))
1046- {
1047- rerun = true ;
1048- }
1049- }
1050- }
1051- if (!rerun)
1052- {
1053- break ;
1054- }
1055- }
1056-
1061+ return plier::naivelyFuseParallelOps (op.getRegion ());
1062+ });
10571063}
10581064
10591065struct PromoteParallelPass :
1060- public mlir::PassWrapper<PromoteParallelPass, mlir::OperationPass<mlir::ModuleOp> >
1066+ public mlir::PassWrapper<PromoteParallelPass, mlir::FunctionPass >
10611067{
1062- void runOnOperation () override ;
1068+ void runOnFunction () override ;
10631069};
10641070
1065- void PromoteParallelPass::runOnOperation ()
1071+ void PromoteParallelPass::runOnFunction ()
10661072{
10671073 mlir::OwningRewritePatternList patterns;
10681074
@@ -1074,13 +1080,13 @@ void PromoteParallelPass::runOnOperation()
10741080 plier::PromoteToParallel // TODO
10751081 >(&context);
10761082
1077- ( void ) mlir::applyPatternsAndFoldGreedily ( getOperation (), std::move (patterns));
1083+ applyOptimizations ( getFunction (), std::move (patterns));
10781084}
10791085
10801086void populate_plier_to_linalg_gen_pipeline (mlir::OpPassManager& pm)
10811087{
10821088 pm.addPass (std::make_unique<PlierToLinalgPass>());
1083- pm.addPass (std::make_unique<PostPlierToLinalgPass>());
1089+ pm.addNestedPass <mlir::FuncOp> (std::make_unique<PostPlierToLinalgPass>());
10841090 pm.addPass (mlir::createSymbolDCEPass ());
10851091}
10861092
@@ -1105,9 +1111,9 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
11051111 pm.addPass (mlir::createCopyRemovalPass ());
11061112
11071113 pm.addPass (std::make_unique<LowerLinalgPass>());
1108- pm.addPass (std::make_unique<PostLinalgOptPass>());
1114+ pm.addNestedPass <mlir::FuncOp> (std::make_unique<PostLinalgOptPass>());
11091115 pm.addPass (mlir::createSymbolDCEPass ());
1110- pm.addPass (std::make_unique<PromoteParallelPass>());
1116+ pm.addNestedPass <mlir::FuncOp> (std::make_unique<PromoteParallelPass>());
11111117}
11121118}
11131119
0 commit comments