@@ -487,11 +487,30 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
487487 Value destBox = destConvert.getValue ();
488488 Value srcBox = srcConvert.getValue ();
489489
490+ // get defining alloca op of destBox and srcBox
491+ auto destAlloca = destBox.getDefiningOp <fir::AllocaOp>();
492+
493+ if (!destAlloca) {
494+ emitError (loc, " Unimplemented: FortranAssign to OpenMP lowering\n " );
495+ return ;
496+ }
497+
498+ // get the store op that stores to the alloca
499+ for (auto user : destAlloca->getUsers ()) {
500+ if (auto storeOp = dyn_cast<fir::StoreOp>(user)) {
501+ destBox = storeOp.getValue ();
502+ break ;
503+ }
504+ }
505+
490506 builder.setInsertionPoint (teamsOp);
491- // Load destination array box and source scalar
492- auto arrayBox = builder.create <fir::LoadOp>(loc, destBox);
507+ // Load destination array box (if it's a reference)
508+ Value arrayBox = destBox;
509+ if (isa<fir::ReferenceType>(destBox.getType ()))
510+ arrayBox = builder.create <fir::LoadOp>(loc, destBox);
511+
493512 auto scalarValue = builder.create <fir::BoxAddrOp>(loc, srcBox);
494- auto scalar = builder.create <fir::LoadOp>(loc, scalarValue);
513+ Value scalar = builder.create <fir::LoadOp>(loc, scalarValue);
495514
496515 // Calculate total number of elements (flattened)
497516 auto c0 = builder.create <arith::ConstantIndexOp>(loc, 0 );
@@ -543,9 +562,8 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
543562 bool changed = false ;
544563 omp::TargetOp targetOp;
545564 // Get the target op parent of teams
546- if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp ())) {
547- targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp ());
548- }
565+ targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp ());
566+ SmallVector<Operation *> opsToErase;
549567 for (auto &op : workdistribute.getOps ()) {
550568 if (&op == terminator) {
551569 break ;
@@ -560,12 +578,15 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
560578 targetOpsToProcess.insert (targetOp);
561579 replaceWithUnorderedDoLoop (rewriter, loc, teams, workdistribute,
562580 runtimeCall);
563- op. erase ( );
564- return true ;
581+ opsToErase. push_back (&op );
582+ changed = true ;
565583 }
566584 }
567585 }
568586 }
587+ for (auto *op : opsToErase) {
588+ op->erase ();
589+ }
569590 return changed;
570591}
571592
@@ -911,7 +932,7 @@ static void reloadCacheAndRecompute(
911932
912933 unsigned originalMapVarsSize = targetOp.getMapVars ().size ();
913934 unsigned hostEvalVarsSize = hostEvalVars.size ();
914- // Create Stores for allocs .
935+ // Create load operations for each allocated variable .
915936 for (unsigned i = 0 ; i < allocs.size (); ++i) {
916937 Value original = allocs[i];
917938 // Get the new block argument for this specific allocated value.
@@ -1196,6 +1217,12 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
11961217 Block *targetBlock = &targetOp.getRegion ().front ();
11971218 assert (targetBlock == &targetOp.getRegion ().back ());
11981219 IRMapping mapping;
1220+
1221+ auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp ());
1222+ if (!targetDataOp) {
1223+ llvm_unreachable (" Expected target op to be inside target_data op" );
1224+ return ;
1225+ }
11991226 // create mapping for host_eval_vars
12001227 unsigned hostEvalVarCount = targetOp.getHostEvalVars ().size ();
12011228 for (unsigned i = 0 ; i < targetOp.getHostEvalVars ().size (); ++i) {
@@ -1361,12 +1388,14 @@ static void computeAllocsCacheRecomputable(
13611388 it++) {
13621389 // Check if any of the results are used outside the split point.
13631390 for (auto res : it->getResults ()) {
1364- if (usedOutsideSplit (res, splitBeforeOp))
1391+ if (usedOutsideSplit (res, splitBeforeOp)) {
13651392 requiredVals.push_back (res);
1393+ }
13661394 }
13671395 // If the op is not recomputable, add it to the nonRecomputable set.
1368- if (!isRecomputableAfterFission (&*it, splitBeforeOp))
1396+ if (!isRecomputableAfterFission (&*it, splitBeforeOp)) {
13691397 nonRecomputable.insert (&*it);
1398+ }
13701399 }
13711400 // For each required value, collect its dependencies.
13721401 for (auto requiredVal : requiredVals)
0 commit comments