1414#include " flang/Optimizer/Dialect/FIRDialect.h"
1515#include " flang/Optimizer/Dialect/FIROps.h"
1616#include " flang/Optimizer/Dialect/FIRType.h"
17- #include " flang/Optimizer/Transforms/Passes.h"
1817#include " flang/Optimizer/HLFIR/Passes.h"
1918#include " flang/Optimizer/OpenMP/Utils.h"
19+ #include " flang/Optimizer/Transforms/Passes.h"
2020#include " mlir/Analysis/SliceAnalysis.h"
2121#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
2222#include " mlir/IR/Builders.h"
2323#include " mlir/IR/Value.h"
2424#include " mlir/Transforms/DialectConversion.h"
2525#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
26+ #include " mlir/Transforms/RegionUtils.h"
2627#include < mlir/Dialect/Arith/IR/Arith.h>
2728#include < mlir/Dialect/LLVMIR/LLVMTypes.h>
2829#include < mlir/Dialect/Utils/IndexingUtils.h>
3334#include < mlir/IR/PatternMatch.h>
3435#include < mlir/Interfaces/SideEffectInterfaces.h>
3536#include < mlir/Support/LLVM.h>
36- #include " mlir/Transforms/RegionUtils.h"
3737#include < optional>
3838#include < variant>
3939
@@ -66,30 +66,30 @@ static T getPerfectlyNested(Operation *op) {
6666// / This is the single source of truth about whether we should parallelize an
6767// / operation nested in an omp.workdistribute region.
6868static bool shouldParallelize (Operation *op) {
69- // Currently we cannot parallelize operations with results that have uses
70- if (llvm::any_of (op->getResults (),
71- [](OpResult v) -> bool { return !v.use_empty (); }))
69+ // Currently we cannot parallelize operations with results that have uses
70+ if (llvm::any_of (op->getResults (),
71+ [](OpResult v) -> bool { return !v.use_empty (); }))
72+ return false ;
73+ // We will parallelize unordered loops - these come from array syntax
74+ if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
75+ auto unordered = loop.getUnordered ();
76+ if (!unordered)
7277 return false ;
73- // We will parallelize unordered loops - these come from array syntax
74- if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
75- auto unordered = loop.getUnordered ();
76- if (!unordered)
77- return false ;
78- return *unordered;
79- }
80- if (auto callOp = dyn_cast<fir::CallOp>(op)) {
81- auto callee = callOp.getCallee ();
82- if (!callee)
83- return false ;
84- auto *func = op->getParentOfType <ModuleOp>().lookupSymbol (*callee);
85- // TODO need to insert a check here whether it is a call we can actually
86- // parallelize currently
87- if (func->getAttr (fir::FIROpsDialect::getFirRuntimeAttrName ()))
88- return true ;
78+ return *unordered;
79+ }
80+ if (auto callOp = dyn_cast<fir::CallOp>(op)) {
81+ auto callee = callOp.getCallee ();
82+ if (!callee)
8983 return false ;
90- }
91- // We cannot parallise anything else
84+ auto *func = op->getParentOfType <ModuleOp>().lookupSymbol (*callee);
85+ // TODO need to insert a check here whether it is a call we can actually
86+ // parallelize currently
87+ if (func->getAttr (fir::FIROpsDialect::getFirRuntimeAttrName ()))
88+ return true ;
9289 return false ;
90+ }
91+ // We cannot parallise anything else
92+ return false ;
9393}
9494
9595// / If B() and D() are parallelizable,
@@ -120,12 +120,10 @@ static bool shouldParallelize(Operation *op) {
120120// / }
121121// / E()
122122
123- struct FissionWorkdistribute
124- : public OpRewritePattern<omp::WorkdistributeOp> {
123+ struct FissionWorkdistribute : public OpRewritePattern <omp::WorkdistributeOp> {
125124 using OpRewritePattern::OpRewritePattern;
126- LogicalResult
127- matchAndRewrite (omp::WorkdistributeOp workdistribute,
128- PatternRewriter &rewriter) const override {
125+ LogicalResult matchAndRewrite (omp::WorkdistributeOp workdistribute,
126+ PatternRewriter &rewriter) const override {
129127 auto loc = workdistribute->getLoc ();
130128 auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp ());
131129 if (!teams) {
@@ -185,7 +183,7 @@ struct FissionWorkdistribute
185183 auto newWorkdistribute = rewriter.create <omp::WorkdistributeOp>(loc);
186184 rewriter.create <omp::TerminatorOp>(loc);
187185 rewriter.createBlock (&newWorkdistribute.getRegion (),
188- newWorkdistribute.getRegion ().begin (), {}, {});
186+ newWorkdistribute.getRegion ().begin (), {}, {});
189187 auto *cloned = rewriter.clone (*parallelize);
190188 rewriter.replaceOp (parallelize, cloned);
191189 rewriter.create <omp::TerminatorOp>(loc);
@@ -197,8 +195,7 @@ struct FissionWorkdistribute
197195};
198196
199197static void
200- genLoopNestClauseOps (mlir::Location loc,
201- mlir::PatternRewriter &rewriter,
198+ genLoopNestClauseOps (mlir::Location loc, mlir::PatternRewriter &rewriter,
202199 fir::DoLoopOp loop,
203200 mlir::omp::LoopNestOperands &loopNestClauseOps) {
204201 assert (loopNestClauseOps.loopLowerBounds .empty () &&
@@ -209,10 +206,8 @@ genLoopNestClauseOps(mlir::Location loc,
209206 loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
210207}
211208
212- static void
213- genWsLoopOp (mlir::PatternRewriter &rewriter,
214- fir::DoLoopOp doLoop,
215- const mlir::omp::LoopNestOperands &clauseOps) {
209+ static void genWsLoopOp (mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
210+ const mlir::omp::LoopNestOperands &clauseOps) {
216211
217212 auto wsloopOp = rewriter.create <mlir::omp::WsloopOp>(doLoop.getLoc ());
218213 rewriter.createBlock (&wsloopOp.getRegion ());
@@ -236,7 +231,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter,
236231 return ;
237232}
238233
239- // / If fir.do_loop id present inside teams workdistribute
234+ // / If fir.do_loop is present inside teams workdistribute
240235// /
241236// / omp.teams {
242237// / omp.workdistribute {
@@ -246,7 +241,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter,
246241// / }
247242// / }
248243// /
249- // / Then, its lowered to
244+ // / Then, its lowered to
250245// /
251246// / omp.teams {
252247// / omp.workdistribute {
@@ -277,7 +272,8 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
277272
278273 auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(teamsLoc);
279274 rewriter.createBlock (¶llelOp.getRegion ());
280- rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(doLoop.getLoc ()));
275+ rewriter.setInsertionPoint (
276+ rewriter.create <mlir::omp::TerminatorOp>(doLoop.getLoc ()));
281277
282278 mlir::omp::LoopNestOperands loopNestClauseOps;
283279 genLoopNestClauseOps (doLoop.getLoc (), rewriter, doLoop,
@@ -292,7 +288,6 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
292288 }
293289};
294290
295-
296291// / If A() and B () are present inside teams workdistribute
297292// /
298293// / omp.teams {
@@ -311,17 +306,17 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
311306struct TeamsWorkdistributeToSingle : public OpRewritePattern <omp::TeamsOp> {
312307 using OpRewritePattern::OpRewritePattern;
313308 LogicalResult matchAndRewrite (omp::TeamsOp teamsOp,
314- PatternRewriter &rewriter) const override {
315- auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
316- if (!workdistributeOp) {
317- LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " No workdistribute nested\n " );
318- return failure ();
319- }
320- Block *workdistributeBlock = &workdistributeOp.getRegion ().front ();
321- rewriter.eraseOp (workdistributeBlock->getTerminator ());
322- rewriter.inlineBlockBefore (workdistributeBlock, teamsOp);
323- rewriter.eraseOp (teamsOp);
324- return success ();
309+ PatternRewriter &rewriter) const override {
310+ auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
311+ if (!workdistributeOp) {
312+ LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " No workdistribute nested\n " );
313+ return failure ();
314+ }
315+ Block *workdistributeBlock = &workdistributeOp.getRegion ().front ();
316+ rewriter.eraseOp (workdistributeBlock->getTerminator ());
317+ rewriter.inlineBlockBefore (workdistributeBlock, teamsOp);
318+ rewriter.eraseOp (teamsOp);
319+ return success ();
325320 }
326321};
327322
@@ -332,26 +327,27 @@ class LowerWorkdistributePass
332327 MLIRContext &context = getContext ();
333328 GreedyRewriteConfig config;
334329 // prevent the pattern driver form merging blocks
335- config.setRegionSimplificationLevel (
336- GreedySimplifyRegionLevel::Disabled);
337-
330+ config.setRegionSimplificationLevel (GreedySimplifyRegionLevel::Disabled);
331+
338332 Operation *op = getOperation ();
339333 {
340334 RewritePatternSet patterns (&context);
341- patterns.insert <FissionWorkdistribute, TeamsWorkdistributeLowering>(&context);
335+ patterns.insert <FissionWorkdistribute, TeamsWorkdistributeLowering>(
336+ &context);
342337 if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
343338 emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
344339 signalPassFailure ();
345340 }
346341 }
347342 {
348343 RewritePatternSet patterns (&context);
349- patterns.insert <TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(&context);
344+ patterns.insert <TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(
345+ &context);
350346 if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
351347 emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
352348 signalPassFailure ();
353349 }
354350 }
355351 }
356352};
357- }
353+ } // namespace
0 commit comments