1212// Fortran array statements are lowered to fir as fir.do_loop unordered.
1313// lower-workdistribute pass works mainly on identifying fir.do_loop unordered
1414// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
15- // lowers it to target{teams{parallel{wsloop{loop_nest}}}}.
15+ // lowers it to target{teams{parallel{distribute{ wsloop{loop_nest} }}}}.
1616// It hoists all the other ops outside target region.
1717// Relaces heap allocation on target with omp.target_allocmem and
1818// deallocation with omp.target_freemem from host. Also replaces
19- // runtime function "Assign" with omp.target_memcpy .
19+ // runtime function "Assign" with omp_target_memcpy .
2020//
2121// ===----------------------------------------------------------------------===//
2222
@@ -319,13 +319,14 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
319319// Then, its lowered to
320320//
321321// omp.teams {
322- // omp.parallel {
323- // omp.distribute {
324- // omp.wsloop {
325- // omp.loop_nest
326- // ...
327- // }
328- // }
322+ // omp.parallel {
323+ // omp.distribute {
324+ // omp.wsloop {
325+ // omp.loop_nest
326+ // ...
327+ // }
328+ // }
329+ // }
329330// }
330331// }
331332
@@ -345,6 +346,7 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute,
345346 targetOpsToProcess.insert (targetOp);
346347 }
347348 }
349+ // Generate the nested parallel, distribute, wsloop and loop_nest ops.
348350 genParallelOp (wdLoc, rewriter, true );
349351 genDistributeOp (wdLoc, rewriter, true );
350352 mlir::omp::LoopNestOperands loopNestClauseOps;
@@ -584,6 +586,7 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
584586 }
585587 }
586588 }
589+ // Erase the runtime calls that have been replaced.
587590 for (auto *op : opsToErase) {
588591 op->erase ();
589592 }
@@ -772,6 +775,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
772775 Value alloc;
773776 Type allocType;
774777 auto llvmPtrTy = LLVM::LLVMPointerType::get (&ctx);
778+ // Get the appropriate type for allocation
775779 if (isPtr (ty)) {
776780 Type intTy = rewriter.getI32Type ();
777781 auto one = rewriter.create <LLVM::ConstantOp>(loc, intTy, 1 );
@@ -782,6 +786,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
782786 allocType = ty;
783787 alloc = rewriter.create <fir::AllocaOp>(loc, allocType);
784788 }
789+ // Lambda to create mapinfo ops
785790 auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
786791 return rewriter.create <omp::MapInfoOp>(
787792 loc, alloc.getType (), alloc, TypeAttr::get (allocType),
@@ -796,6 +801,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
796801 /* mapperId=*/ mlir::FlatSymbolRefAttr (),
797802 /* name=*/ rewriter.getStringAttr (name), rewriter.getBoolAttr (false ));
798803 };
804+ // Create mapinfo ops.
799805 uint64_t mapFrom =
800806 static_cast <std::underlying_type_t <llvm::omp::OpenMPOffloadMappingFlags>>(
801807 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
@@ -847,14 +853,17 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
847853 SetVector<Operation *> &toCache,
848854 SetVector<Operation *> &toRecompute) {
849855 Operation *op = v.getDefiningOp ();
856+ // If v is a block argument, it must be from the targetOp.
850857 if (!op) {
851858 assert (cast<BlockArgument>(v).getOwner ()->getParentOp () == targetOp);
852859 return ;
853860 }
861+ // If the op is in the nonRecomputable set, add it to toCache and return.
854862 if (nonRecomputable.contains (op)) {
855863 toCache.insert (op);
856864 return ;
857865 }
866+ // Add the op to toRecompute.
858867 toRecompute.insert (op);
859868 for (auto opr : op->getOperands ())
860869 collectNonRecomputableDeps (opr, targetOp, nonRecomputable, toCache,
@@ -939,6 +948,8 @@ static void reloadCacheAndRecompute(
939948 Value newArg =
940949 newTargetBlock->getArgument (hostEvalVarsSize + originalMapVarsSize + i);
941950 Value restored;
951+ // If the original value is a pointer or reference, load and convert if
952+ // necessary.
942953 if (isPtr (original.getType ())) {
943954 restored = rewriter.create <LLVM::LoadOp>(loc, llvmPtrTy, newArg);
944955 if (!isa<LLVM::LLVMPointerType>(original.getType ()))
@@ -967,6 +978,7 @@ static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) {
967978 return nullptr ;
968979 // Find parallel op inside teams
969980 mlir::omp::ParallelOp parallelOp = nullptr ;
981+ // Look for the parallel op in the teams region
970982 for (auto &op : teamsOp.getRegion ().front ()) {
971983 if (auto parallel = dyn_cast<mlir::omp::ParallelOp>(op)) {
972984 parallelOp = parallel;
@@ -1218,6 +1230,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
12181230 assert (targetBlock == &targetOp.getRegion ().back ());
12191231 IRMapping mapping;
12201232
1233+ // Get the parent target_data op
12211234 auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp ());
12221235 if (!targetDataOp) {
12231236 llvm_unreachable (" Expected target op to be inside target_data op" );
@@ -1255,6 +1268,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
12551268 SmallVector<Operation *> opsToReplace;
12561269 Value device = targetOp.getDevice ();
12571270
1271+ // If device is not specified, default to device 0.
12581272 if (!device) {
12591273 device = genI32Constant (targetOp.getLoc (), rewriter, 0 );
12601274 }
@@ -1508,15 +1522,12 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
15081522 SmallVector<Value> isolatedHostEvalVars{targetOp.getHostEvalVars ()};
15091523 // update the hostEvalVars of isolatedTargetOp
15101524 if (!hostEvalVars.lbs .empty () && !isTargetDevice) {
1511- for (size_t i = 0 ; i < hostEvalVars.lbs .size (); ++i) {
1512- isolatedHostEvalVars.push_back (hostEvalVars.lbs [i]);
1513- }
1514- for (size_t i = 0 ; i < hostEvalVars.ubs .size (); ++i) {
1515- isolatedHostEvalVars.push_back (hostEvalVars.ubs [i]);
1516- }
1517- for (size_t i = 0 ; i < hostEvalVars.steps .size (); ++i) {
1518- isolatedHostEvalVars.push_back (hostEvalVars.steps [i]);
1519- }
1525+ isolatedHostEvalVars.append (hostEvalVars.lbs .begin (),
1526+ hostEvalVars.lbs .end ());
1527+ isolatedHostEvalVars.append (hostEvalVars.ubs .begin (),
1528+ hostEvalVars.ubs .end ());
1529+ isolatedHostEvalVars.append (hostEvalVars.steps .begin (),
1530+ hostEvalVars.steps .end ());
15201531 }
15211532 // Create the isolated target op
15221533 omp::TargetOp isolatedTargetOp = rewriter.create <omp::TargetOp>(
@@ -1708,13 +1719,14 @@ static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter,
17081719 Operation *toIsolate = std::get<0 >(*tuple);
17091720 bool splitBefore = !std::get<1 >(*tuple);
17101721 bool splitAfter = !std::get<2 >(*tuple);
1711-
1722+ // Recursively isolate the target op.
17121723 if (splitBefore && splitAfter) {
17131724 auto res =
17141725 isolateOp (toIsolate, splitAfter, rewriter, module , isTargetDevice);
17151726 fissionTarget (res.postTargetOp , rewriter, module , isTargetDevice);
17161727 return ;
17171728 }
1729+ // Isolate only before the op.
17181730 if (splitBefore) {
17191731 isolateOp (toIsolate, splitAfter, rewriter, module , isTargetDevice);
17201732 return ;
0 commit comments