From 5dd5d8a58bbc600e55e4466882f4e14631b92f09 Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 4 Dec 2023 12:57:36 -0800 Subject: [PATCH 01/21] [Flang] Introduce lower-workdistribute pass for workdistribute lowering. Co-authors: ivanradanov, skc7 --- .../include/flang/Optimizer/OpenMP/Passes.td | 4 + flang/lib/Optimizer/OpenMP/CMakeLists.txt | 1 + .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 901 ++++++++++++++++++ flang/lib/Optimizer/Passes/Pipelines.cpp | 4 +- flang/test/Fir/basic-program.fir | 1 + .../OpenMP/lower-workdistribute-doloop.mlir | 33 + .../lower-workdistribute-fission-target.mlir | 112 +++ .../OpenMP/lower-workdistribute-fission.mlir | 71 ++ .../OpenMP/lower-workdistribute-target.mlir | 32 + .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 3 + 10 files changed, 1161 insertions(+), 1 deletion(-) create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td index e2f092024c250..bfbaa5f838e90 100644 --- a/flang/include/flang/Optimizer/OpenMP/Passes.td +++ b/flang/include/flang/Optimizer/OpenMP/Passes.td @@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> { let summary = "Lower workshare construct"; } +def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> { + let summary = "Lower workdistribute construct"; +} + def GenericLoopConversionPass : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> { let summary = "Converts OpenMP generic `omp.loop` to semantically " diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index b85ee7e861a4f..23a7dc8f08399 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms MapsForPrivatizedSymbols.cpp MapInfoFinalization.cpp MarkDeclareTarget.cpp + LowerWorkdistribute.cpp LowerWorkshare.cpp LowerNontemporal.cpp SimdOnly.cpp diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp new file mode 100644 index 0000000000000..0885efc716db4 --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -0,0 +1,901 @@ +//===- LowerWorkshare.cpp - special cases for bufferization -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering and optimisations of omp.workdistribute. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "flang/Optimizer/OpenMP/Utils.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace flangomp { +#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +#define DEBUG_TYPE "lower-workdistribute" + +using namespace mlir; + +namespace { + +static bool isRuntimeCall(Operation *op) { + if (auto callOp = dyn_cast(op)) { + auto callee = callOp.getCallee(); + if (!callee) + return false; + auto *func = op->getParentOfType().lookupSymbol(*callee); + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; + } + return false; +} + +/// This is the single source of truth about whether we should parallelize an +/// operation nested in an omp.execute region. +static bool shouldParallelize(Operation *op) { + if (llvm::any_of(op->getResults(), + [](OpResult v) -> bool { return !v.use_empty(); })) + return false; + // We will parallelize unordered loops - these come from array syntax + if (auto loop = dyn_cast(op)) { + auto unordered = loop.getUnordered(); + if (!unordered) + return false; + return *unordered; + } + if (isRuntimeCall(op)) { + return true; + } + // We cannot parallise anything else + return false; +} + +template +static T getPerfectlyNested(Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + auto ®ion = op->getRegion(0); + if (region.getBlocks().size() != 1) + return nullptr; + auto *block = ®ion.front(); + auto *firstOp = &block->front(); + if (auto nested = dyn_cast(firstOp)) + if (firstOp->getNextNode() == block->getTerminator()) + return nested; + return nullptr; +} + +/// If B() and D() are parallelizable, +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// C() +/// D() +/// E() +/// } +/// } +/// +/// becomes +/// +/// A() +/// omp.teams { +/// omp.workdistribute { +/// B() +/// } +/// } +/// C() +/// omp.teams { +/// omp.workdistribute { +/// D() +/// } +/// } +/// E() + +static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return false; + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return false; + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return false; + } + + auto *teamsBlock = &teams.getRegion().front(); + bool changed = false; + // Move the ops inside teams and before workdistribute outside. + IRMapping irMapping; + llvm::SmallVector teamsHoisted; + for (auto &op : teams.getOps()) { + if (&op == workdistribute) { + break; + } + if (shouldParallelize(&op)) { + emitError(loc, "teams has parallelize ops before first workdistribute\n"); + return false; + } else { + rewriter.setInsertionPoint(teams); + rewriter.clone(op, irMapping); + teamsHoisted.push_back(&op); + changed = true; + } + } + for (auto *op : llvm::reverse(teamsHoisted)) { + op->replaceAllUsesWith(irMapping.lookup(op)); + op->erase(); + } + + // While we have unhandled operations in the original workdistribute + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + while (&workdistributeBlock->front() != terminator) { + rewriter.setInsertionPoint(teams); + IRMapping mapping; + llvm::SmallVector hoisted; + Operation *parallelize = nullptr; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; + } + if (shouldParallelize(&op)) { + parallelize = &op; + break; + } else { + rewriter.clone(op, mapping); + hoisted.push_back(&op); + changed = true; + } + } + + for (auto *op : llvm::reverse(hoisted)) { + op->replaceAllUsesWith(mapping.lookup(op)); + op->erase(); + } + + if (parallelize && hoisted.empty() && + parallelize->getNextNode() == terminator) + break; + if (parallelize) { + auto newTeams = rewriter.cloneWithoutRegions(teams); + auto *newTeamsBlock = rewriter.createBlock( + &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); + for (auto arg : teamsBlock->getArguments()) + newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); + auto newWorkdistribute = rewriter.create(loc); + rewriter.create(loc); + rewriter.createBlock(&newWorkdistribute.getRegion(), + newWorkdistribute.getRegion().begin(), {}, {}); + auto *cloned = rewriter.clone(*parallelize); + parallelize->replaceAllUsesWith(cloned); + parallelize->erase(); + rewriter.create(loc); + changed = true; + } + } + return changed; +} + +/// If fir.do_loop is present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.parallel { +/// omp.distribute { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } + +static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { + auto parallelOp = rewriter.create(loc); + parallelOp.setComposite(composite); + rewriter.createBlock(¶llelOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create(loc)); + return; +} + +static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { + mlir::omp::DistributeOperands distributeClauseOps; + auto distributeOp = + rewriter.create(loc, distributeClauseOps); + distributeOp.setComposite(composite); + auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); + rewriter.setInsertionPointToStart(distributeBlock); + return; +} + +static void +genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, + mlir::omp::LoopNestOperands &loopNestClauseOps) { + assert(loopNestClauseOps.loopLowerBounds.empty() && + "Loop nest bounds were already emitted!"); + loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound()); + loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound()); + loopNestClauseOps.loopSteps.push_back(loop.getStep()); + loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); +} + +static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, + const mlir::omp::LoopNestOperands &clauseOps, + bool composite) { + + auto wsloopOp = rewriter.create(doLoop.getLoc()); + wsloopOp.setComposite(composite); + rewriter.createBlock(&wsloopOp.getRegion()); + + auto loopNestOp = + rewriter.create(doLoop.getLoc(), clauseOps); + + // Clone the loop's body inside the loop nest construct using the + // mapped values. + rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(), + loopNestOp.getRegion().begin()); + Block *clonedBlock = &loopNestOp.getRegion().back(); + mlir::Operation *terminatorOp = clonedBlock->getTerminator(); + + // Erase fir.result op of do loop and create yield op. + if (auto resultOp = dyn_cast(terminatorOp)) { + rewriter.setInsertionPoint(terminatorOp); + rewriter.create(doLoop->getLoc()); + // rewriter.erase(terminatorOp); + terminatorOp->erase(); + } + return; +} + +static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto doLoop = getPerfectlyNested(workdistribute); + auto wdLoc = workdistribute->getLoc(); + if (doLoop && shouldParallelize(doLoop)) { + assert(doLoop.getReduceOperands().empty()); + genParallelOp(wdLoc, rewriter, true); + genDistributeOp(wdLoc, rewriter, true); + mlir::omp::LoopNestOperands loopNestClauseOps; + genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); + genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); + workdistribute.erase(); + return true; + } + return false; +} + +/// If A() and B () are present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// } +/// } +/// +/// Then, its lowered to +/// +/// A() +/// B() +/// + +static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) { + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) + return false; + // Get the block containing teamsOp (the parent block). + Block *parentBlock = teamsOp->getBlock(); + Block &workdistributeBlock = *workdistributeOp.getRegion().begin(); + auto insertPoint = Block::iterator(teamsOp); + // Get the range of operations to move (excluding the terminator). + auto workdistributeBegin = workdistributeBlock.begin(); + auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator(); + // Move the operations from workdistribute block to before teamsOp. + parentBlock->getOperations().splice(insertPoint, + workdistributeBlock.getOperations(), + workdistributeBegin, workdistributeEnd); + // Erase the now-empty workdistributeOp. + workdistributeOp.erase(); + Block &teamsBlock = *teamsOp.getRegion().begin(); + // Check if only the terminator remains and erase teams op. + if (teamsBlock.getOperations().size() == 1 && + teamsBlock.getTerminator() != nullptr) { + teamsOp.erase(); + } + return true; +} + +struct SplitTargetResult { + omp::TargetOp targetOp; + omp::TargetDataOp dataOp; +}; + +/// If multiple workdistribute are nested in a target regions, we will need to +/// split the target region, but we want to preserve the data semantics of the +/// original data region and avoid unnecessary data movement at each of the +/// subkernels - we split the target region into a target_data{target} +/// nest where only the outer one moves the data +std::optional splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { + auto loc = targetOp->getLoc(); + if (targetOp.getMapVars().empty()) { + LLVM_DEBUG(llvm::dbgs() + << DEBUG_TYPE << " target region has no data maps\n"); + return std::nullopt; + } + + SmallVector mapInfos; + for (auto opr : targetOp.getMapVars()) { + auto mapInfo = cast(opr.getDefiningOp()); + mapInfos.push_back(mapInfo); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector innerMapInfos; + SmallVector outerMapInfos; + + for (auto mapInfo : mapInfos) { + auto originalMapType = + (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + auto originalCaptureType = mapInfo.getMapCaptureType(); + llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::VariableCaptureKind newCaptureType; + + if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) { + newMapType = originalMapType; + newCaptureType = originalCaptureType; + } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { + newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newCaptureType = originalCaptureType; + outerMapInfos.push_back(mapInfo); + } else { + llvm_unreachable("Unhandled case"); + } + auto innerMapInfo = cast(rewriter.clone(*mapInfo)); + innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( + rewriter.getIntegerType(64, false), + static_cast< + std::underlying_type_t>( + newMapType))); + innerMapInfo.setMapCaptureType(newCaptureType); + innerMapInfos.push_back(innerMapInfo.getResult()); + } + + rewriter.setInsertionPoint(targetOp); + auto device = targetOp.getDevice(); + auto ifExpr = targetOp.getIfExpr(); + auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); + auto devicePtrVars = targetOp.getIsDevicePtrVars(); + auto targetDataOp = rewriter.create( + loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(taregtDataBlock); + + auto newTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), + newTargetOp.getRegion().begin()); + + rewriter.replaceOp(targetOp, newTargetOp); + return SplitTargetResult{cast(newTargetOp), targetDataOp}; +} + +static std::optional> +getNestedOpToIsolate(omp::TargetOp targetOp) { + if (targetOp.getRegion().empty()) + return std::nullopt; + auto *targetBlock = &targetOp.getRegion().front(); + for (auto &op : *targetBlock) { + bool first = &op == &*targetBlock->begin(); + bool last = op.getNextNode() == targetBlock->getTerminator(); + if (first && last) + return std::nullopt; + + if (isa(&op)) + return {{&op, first, last}}; + } + return std::nullopt; +} + +struct TempOmpVar { + omp::MapInfoOp from, to; +}; + +static bool isPtr(Type ty) { + return isa(ty) || isa(ty); +} + +static Type getPtrTypeForOmp(Type ty) { + if (isPtr(ty)) + return LLVM::LLVMPointerType::get(ty.getContext()); + else + return fir::LLVMPointerType::get(ty); +} + +static TempOmpVar allocateTempOmpVar(Location loc, Type ty, + RewriterBase &rewriter) { + MLIRContext &ctx = *ty.getContext(); + Value alloc; + Type allocType; + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + if (isPtr(ty)) { + Type intTy = rewriter.getI32Type(); + auto one = rewriter.create(loc, intTy, 1); + allocType = llvmPtrTy; + alloc = rewriter.create(loc, llvmPtrTy, allocType, one); + allocType = intTy; + } else { + allocType = ty; + alloc = rewriter.create(loc, allocType); + } + auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { + return rewriter.create( + loc, alloc.getType(), alloc, TypeAttr::get(allocType), + rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), + mappingFlags), + rewriter.getAttr( + omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/Value{}, + /*members=*/SmallVector{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/ValueRange(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); + }; + uint64_t mapFrom = + static_cast>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + uint64_t mapTo = + static_cast>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); + auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + return TempOmpVar{mapInfoFrom, mapInfoTo}; +}; + +static bool usedOutsideSplit(Value v, Operation *split) { + if (!split) + return false; + auto targetOp = cast(split->getParentOp()); + auto *targetBlock = &targetOp.getRegion().front(); + for (auto *user : v.getUsers()) { + while (user->getBlock() != targetBlock) { + user = user->getParentOp(); + } + if (!user->isBeforeInBlock(split)) + return true; + } + return false; +}; + +static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { + if (isa(op)) + return true; + + llvm::SmallVector effects; + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) { + return false; + } + interface.getEffects(effects); + if (effects.empty()) + return true; + return false; +} + +struct SplitResult { + omp::TargetOp preTargetOp; + omp::TargetOp isolatedTargetOp; + omp::TargetOp postTargetOp; +}; + +static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, + SetVector &nonRecomputable, + SetVector &toCache, + SetVector &toRecompute) { + Operation *op = v.getDefiningOp(); + if (!op) { + assert(cast(v).getOwner()->getParentOp() == targetOp); + return; + } + if (nonRecomputable.contains(op)) { + toCache.insert(op); + return; + } + toRecompute.insert(op); + for (auto opr : op->getOperands()) + collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, + toRecompute); +} + +static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter, + MLIRContext &ctx, IRMapping &mapping, + Operation *splitBefore, Block *targetBlock, + Block *newTargetBlock, + SmallVector &allocs, + SetVector &toRecompute) { + for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) { + auto originalArg = targetBlock->getArgument(i); + auto newArg = newTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + mapping.map(originalArg, newArg); + } + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + for (auto original : allocs) { + Value newArg = newTargetBlock->addArgument( + getPtrTypeForOmp(original.getType()), original.getLoc()); + Value restored; + if (isPtr(original.getType())) { + restored = rewriter.create(loc, llvmPtrTy, newArg); + if (!isa(original.getType())) + restored = + rewriter.create(loc, original.getType(), restored); + } else { + restored = rewriter.create(loc, newArg); + } + mapping.map(original, restored); + } + for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { + if (toRecompute.contains(&*it)) + rewriter.clone(*it, mapping); + } +} + +static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, + RewriterBase &rewriter) { + auto targetOp = cast(splitBeforeOp->getParentOp()); + MLIRContext &ctx = *targetOp.getContext(); + assert(targetOp); + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + rewriter.setInsertionPoint(targetOp); + + auto preMapOperands = SmallVector(targetOp.getMapVars()); + auto postMapOperands = SmallVector(targetOp.getMapVars()); + + SmallVector requiredVals; + SetVector toCache; + SetVector toRecompute; + SetVector nonRecomputable; + SmallVector allocs; + + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + for (auto res : it->getResults()) { + if (usedOutsideSplit(res, splitBeforeOp)) + requiredVals.push_back(res); + } + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) + nonRecomputable.insert(&*it); + } + + for (auto requiredVal : requiredVals) + collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, + toRecompute); + + for (Operation *op : toCache) { + for (auto res : op->getResults()) { + auto alloc = + allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); + allocs.push_back(res); + preMapOperands.push_back(alloc.from); + postMapOperands.push_back(alloc.to); + } + } + + rewriter.setInsertionPoint(targetOp); + + auto preTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + auto *preTargetBlock = rewriter.createBlock( + &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); + IRMapping preMapping; + for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) { + auto originalArg = targetBlock->getArgument(i); + auto newArg = preTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + preMapping.map(originalArg, newArg); + } + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) + rewriter.clone(*it, preMapping); + + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + for (auto original : allocs) { + Value toStore = preMapping.lookup(original); + auto newArg = preTargetBlock->addArgument( + getPtrTypeForOmp(original.getType()), original.getLoc()); + if (isPtr(original.getType())) { + if (!isa(toStore.getType())) + toStore = rewriter.create(loc, llvmPtrTy, toStore); + rewriter.create(loc, toStore, newArg); + } else { + rewriter.create(loc, toStore, newArg); + } + } + rewriter.create(loc); + + rewriter.setInsertionPoint(targetOp); + + auto isolatedTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + + auto *isolatedTargetBlock = + rewriter.createBlock(&isolatedTargetOp.getRegion(), + isolatedTargetOp.getRegion().begin(), {}, {}); + + IRMapping isolatedMapping; + reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp, + targetBlock, isolatedTargetBlock, allocs, + toRecompute); + rewriter.clone(*splitBeforeOp, isolatedMapping); + rewriter.create(loc); + + omp::TargetOp postTargetOp = nullptr; + + if (splitAfter) { + rewriter.setInsertionPoint(targetOp); + postTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + auto *postTargetBlock = rewriter.createBlock( + &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); + IRMapping postMapping; + reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp, + targetBlock, postTargetBlock, allocs, toRecompute); + + assert(splitBeforeOp->getNumResults() == 0 || + llvm::all_of(splitBeforeOp->getResults(), + [](Value result) { return result.use_empty(); })); + + for (auto it = std::next(splitBeforeOp->getIterator()); + it != targetBlock->end(); it++) + rewriter.clone(*it, postMapping); + } + + rewriter.eraseOp(targetOp); + return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; +} + +static mlir::LLVM::ConstantOp +genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { + mlir::Type i32Ty = rewriter.getI32Type(); + mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); + return rewriter.create(loc, i32Ty, attr); +} + +static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); } + +static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + Block *targetBlock = &targetOp.getRegion().front(); + assert(targetBlock == &targetOp.getRegion().back()); + IRMapping mapping; + for (auto map : + zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) { + Value mapInfo = std::get<0>(map); + BlockArgument arg = std::get<1>(map); + Operation *op = mapInfo.getDefiningOp(); + assert(op); + auto mapInfoOp = cast(op); + mapping.map(arg, mapInfoOp.getVarPtr()); + } + rewriter.setInsertionPoint(targetOp); + SmallVector opsToReplace; + Value device = targetOp.getDevice(); + if (!device) { + device = genI32Constant(targetOp.getLoc(), rewriter, 0); + } + for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); + it != end; ++it) { + auto *op = &*it; + if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast(op); + auto module = runtimeCall->getParentOfType(); + auto callee = + cast(module.lookupSymbol(runtimeCall.getCalleeAttr())); + std::string newCalleeName = (callee.getName() + "_omp").str(); + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + func::FuncOp newCallee = + cast_or_null(module.lookupSymbol(newCalleeName)); + if (!newCallee) { + SmallVector argTypes(callee.getFunctionType().getInputs()); + argTypes.push_back(getOmpDeviceType(rewriter.getContext())); + newCallee = moduleBuilder.create( + callee->getLoc(), newCalleeName, + FunctionType::get(rewriter.getContext(), argTypes, + callee.getFunctionType().getResults())); + if (callee.getArgAttrs()) + newCallee.setArgAttrsAttr(*callee.getArgAttrs()); + if (callee.getResAttrs()) + newCallee.setResAttrsAttr(*callee.getResAttrs()); + newCallee.setSymVisibility(callee.getSymVisibility()); + newCallee->setDiscardableAttrs(callee->getDiscardableAttrDictionary()); + } + SmallVector operands = runtimeCall.getOperands(); + operands.push_back(device); + auto tmpCall = rewriter.create( + runtimeCall.getLoc(), runtimeCall.getResultTypes(), + SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr, + runtimeCall.getFastmathAttr()); + Operation *newCall = rewriter.clone(*tmpCall, mapping); + mapping.map(&*it, newCall); + rewriter.eraseOp(tmpCall); + } else { + Operation *clonedOp = rewriter.clone(*op, mapping); + if (isa(clonedOp) || isa(clonedOp)) + opsToReplace.push_back(clonedOp); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + mapping.map(op->getResult(i), clonedOp->getResult(i)); + } + } + } + for (Operation *op : opsToReplace) { + if (auto allocOp = dyn_cast(op)) { + rewriter.setInsertionPoint(allocOp); + auto ompAllocmemOp = rewriter.create( + allocOp.getLoc(), rewriter.getI64Type(), device, + allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), + allocOp.getBindcNameAttr(), allocOp.getTypeparams(), + allocOp.getShape()); + auto firConvertOp = rewriter.create( + allocOp.getLoc(), allocOp.getResult().getType(), + ompAllocmemOp.getResult()); + rewriter.replaceOp(allocOp, firConvertOp.getResult()); + } else if (auto freeOp = dyn_cast(op)) { + rewriter.setInsertionPoint(freeOp); + auto firConvertOp = rewriter.create( + freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); + rewriter.create(freeOp.getLoc(), device, + firConvertOp.getResult()); + rewriter.eraseOp(freeOp); + } + } + rewriter.eraseOp(targetOp); +} + +void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) { + auto tuple = getNestedOpToIsolate(targetOp); + if (!tuple) { + LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); + moveToHost(targetOp, rewriter); + return; + } + + Operation *toIsolate = std::get<0>(*tuple); + bool splitBefore = !std::get<1>(*tuple); + bool splitAfter = !std::get<2>(*tuple); + + if (splitBefore && splitAfter) { + auto res = isolateOp(toIsolate, splitAfter, rewriter); + moveToHost(res.preTargetOp, rewriter); + fissionTarget(res.postTargetOp, rewriter); + return; + } + if (splitBefore) { + auto res = isolateOp(toIsolate, splitAfter, rewriter); + moveToHost(res.preTargetOp, rewriter); + return; + } + if (splitAfter) { + auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter); + fissionTarget(res.postTargetOp, rewriter); + return; + } +} + +class LowerWorkdistributePass + : public flangomp::impl::LowerWorkdistributeBase { +public: + void runOnOperation() override { + MLIRContext &context = getContext(); + auto moduleOp = getOperation(); + bool changed = false; + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= FissionWorkdistribute(workdistribute); + }); + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= WorkdistributeDoLower(workdistribute); + }); + moduleOp->walk([&](mlir::omp::TeamsOp teams) { + changed |= TeamsWorkdistributeToSingleOp(teams); + }); + + if (changed) { + SmallVector targetOps; + moduleOp->walk( + [&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); }); + IRRewriter rewriter(&context); + for (auto targetOp : targetOps) { + auto res = splitTargetData(targetOp, rewriter); + if (res) + fissionTarget(res->targetOp, rewriter); + } + } + } +}; +} // namespace diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index a83b0665eaf1f..1ecb6d383f939 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -301,8 +301,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, addNestedPassToAllTopLevelOperations( pm, hlfir::createInlineHLFIRAssign); pm.addPass(hlfir::createConvertHLFIRtoFIR()); - if (enableOpenMP != EnableOpenMP::None) + if (enableOpenMP != EnableOpenMP::None) { pm.addPass(flangomp::createLowerWorkshare()); + pm.addPass(flangomp::createLowerWorkdistribute()); + } if (enableOpenMP == EnableOpenMP::Simd) pm.addPass(flangomp::createSimdOnlyPass()); } diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir index 195e5ad7f9dc8..59f6c73ae84ee 100644 --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -69,6 +69,7 @@ func.func @_QQmain() { // PASSES-NEXT: InlineHLFIRAssign // PASSES-NEXT: ConvertHLFIRtoFIR // PASSES-NEXT: LowerWorkshare +// PASSES-NEXT: LowerWorkdistribute // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir new file mode 100644 index 0000000000000..00d10d6264ec9 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir @@ -0,0 +1,33 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @x({{.*}}) +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref) { + omp.teams { + omp.workdistribute { + fir.do_loop %iv = %lb to %ub step %step unordered { + %zero = arith.constant 0 : index + fir.store %zero to %addr : !fir.ref + } + omp.terminator + } + omp.terminator + } + return +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir new file mode 100644 index 0000000000000..19bdb9ce10fbd --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -0,0 +1,112 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @x( +// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} +// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} +// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"} +// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { +// CHECK: %[[VAL_11:.*]] = fir.alloca index +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_14:.*]] = fir.alloca index +// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_17:.*]] = fir.alloca index +// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_23:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_29:.*]] = "fir.omp_target_allocmem"(%[[VAL_28]], %[[VAL_23]]) <{in_type = index, operandSegmentSizes = array, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap +// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref +// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref +// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref +// CHECK: fir.store %[[VAL_29]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_30:.*]], %[[VAL_8]] -> %[[VAL_31:.*]], %[[VAL_9]] -> %[[VAL_32:.*]], %[[VAL_10]] -> %[[VAL_33:.*]], %[[VAL_13]] -> %[[VAL_34:.*]], %[[VAL_16]] -> %[[VAL_35:.*]], %[[VAL_19]] -> %[[VAL_36:.*]], %[[VAL_22]] -> %[[VAL_37:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_38:.*]] = fir.load %[[VAL_34]] : !fir.llvm_ptr +// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr +// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr +// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr> +// CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_39]], %[[VAL_39]] : index +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_43:.*]]) : index = (%[[VAL_38]]) to (%[[VAL_39]]) inclusive step (%[[VAL_40]]) { +// CHECK: fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_46:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_47:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_45]], %[[VAL_45]] : index +// CHECK: fir.store %[[VAL_44]] to %[[VAL_47]] : !fir.heap +// CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: "fir.omp_target_freemem"(%[[VAL_49]], %[[VAL_47]]) : (i32, !fir.heap) -> () +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { + %lb_ref = fir.alloca index {bindc_name = "lb"} + fir.store %lb to %lb_ref : !fir.ref + %ub_ref = fir.alloca index {bindc_name = "ub"} + fir.store %ub to %ub_ref : !fir.ref + %step_ref = fir.alloca index {bindc_name = "step"} + fir.store %step to %step_ref : !fir.ref + + %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} + %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} + %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} + %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} + + omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %ARG0 : !fir.ref + %ub_val = fir.load %ARG1 : !fir.ref + %step_val = fir.load %ARG2 : !fir.ref + %one = arith.constant 1 : index + + %20 = arith.addi %ub_val, %ub_val : index + omp.teams { + omp.workdistribute { + %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"} + fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered { + fir.store %20 to %dev_mem : !fir.heap + } + fir.store %lb_val to %dev_mem : !fir.heap + fir.freemem %dev_mem : !fir.heap + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir new file mode 100644 index 0000000000000..c562b7009664d --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir @@ -0,0 +1,71 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @test_fission_workdistribute( +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32 +// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) { +// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref) -> () +// CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref) -> () +// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] { +// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref +// CHECK: } +// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref +// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref +// CHECK: return +// CHECK: } +module { +func.func @regular_side_effect_func(%arg0: !fir.ref) { + return +} +func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref) attributes {fir.runtime} { + return +} +func.func @test_fission_workdistribute(%arr1: !fir.ref>, %arr2: !fir.ref>, %scalar_ref1: !fir.ref, %scalar_ref2: !fir.ref) { + %c0_idx = arith.constant 0 : index + %c1_idx = arith.constant 1 : index + %c9_idx = arith.constant 9 : index + %float_val = arith.constant 5.0 : f32 + omp.teams { + omp.workdistribute { + fir.store %float_val to %scalar_ref1 : !fir.ref + fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered { + %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref>, index) -> !fir.ref + %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref + %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref>, index) -> !fir.ref + fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref + } + fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref) -> () + fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref) -> () + fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx { + %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref>, index) -> !fir.ref + fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref + } + %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref + fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir new file mode 100644 index 0000000000000..d96068b26ca2f --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir @@ -0,0 +1,32 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @test_nested_derived_type_map_operand_and_block_addition( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref}>>) { +// CHECK: %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref}>>) -> !fir.ref}>> +// CHECK: %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref}>>) -> !fir.ref> +// CHECK: %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref>) -> !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"} +// CHECK: %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref}>>) -> !fir.ref> +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref>) -> !fir.ref +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} +// CHECK: omp.target map_entries(%[[VAL_10]] -> %[[VAL_11:.*]] : !fir.ref}>>) { +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref}>>) { + %0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref}>>) -> !fir.ref}>> + %2 = fir.coordinate_of %0, n : (!fir.ref}>>) -> !fir.ref> + %4 = fir.coordinate_of %2, i : (!fir.ref>) -> !fir.ref + %5 = omp.map.info var_ptr(%4 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%n%i"} + %7 = fir.coordinate_of %0, n : (!fir.ref}>>) -> !fir.ref> + %9 = fir.coordinate_of %7, r : (!fir.ref>) -> !fir.ref + %10 = omp.map.info var_ptr(%9 : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%n%r"} + %11 = omp.map.info var_ptr(%0 : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} + omp.target map_entries(%11 -> %arg1 : !fir.ref}>>) { + omp.terminator + } + return +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9fcb02eb4be3d..1ece86729dbd0 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5518,6 +5518,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, omp::LoopNestOp loopOp = castOrGetParentOfType(capturedOp); unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0; + if (targetOp.getHostEvalVars().empty()) + numLoops = 0; + Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit; llvm::SmallVector lowerBounds(numLoops), upperBounds(numLoops), steps(numLoops); From 408eca8ffcbb97ce36f060ba665fabe22c6b387c Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 25 Jul 2025 22:20:12 +0530 Subject: [PATCH 02/21] Fix hoisting declare ops out of omp.target --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 0885efc716db4..3f78727450f31 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -434,7 +434,7 @@ std::optional splitTargetData(omp::TargetOp targetOp, rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), newTargetOp.getRegion().begin()); - rewriter.replaceOp(targetOp, newTargetOp); + rewriter.replaceOp(targetOp, targetDataOp); return SplitTargetResult{cast(newTargetOp), targetDataOp}; } @@ -807,11 +807,30 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { rewriter.eraseOp(tmpCall); } else { Operation *clonedOp = rewriter.clone(*op, mapping); - if (isa(clonedOp) || isa(clonedOp)) - opsToReplace.push_back(clonedOp); for (unsigned i = 0; i < op->getNumResults(); ++i) { mapping.map(op->getResult(i), clonedOp->getResult(i)); } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + if (fir::DeclareOp clonedDeclareOp = dyn_cast(clonedOp)) { + auto originalDeclareOp = cast(op); + Type originalInType = originalDeclareOp.getMemref().getType(); + Type clonedInType = clonedDeclareOp.getMemref().getType(); + + fir::ReferenceType originalRefType = + dyn_cast(originalInType); + fir::ReferenceType clonedRefType = + dyn_cast(clonedInType); + if (!originalRefType && clonedRefType) { + Type clonedEleTy = clonedRefType.getElementType(); + if (clonedEleTy == originalDeclareOp.getType()) { + opsToReplace.push_back(clonedOp); + } + } + } + if (isa(clonedOp) || isa(clonedOp)) + opsToReplace.push_back(clonedOp); } } for (Operation *op : opsToReplace) { @@ -833,6 +852,15 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { rewriter.create(freeOp.getLoc(), device, firConvertOp.getResult()); rewriter.eraseOp(freeOp); + } else if (fir::DeclareOp clonedDeclareOp = dyn_cast(op)) { + Type clonedInType = clonedDeclareOp.getMemref().getType(); + fir::ReferenceType clonedRefType = + dyn_cast(clonedInType); + Type clonedEleTy = clonedRefType.getElementType(); + rewriter.setInsertionPoint(op); + Value loadedValue = rewriter.create( + clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); + clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); } } rewriter.eraseOp(targetOp); From 8c3785abb68ffbfcd42eb6e3a2dab7001c939217 Mon Sep 17 00:00:00 2001 From: skc7 Date: Sun, 3 Aug 2025 15:26:54 +0530 Subject: [PATCH 03/21] Handle case when private maps are present in omp.target --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 133 +++++++++++++----- 1 file changed, 95 insertions(+), 38 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 3f78727450f31..e61240e8aa443 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -566,22 +566,60 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, toRecompute); } -static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter, - MLIRContext &ctx, IRMapping &mapping, - Operation *splitBefore, Block *targetBlock, - Block *newTargetBlock, - SmallVector &allocs, - SetVector &toRecompute) { - for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) { - auto originalArg = targetBlock->getArgument(i); +static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, + omp::TargetOp &targetOp, Block *targetBlock, + Block *newTargetBlock, + SmallVector &mapOperands, + SmallVector &allocs, + IRMapping &irMapping) { + // Map `map_operands` to block arguments. + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < mapOperands.size(); ++i) { + Value originalValue; + BlockArgument newArg; + // Map the new arguments from the original block. + if (i < originalMapVarsSize) { + originalValue = targetBlock->getArgument(i); + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + // Map the new arguments from the `allocs`. + else { + originalValue = allocs[i - originalMapVarsSize]; + newArg = newTargetBlock->addArgument( + getPtrTypeForOmp(originalValue.getType()), originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + // Map `private_vars` to block arguments. + unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size(); + for (unsigned i = 0; i < originalPrivateVarsSize; ++i) { + auto originalArg = targetBlock->getArgument(originalMapVarsSize + i); auto newArg = newTargetBlock->addArgument(originalArg.getType(), originalArg.getLoc()); - mapping.map(originalArg, newArg); + irMapping.map(originalArg, newArg); } - auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); - for (auto original : allocs) { - Value newArg = newTargetBlock->addArgument( - getPtrTypeForOmp(original.getType()), original.getLoc()); + return; +} + +static void reloadCacheAndRecompute( + Location loc, RewriterBase &rewriter, Operation *splitBefore, + omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, + SmallVector &mapOperands, SmallVector &allocs, + SetVector &toRecompute, IRMapping &irMapping) { + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, newTargetBlock, + mapOperands, allocs, irMapping); + // Handle the load operations for the allocs. + rewriter.setInsertionPointToStart(newTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + // Create Stores for allocs. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value original = allocs[i]; + // Get the new block argument for this specific allocated value. + Value newArg = newTargetBlock->getArgument(originalMapVarsSize + i); + Value restored; if (isPtr(original.getType())) { restored = rewriter.create(loc, llvmPtrTy, newArg); @@ -591,18 +629,18 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter, } else { restored = rewriter.create(loc, newArg); } - mapping.map(original, restored); + irMapping.map(original, restored); } + for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { if (toRecompute.contains(&*it)) - rewriter.clone(*it, mapping); + rewriter.clone(*it, irMapping); } } static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, RewriterBase &rewriter) { auto targetOp = cast(splitBeforeOp->getParentOp()); - MLIRContext &ctx = *targetOp.getContext(); assert(targetOp); auto loc = targetOp.getLoc(); auto *targetBlock = &targetOp.getRegion().front(); @@ -657,22 +695,29 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, auto *preTargetBlock = rewriter.createBlock( &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); IRMapping preMapping; - for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) { - auto originalArg = targetBlock->getArgument(i); - auto newArg = preTargetBlock->addArgument(originalArg.getType(), - originalArg.getLoc()); - preMapping.map(originalArg, newArg); - } - for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) - rewriter.clone(*it, preMapping); + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock, + preMapOperands, allocs, preMapping); + + // Handle the store operations for the allocs. + rewriter.setInsertionPointToStart(preTargetBlock); auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); - for (auto original : allocs) { - Value toStore = preMapping.lookup(original); - auto newArg = preTargetBlock->addArgument( - getPtrTypeForOmp(original.getType()), original.getLoc()); - if (isPtr(original.getType())) { + // Clone the original operations. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + rewriter.clone(*it, preMapping); + } + + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + // Create Stores for allocs. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value originalResult = allocs[i]; + Value toStore = preMapping.lookup(originalResult); + // Get the new block argument for this specific allocated value. + Value newArg = preTargetBlock->getArgument(originalMapVarsSize + i); + + if (isPtr(originalResult.getType())) { if (!isa(toStore.getType())) toStore = rewriter.create(loc, llvmPtrTy, toStore); rewriter.create(loc, toStore, newArg); @@ -701,9 +746,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, isolatedTargetOp.getRegion().begin(), {}, {}); IRMapping isolatedMapping; - reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp, - targetBlock, isolatedTargetBlock, allocs, - toRecompute); + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + isolatedTargetBlock, postMapOperands, allocs, + toRecompute, isolatedMapping); rewriter.clone(*splitBeforeOp, isolatedMapping); rewriter.create(loc); @@ -725,8 +770,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, auto *postTargetBlock = rewriter.createBlock( &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); IRMapping postMapping; - reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp, - targetBlock, postTargetBlock, allocs, toRecompute); + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + postTargetBlock, postMapOperands, allocs, + toRecompute, postMapping); assert(splitBeforeOp->getNumResults() == 0 || llvm::all_of(splitBeforeOp->getResults(), @@ -755,15 +801,24 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { Block *targetBlock = &targetOp.getRegion().front(); assert(targetBlock == &targetOp.getRegion().back()); IRMapping mapping; - for (auto map : - zip_equal(targetOp.getMapVars(), targetBlock->getArguments())) { - Value mapInfo = std::get<0>(map); - BlockArgument arg = std::get<1>(map); + for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) { + Value mapInfo = targetOp.getMapVars()[i]; + BlockArgument arg = targetBlock->getArguments()[i]; Operation *op = mapInfo.getDefiningOp(); assert(op); auto mapInfoOp = cast(op); + // map the block argument to the host-side variable pointer mapping.map(arg, mapInfoOp.getVarPtr()); } + unsigned mapSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) { + Value privateVar = targetOp.getPrivateVars()[i]; + // The mapping should link the device-side variable to the host-side one. + BlockArgument arg = targetBlock->getArguments()[mapSize + i]; + // Map the device-side copy (`arg`) to the host-side value (`privateVar`). + mapping.map(arg, privateVar); + } + rewriter.setInsertionPoint(targetOp); SmallVector opsToReplace; Value device = targetOp.getDevice(); @@ -813,6 +868,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { // fir.declare changes its type when hoisting it out of omp.target to // omp.target_data Introduce a load, if original declareOp input is not of // reference type, but cloned delcareOp input is reference type. + if (fir::DeclareOp clonedDeclareOp = dyn_cast(clonedOp)) { auto originalDeclareOp = cast(op); Type originalInType = originalDeclareOp.getMemref().getType(); @@ -833,6 +889,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { opsToReplace.push_back(clonedOp); } } + for (Operation *op : opsToReplace) { if (auto allocOp = dyn_cast(op)) { rewriter.setInsertionPoint(allocOp); From a93038786d92db38d7727680f051ed649bcf88ff Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 4 Aug 2025 11:52:27 +0530 Subject: [PATCH 04/21] Add comments/description for functions. --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 197 +++++++++++------- 1 file changed, 121 insertions(+), 76 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index e61240e8aa443..ece64a1ba1d4d 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -1,4 +1,5 @@ -//===- LowerWorkshare.cpp - special cases for bufferization -------===// +//===- LowerWorkdistribute.cpp +//-------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,6 +9,16 @@ // // This file implements the lowering and optimisations of omp.workdistribute. // +// Fortran array statements are lowered to fir as fir.do_loop unordered. +// lower-workdistribute pass works mainly on identifying fir.do_loop unordered +// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and +// lowers it to target{teams{parallel{wsloop{loop_nest}}}}. +// It hoists all the other ops outside target region. +// Relaces heap allocation on target with omp.target_allocmem and +// deallocation with omp.target_freemem from host. Also replaces +// runtime function "Assign" with equivalent omp function. ex. @_FortranAAssign +// on target, once hoisted outside target is replaced with @_FortranAAssign_omp. +// //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" @@ -49,6 +60,8 @@ using namespace mlir; namespace { +// The isRuntimeCall function is a utility designed to determine +// if a given operation is a call to a Fortran-specific runtime function. static bool isRuntimeCall(Operation *op) { if (auto callOp = dyn_cast(op)) { auto callee = callOp.getCallee(); @@ -61,8 +74,8 @@ static bool isRuntimeCall(Operation *op) { return false; } -/// This is the single source of truth about whether we should parallelize an -/// operation nested in an omp.execute region. +// This is the single source of truth about whether we should parallelize an +// operation nested in an omp.execute region. static bool shouldParallelize(Operation *op) { if (llvm::any_of(op->getResults(), [](OpResult v) -> bool { return !v.use_empty(); })) @@ -74,13 +87,16 @@ static bool shouldParallelize(Operation *op) { return false; return *unordered; } - if (isRuntimeCall(op)) { + if (isRuntimeCall(op) && + (op->getName().getStringRef() == "_FortranAAssign")) { return true; } - // We cannot parallise anything else + // We cannot parallise anything else. return false; } +// The getPerfectlyNested function is a generic utility for finding +// a single, "perfectly nested" operation within a parent operation. template static T getPerfectlyNested(Operation *op) { if (op->getNumRegions() != 1) @@ -96,33 +112,37 @@ static T getPerfectlyNested(Operation *op) { return nullptr; } -/// If B() and D() are parallelizable, -/// -/// omp.teams { -/// omp.workdistribute { -/// A() -/// B() -/// C() -/// D() -/// E() -/// } -/// } -/// -/// becomes -/// -/// A() -/// omp.teams { -/// omp.workdistribute { -/// B() -/// } -/// } -/// C() -/// omp.teams { -/// omp.workdistribute { -/// D() -/// } -/// } -/// E() +// FissionWorkdistribute method finds the parallelizable ops +// within teams {workdistribute} region and moves them to their +// own teams{workdistribute} region. +// +// If B() and D() are parallelizable, +// +// omp.teams { +// omp.workdistribute { +// A() +// B() +// C() +// D() +// E() +// } +// } +// +// becomes +// +// A() +// omp.teams { +// omp.workdistribute { +// B() +// } +// } +// C() +// omp.teams { +// omp.workdistribute { +// D() +// } +// } +// E() static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { OpBuilder rewriter(workdistribute); @@ -215,29 +235,6 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { return changed; } -/// If fir.do_loop is present inside teams workdistribute -/// -/// omp.teams { -/// omp.workdistribute { -/// fir.do_loop unoredered { -/// ... -/// } -/// } -/// } -/// -/// Then, its lowered to -/// -/// omp.teams { -/// omp.parallel { -/// omp.distribute { -/// omp.wsloop { -/// omp.loop_nest -/// ... -/// } -/// } -/// } -/// } - static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { auto parallelOp = rewriter.create(loc); parallelOp.setComposite(composite); @@ -295,6 +292,33 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, return; } +// WorkdistributeDoLower method finds the fir.do_loop unoredered +// nested in teams {workdistribute{fir.do_loop unoredered}} and +// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}. +// +// If fir.do_loop is present inside teams workdistribute +// +// omp.teams { +// omp.workdistribute { +// fir.do_loop unoredered { +// ... +// } +// } +// } +// +// Then, its lowered to +// +// omp.teams { +// omp.parallel { +// omp.distribute { +// omp.wsloop { +// omp.loop_nest +// ... +// } +// } +// } +// } + static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) { OpBuilder rewriter(workdistribute); auto doLoop = getPerfectlyNested(workdistribute); @@ -312,20 +336,23 @@ static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) { return false; } -/// If A() and B () are present inside teams workdistribute -/// -/// omp.teams { -/// omp.workdistribute { -/// A() -/// B() -/// } -/// } -/// -/// Then, its lowered to -/// -/// A() -/// B() -/// +// TeamsWorkdistributeToSingleOp method hoists all the ops inside +// teams {workdistribute{}} before teams op. +// +// If A() and B () are present inside teams workdistribute +// +// omp.teams { +// omp.workdistribute { +// A() +// B() +// } +// } +// +// Then, its lowered to +// +// A() +// B() +// static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) { auto workdistributeOp = getPerfectlyNested(teamsOp); @@ -358,11 +385,11 @@ struct SplitTargetResult { omp::TargetDataOp dataOp; }; -/// If multiple workdistribute are nested in a target regions, we will need to -/// split the target region, but we want to preserve the data semantics of the -/// original data region and avoid unnecessary data movement at each of the -/// subkernels - we split the target region into a target_data{target} -/// nest where only the outer one moves the data +// If multiple workdistribute are nested in a target regions, we will need to +// split the target region, but we want to preserve the data semantics of the +// original data region and avoid unnecessary data movement at each of the +// subkernels - we split the target region into a target_data{target} +// nest where only the outer one moves the data std::optional splitTargetData(omp::TargetOp targetOp, RewriterBase &rewriter) { auto loc = targetOp->getLoc(); @@ -438,6 +465,10 @@ std::optional splitTargetData(omp::TargetOp targetOp, return SplitTargetResult{cast(newTargetOp), targetDataOp}; } +// getNestedOpToIsolate function is designed to identify a specific teams +// parallel op within the body of an omp::TargetOp that should be "isolated." +// This returns a tuple of op, if its first op in targetBlock, or if the op is +// last op in the tragte block. static std::optional> getNestedOpToIsolate(omp::TargetOp targetOp) { if (targetOp.getRegion().empty()) @@ -638,6 +669,15 @@ static void reloadCacheAndRecompute( } } +// isolateOp method rewrites a omp.target_data { omp.target } in to +// omp.target_data { +// // preTargetOp region contains ops before splitBeforeOp. +// omp.target {} +// // isolatedTargetOp region contains splitBeforeOp, +// omp.target {} +// // postTargetOp region contains ops after splitBeforeOp. +// omp.target {} +// } static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, RewriterBase &rewriter) { auto targetOp = cast(splitBeforeOp->getParentOp()); @@ -796,6 +836,10 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); } +// moveToHost method clones all the ops from target region outside of it. +// It hoists runtime functions and replaces them with omp vesions. +// Also hoists and replaces fir.allocmem with omp.target_allocmem and +// fir.freemem with omp.target_freemem static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { OpBuilder::InsertionGuard guard(rewriter); Block *targetBlock = &targetOp.getRegion().front(); @@ -815,7 +859,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { Value privateVar = targetOp.getPrivateVars()[i]; // The mapping should link the device-side variable to the host-side one. BlockArgument arg = targetBlock->getArguments()[mapSize + i]; - // Map the device-side copy (`arg`) to the host-side value (`privateVar`). + // Map the device-side copy (arg) to the host-side value (privateVar). mapping.map(arg, privateVar); } @@ -868,7 +912,6 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { // fir.declare changes its type when hoisting it out of omp.target to // omp.target_data Introduce a load, if original declareOp input is not of // reference type, but cloned delcareOp input is reference type. - if (fir::DeclareOp clonedDeclareOp = dyn_cast(clonedOp)) { auto originalDeclareOp = cast(op); Type originalInType = originalDeclareOp.getMemref().getType(); @@ -890,6 +933,8 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { } } + // Replace fir.allocmem with omp.target_allocmem, + // fir.freemem with omp.target_freemem. for (Operation *op : opsToReplace) { if (auto allocOp = dyn_cast(op)) { rewriter.setInsertionPoint(allocOp); From 85caac1fb1a15626726be755301d777f1b7b1101 Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 28 Aug 2025 16:53:33 +0530 Subject: [PATCH 05/21] update moveToHost implementation --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 259 +++++++++++++----- .../lower-workdistribute-fission-target.mlir | 42 +-- 2 files changed, 214 insertions(+), 87 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index ece64a1ba1d4d..8ead3d57eca98 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -834,13 +834,140 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { return rewriter.create(loc, i32Ty, attr); } -static Type getOmpDeviceType(MLIRContext *c) { return IntegerType::get(c, 32); } +static mlir::LLVM::ConstantOp +genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { + mlir::Type i64Ty = rewriter.getI64Type(); + mlir::IntegerAttr attr = rewriter.getI64IntegerAttr(value); + return rewriter.create(loc, i64Ty, attr); +} + +static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetBaseAddress"); + auto i8Type = builder.getI8Type(); + auto unknownArrayType = + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, i8Type); + auto i8BoxType = fir::BoxType::get(unknownArrayType); + auto typedBox = fir::ConvertOp::create(builder, loc, i8BoxType, box); + auto rawAddr = fir::BoxAddrOp::create(builder, loc, typedBox); + return rawAddr; +} + +static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetTotalElements"); + auto i64Type = builder.getI64Type(); + return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box); +} + +static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, + Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + auto i64Type = builder.getI64Type(); + return fir::BoxEleSizeOp::create(builder, loc, i64Type, box); +} + +static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + Value eleSize = genDescriptorGetEleSize(builder, loc, box); + Value totalElements = genDescriptorGetTotalElements(builder, loc, box); + return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize); +} + +static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value hostPtr, + mlir::Value deviceNum, + mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto i32Type = builder.getI32Type(); + auto funcName = "omp_get_mapped_ptr"; + auto funcOp = module.lookupSymbol(funcName); + + if (!funcOp) { + auto funcType = + mlir::FunctionType::get(context, {voidPtrType, i32Type}, {voidPtrType}); + + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector args; + args.push_back(fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr)); + args.push_back(fir::ConvertOp::create(builder, loc, i32Type, deviceNum)); + auto callOp = fir::CallOp::create(builder, loc, funcOp, args); + auto mappedPtr = callOp.getResult(0); + auto isNull = builder.genIsNullAddr(loc, mappedPtr); + auto convertedHostPtr = + fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr); + auto result = arith::SelectOp::create(builder, loc, isNull, convertedHostPtr, + mappedPtr); + return result; +} + +static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value dst, + mlir::Value src, mlir::Value length, + mlir::Value dstOffset, mlir::Value srcOffset, + mlir::Value device, mlir::ModuleOp module) { + auto *context = builder.getContext(); + // int omp_target_memcpy(void *dst, const void *src, size_t length, + // size_t dst_offset, size_t src_offset, + // int dst_device, int src_device) + auto funcName = "omp_target_memcpy"; + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit + auto i32Type = builder.getI32Type(); + auto funcOp = module.lookupSymbol(funcName); + + if (!funcOp) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + llvm::SmallVector argTypes = { + voidPtrType, voidPtrType, sizeTType, sizeTType, + sizeTType, i32Type, i32Type}; + auto funcType = mlir::FunctionType::get(context, argTypes, {i32Type}); + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector args{dst, src, length, dstOffset, + srcOffset, device, device}; + fir::CallOp::create(builder, loc, funcOp, args); + return; +} // moveToHost method clones all the ops from target region outside of it. // It hoists runtime functions and replaces them with omp vesions. // Also hoists and replaces fir.allocmem with omp.target_allocmem and // fir.freemem with omp.target_freemem -static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { +static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, + mlir::ModuleOp module) { OpBuilder::InsertionGuard guard(rewriter); Block *targetBlock = &targetOp.getRegion().front(); assert(targetBlock == &targetOp.getRegion().back()); @@ -859,7 +986,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { Value privateVar = targetOp.getPrivateVars()[i]; // The mapping should link the device-side variable to the host-side one. BlockArgument arg = targetBlock->getArguments()[mapSize + i]; - // Map the device-side copy (arg) to the host-side value (privateVar). + // Map the device-side copy (`arg`) to the host-side value (`privateVar`). mapping.map(arg, privateVar); } @@ -872,69 +999,43 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); it != end; ++it) { auto *op = &*it; - if (isRuntimeCall(op)) { - fir::CallOp runtimeCall = cast(op); - auto module = runtimeCall->getParentOfType(); - auto callee = - cast(module.lookupSymbol(runtimeCall.getCalleeAttr())); - std::string newCalleeName = (callee.getName() + "_omp").str(); - mlir::OpBuilder moduleBuilder(module.getBodyRegion()); - func::FuncOp newCallee = - cast_or_null(module.lookupSymbol(newCalleeName)); - if (!newCallee) { - SmallVector argTypes(callee.getFunctionType().getInputs()); - argTypes.push_back(getOmpDeviceType(rewriter.getContext())); - newCallee = moduleBuilder.create( - callee->getLoc(), newCalleeName, - FunctionType::get(rewriter.getContext(), argTypes, - callee.getFunctionType().getResults())); - if (callee.getArgAttrs()) - newCallee.setArgAttrsAttr(*callee.getArgAttrs()); - if (callee.getResAttrs()) - newCallee.setResAttrsAttr(*callee.getResAttrs()); - newCallee.setSymVisibility(callee.getSymVisibility()); - newCallee->setDiscardableAttrs(callee->getDiscardableAttrDictionary()); - } - SmallVector operands = runtimeCall.getOperands(); - operands.push_back(device); - auto tmpCall = rewriter.create( - runtimeCall.getLoc(), runtimeCall.getResultTypes(), - SymbolRefAttr::get(newCallee), operands, nullptr, nullptr, nullptr, - runtimeCall.getFastmathAttr()); - Operation *newCall = rewriter.clone(*tmpCall, mapping); - mapping.map(&*it, newCall); - rewriter.eraseOp(tmpCall); - } else { - Operation *clonedOp = rewriter.clone(*op, mapping); - for (unsigned i = 0; i < op->getNumResults(); ++i) { - mapping.map(op->getResult(i), clonedOp->getResult(i)); - } - // fir.declare changes its type when hoisting it out of omp.target to - // omp.target_data Introduce a load, if original declareOp input is not of - // reference type, but cloned delcareOp input is reference type. - if (fir::DeclareOp clonedDeclareOp = dyn_cast(clonedOp)) { - auto originalDeclareOp = cast(op); - Type originalInType = originalDeclareOp.getMemref().getType(); - Type clonedInType = clonedDeclareOp.getMemref().getType(); - - fir::ReferenceType originalRefType = - dyn_cast(originalInType); - fir::ReferenceType clonedRefType = - dyn_cast(clonedInType); - if (!originalRefType && clonedRefType) { - Type clonedEleTy = clonedRefType.getElementType(); - if (clonedEleTy == originalDeclareOp.getType()) { - opsToReplace.push_back(clonedOp); - } + Operation *clonedOp = rewriter.clone(*op, mapping); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + mapping.map(op->getResult(i), clonedOp->getResult(i)); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + + if (fir::DeclareOp clonedDeclareOp = dyn_cast(clonedOp)) { + auto originalDeclareOp = cast(op); + Type originalInType = originalDeclareOp.getMemref().getType(); + Type clonedInType = clonedDeclareOp.getMemref().getType(); + + fir::ReferenceType originalRefType = + dyn_cast(originalInType); + fir::ReferenceType clonedRefType = + dyn_cast(clonedInType); + if (!originalRefType && clonedRefType) { + Type clonedEleTy = clonedRefType.getElementType(); + if (clonedEleTy == originalDeclareOp.getType()) { + opsToReplace.push_back(clonedOp); } } + } if (isa(clonedOp) || isa(clonedOp)) opsToReplace.push_back(clonedOp); - } + if (isRuntimeCall(clonedOp)) { + fir::CallOp runtimeCall = cast(op); + if ((*runtimeCall.getCallee()).getRootReference().getValue() == + "_FortranAAssign") { + opsToReplace.push_back(clonedOp); + } else { + llvm_unreachable("Unhandled runtime call hoisting."); + } + } } - // Replace fir.allocmem with omp.target_allocmem, - // fir.freemem with omp.target_freemem. for (Operation *op : opsToReplace) { if (auto allocOp = dyn_cast(op)) { rewriter.setInsertionPoint(allocOp); @@ -963,16 +1064,40 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter) { Value loadedValue = rewriter.create( clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); + } else if (isRuntimeCall(op)) { + rewriter.setInsertionPoint(op); + fir::CallOp runtimeCall = cast(op); + SmallVector operands = runtimeCall.getOperands(); + mlir::Location loc = runtimeCall.getLoc(); + fir::FirOpBuilder builder{rewriter, op}; + assert(operands.size() == 4); + Value sourceFile{operands[2]}, sourceLine{operands[3]}; + + auto fromBaseAddr = + genDescriptorGetBaseAddress(builder, loc, operands[1]); + auto toBaseAddr = genDescriptorGetBaseAddress(builder, loc, operands[0]); + auto dataSizeInBytes = + genDescriptorGetDataSizeInBytes(builder, loc, operands[1]); + + Value toPtr = + genOmpGetMappedPtrIfPresent(builder, loc, toBaseAddr, device, module); + Value fromPtr = genOmpGetMappedPtrIfPresent(builder, loc, fromBaseAddr, + device, module); + Value zero = genI64Constant(loc, rewriter, 0); + genOmpTargetMemcpyCall(builder, loc, toPtr, fromPtr, dataSizeInBytes, + zero, zero, device, module); + rewriter.eraseOp(op); } } rewriter.eraseOp(targetOp); } -void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) { +void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter, + mlir::ModuleOp module) { auto tuple = getNestedOpToIsolate(targetOp); if (!tuple) { LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); - moveToHost(targetOp, rewriter); + moveToHost(targetOp, rewriter, module); return; } @@ -982,18 +1107,18 @@ void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter) { if (splitBefore && splitAfter) { auto res = isolateOp(toIsolate, splitAfter, rewriter); - moveToHost(res.preTargetOp, rewriter); - fissionTarget(res.postTargetOp, rewriter); + moveToHost(res.preTargetOp, rewriter, module); + fissionTarget(res.postTargetOp, rewriter, module); return; } if (splitBefore) { auto res = isolateOp(toIsolate, splitAfter, rewriter); - moveToHost(res.preTargetOp, rewriter); + moveToHost(res.preTargetOp, rewriter, module); return; } if (splitAfter) { auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter); - fissionTarget(res.postTargetOp, rewriter); + fissionTarget(res.postTargetOp, rewriter, module); return; } } @@ -1023,7 +1148,7 @@ class LowerWorkdistributePass for (auto targetOp : targetOps) { auto res = splitTargetData(targetOp, rewriter); if (res) - fissionTarget(res->targetOp, rewriter); + fissionTarget(res->targetOp, rewriter, moduleOp); } } } diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir index 19bdb9ce10fbd..25ef34f81b492 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -14,7 +14,7 @@ // CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} // CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} // CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} -// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} // CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { // CHECK: %[[VAL_11:.*]] = fir.alloca index // CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} @@ -28,29 +28,30 @@ // CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap // CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} // CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} -// CHECK: %[[VAL_23:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref // CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref // CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index -// CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: %[[VAL_29:.*]] = "fir.omp_target_allocmem"(%[[VAL_28]], %[[VAL_23]]) <{in_type = index, operandSegmentSizes = array, uniq_name = "dev_buf"}> : (i32, index) -> !fir.heap +// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_29:.*]] = omp.target_allocmem %[[VAL_23]] : i32, index, %[[VAL_27]] {uniq_name = "dev_buf"} +// CHECK: %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (i64) -> !fir.heap // CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref // CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref // CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref -// CHECK: fir.store %[[VAL_29]] to %[[VAL_20]] : !fir.ref> -// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_30:.*]], %[[VAL_8]] -> %[[VAL_31:.*]], %[[VAL_9]] -> %[[VAL_32:.*]], %[[VAL_10]] -> %[[VAL_33:.*]], %[[VAL_13]] -> %[[VAL_34:.*]], %[[VAL_16]] -> %[[VAL_35:.*]], %[[VAL_19]] -> %[[VAL_36:.*]], %[[VAL_22]] -> %[[VAL_37:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { -// CHECK: %[[VAL_38:.*]] = fir.load %[[VAL_34]] : !fir.llvm_ptr +// CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_31:.*]], %[[VAL_8]] -> %[[VAL_32:.*]], %[[VAL_9]] -> %[[VAL_33:.*]], %[[VAL_10]] -> %[[VAL_34:.*]], %[[VAL_13]] -> %[[VAL_35:.*]], %[[VAL_16]] -> %[[VAL_36:.*]], %[[VAL_19]] -> %[[VAL_37:.*]], %[[VAL_22]] -> %[[VAL_38:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { // CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr // CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr -// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr> -// CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_39]], %[[VAL_39]] : index +// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr> +// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_40]] : index // CHECK: omp.teams { // CHECK: omp.parallel { // CHECK: omp.distribute { // CHECK: omp.wsloop { -// CHECK: omp.loop_nest (%[[VAL_43:.*]]) : index = (%[[VAL_38]]) to (%[[VAL_39]]) inclusive step (%[[VAL_40]]) { -// CHECK: fir.store %[[VAL_42]] to %[[VAL_41]] : !fir.heap +// CHECK: omp.loop_nest (%[[VAL_44:.*]]) : index = (%[[VAL_39]]) to (%[[VAL_40]]) inclusive step (%[[VAL_41]]) { +// CHECK: fir.store %[[VAL_43]] to %[[VAL_42]] : !fir.heap // CHECK: omp.yield // CHECK: } // CHECK: } {omp.composite} @@ -61,14 +62,15 @@ // CHECK: } // CHECK: omp.terminator // CHECK: } -// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_11]] : !fir.ref -// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_14]] : !fir.ref -// CHECK: %[[VAL_46:.*]] = fir.load %[[VAL_17]] : !fir.ref -// CHECK: %[[VAL_47:.*]] = fir.load %[[VAL_20]] : !fir.ref> -// CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_45]], %[[VAL_45]] : index -// CHECK: fir.store %[[VAL_44]] to %[[VAL_47]] : !fir.heap -// CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: "fir.omp_target_freemem"(%[[VAL_49]], %[[VAL_47]]) : (i32, !fir.heap) -> () +// CHECK: %[[VAL_45:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_46:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_47:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_48:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_49:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_47]], %[[VAL_47]] : index +// CHECK: fir.store %[[VAL_46]] to %[[VAL_49]] : !fir.heap +// CHECK: %[[VAL_51:.*]] = fir.convert %[[VAL_49]] : (!fir.heap) -> i64 +// CHECK: omp.target_freemem %[[VAL_45]], %[[VAL_51]] : i32, i64 // CHECK: omp.terminator // CHECK: } // CHECK: return From bdda7576d463e609301c1f032efb50326d36024b Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 2 Sep 2025 21:37:39 +0530 Subject: [PATCH 06/21] clang-format code --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 8ead3d57eca98..fe07070f008c9 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -1023,17 +1023,17 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, } } } - if (isa(clonedOp) || isa(clonedOp)) + if (isa(clonedOp) || isa(clonedOp)) + opsToReplace.push_back(clonedOp); + if (isRuntimeCall(clonedOp)) { + fir::CallOp runtimeCall = cast(op); + if ((*runtimeCall.getCallee()).getRootReference().getValue() == + "_FortranAAssign") { opsToReplace.push_back(clonedOp); - if (isRuntimeCall(clonedOp)) { - fir::CallOp runtimeCall = cast(op); - if ((*runtimeCall.getCallee()).getRootReference().getValue() == - "_FortranAAssign") { - opsToReplace.push_back(clonedOp); - } else { - llvm_unreachable("Unhandled runtime call hoisting."); - } + } else { + llvm_unreachable("Unhandled runtime call hoisting."); } + } } for (Operation *op : opsToReplace) { From ac807b6f7bc360dd46c6223f0b483c468810c730 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 10 Sep 2025 11:41:51 +0530 Subject: [PATCH 07/21] Remove openp-to-llvm ir changes. Created new PR #157717 --- .../Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 1ece86729dbd0..9fcb02eb4be3d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5518,9 +5518,6 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, omp::LoopNestOp loopOp = castOrGetParentOfType(capturedOp); unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0; - if (targetOp.getHostEvalVars().empty()) - numLoops = 0; - Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit; llvm::SmallVector lowerBounds(numLoops), upperBounds(numLoops), steps(numLoops); From fd6752dd1ca49a4c2be8986d196fda1b3b41b2c5 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 10 Sep 2025 15:39:22 +0530 Subject: [PATCH 08/21] Fix CI errors --- flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index fe07070f008c9..7fc59eee2ca2d 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -540,7 +540,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); return TempOmpVar{mapInfoFrom, mapInfoTo}; -}; +} static bool usedOutsideSplit(Value v, Operation *split) { if (!split) @@ -1071,8 +1071,6 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, mlir::Location loc = runtimeCall.getLoc(); fir::FirOpBuilder builder{rewriter, op}; assert(operands.size() == 4); - Value sourceFile{operands[2]}, sourceLine{operands[3]}; - auto fromBaseAddr = genDescriptorGetBaseAddress(builder, loc, operands[1]); auto toBaseAddr = genDescriptorGetBaseAddress(builder, loc, operands[0]); From 981607376e2c8b66dcfd3ce3d9a95d7cf61debc4 Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 15 Sep 2025 18:48:46 +0530 Subject: [PATCH 09/21] Use host_eval on target in host execution. --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 743 +++++++++++++----- .../lower-workdistribute-fission-host.mlir | 117 +++ .../lower-workdistribute-fission-target.mlir | 3 + 3 files changed, 647 insertions(+), 216 deletions(-) create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 7fc59eee2ca2d..aa0e1f3416114 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -16,8 +16,7 @@ // It hoists all the other ops outside target region. // Relaces heap allocation on target with omp.target_allocmem and // deallocation with omp.target_freemem from host. Also replaces -// runtime function "Assign" with equivalent omp function. ex. @_FortranAAssign -// on target, once hoisted outside target is replaced with @_FortranAAssign_omp. +// runtime function "Assign" with omp.target_memcpy. // //===----------------------------------------------------------------------===// @@ -75,7 +74,7 @@ static bool isRuntimeCall(Operation *op) { } // This is the single source of truth about whether we should parallelize an -// operation nested in an omp.execute region. +// operation nested in an omp.workdistribute region. static bool shouldParallelize(Operation *op) { if (llvm::any_of(op->getResults(), [](OpResult v) -> bool { return !v.use_empty(); })) @@ -87,6 +86,7 @@ static bool shouldParallelize(Operation *op) { return false; return *unordered; } + // True if the op is a runtime call to Assign if (isRuntimeCall(op) && (op->getName().getStringRef() == "_FortranAAssign")) { return true; @@ -235,6 +235,7 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { return changed; } +// Generate omp.parallel operation with an empty region. static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { auto parallelOp = rewriter.create(loc); parallelOp.setComposite(composite); @@ -243,6 +244,7 @@ static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { return; } +// Generate omp.distribute operation with an empty region. static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { mlir::omp::DistributeOperands distributeClauseOps; auto distributeOp = @@ -253,6 +255,7 @@ static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { return; } +// Generate loop nest clause operands from fir.do_loop operation. static void genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, mlir::omp::LoopNestOperands &loopNestClauseOps) { @@ -264,6 +267,7 @@ genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); } +// Generate omp.wsloop operation with an empty region and static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, const mlir::omp::LoopNestOperands &clauseOps, bool composite) { @@ -286,7 +290,6 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, if (auto resultOp = dyn_cast(terminatorOp)) { rewriter.setInsertionPoint(terminatorOp); rewriter.create(doLoop->getLoc()); - // rewriter.erase(terminatorOp); terminatorOp->erase(); } return; @@ -319,12 +322,22 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, // } // } -static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) { +static bool +WorkdistributeDoLower(omp::WorkdistributeOp workdistribute, + SetVector &targetOpsToProcess) { OpBuilder rewriter(workdistribute); auto doLoop = getPerfectlyNested(workdistribute); auto wdLoc = workdistribute->getLoc(); if (doLoop && shouldParallelize(doLoop)) { assert(doLoop.getReduceOperands().empty()); + + // Record the target ops to process later + if (auto teamsOp = dyn_cast(workdistribute->getParentOp())) { + auto targetOp = dyn_cast(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } genParallelOp(wdLoc, rewriter, true); genDistributeOp(wdLoc, rewriter, true); mlir::omp::LoopNestOperands loopNestClauseOps; @@ -353,7 +366,7 @@ static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) { // A() // B() // - +// If only the terminator remains in teams after hoisting, we erase teams op. static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) { auto workdistributeOp = getPerfectlyNested(teamsOp); if (!workdistributeOp) @@ -380,25 +393,20 @@ static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) { return true; } -struct SplitTargetResult { - omp::TargetOp targetOp; - omp::TargetDataOp dataOp; -}; - // If multiple workdistribute are nested in a target regions, we will need to // split the target region, but we want to preserve the data semantics of the // original data region and avoid unnecessary data movement at each of the // subkernels - we split the target region into a target_data{target} // nest where only the outer one moves the data -std::optional splitTargetData(omp::TargetOp targetOp, - RewriterBase &rewriter) { +std::optional splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { auto loc = targetOp->getLoc(); if (targetOp.getMapVars().empty()) { LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " target region has no data maps\n"); return std::nullopt; } - + // Collect all the mapinfo ops SmallVector mapInfos; for (auto opr : targetOp.getMapVars()) { auto mapInfo = cast(opr.getDefiningOp()); @@ -408,14 +416,15 @@ std::optional splitTargetData(omp::TargetOp targetOp, rewriter.setInsertionPoint(targetOp); SmallVector innerMapInfos; SmallVector outerMapInfos; - + // Create new mapinfo ops for the inner target region for (auto mapInfo : mapInfos) { auto originalMapType = (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); auto originalCaptureType = mapInfo.getMapCaptureType(); llvm::omp::OpenMPOffloadMappingFlags newMapType; mlir::omp::VariableCaptureKind newCaptureType; - + // For bycopy, we keep the same map type and capture type + // For byref, we change the map type to none and keep the capture type if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) { newMapType = originalMapType; newCaptureType = originalCaptureType; @@ -441,12 +450,13 @@ std::optional splitTargetData(omp::TargetOp targetOp, auto ifExpr = targetOp.getIfExpr(); auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); auto devicePtrVars = targetOp.getIsDevicePtrVars(); + // Create the target data op auto targetDataOp = rewriter.create( loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); rewriter.create(loc); rewriter.setInsertionPointToStart(taregtDataBlock); - + // Create the inner target op auto newTargetOp = rewriter.create( targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), @@ -460,9 +470,8 @@ std::optional splitTargetData(omp::TargetOp targetOp, targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), newTargetOp.getRegion().begin()); - rewriter.replaceOp(targetOp, targetDataOp); - return SplitTargetResult{cast(newTargetOp), targetDataOp}; + return newTargetOp; } // getNestedOpToIsolate function is designed to identify a specific teams @@ -480,20 +489,23 @@ getNestedOpToIsolate(omp::TargetOp targetOp) { if (first && last) return std::nullopt; - if (isa(&op)) + if (isa(&op)) return {{&op, first, last}}; } return std::nullopt; } +// Temporary structure to hold the two mapinfo ops struct TempOmpVar { omp::MapInfoOp from, to; }; +// isPtr checks if the type is a pointer or reference type. static bool isPtr(Type ty) { return isa(ty) || isa(ty); } +// getPtrTypeForOmp returns an LLVM pointer type for the given type. static Type getPtrTypeForOmp(Type ty) { if (isPtr(ty)) return LLVM::LLVMPointerType::get(ty.getContext()); @@ -501,6 +513,7 @@ static Type getPtrTypeForOmp(Type ty) { return fir::LLVMPointerType::get(ty); } +// allocateTempOmpVar allocates a temporary variable for OpenMP mapping static TempOmpVar allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) { MLIRContext &ctx = *ty.getContext(); @@ -542,6 +555,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, return TempOmpVar{mapInfoFrom, mapInfoTo}; } +// usedOutsideSplit checks if a value is used outside the split operation. static bool usedOutsideSplit(Value v, Operation *split) { if (!split) return false; @@ -557,6 +571,7 @@ static bool usedOutsideSplit(Value v, Operation *split) { return false; }; +// isRecomputableAfterFission checks if an operation can be recomputed static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { if (isa(op)) return true; @@ -572,12 +587,7 @@ static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { return false; } -struct SplitResult { - omp::TargetOp preTargetOp; - omp::TargetOp isolatedTargetOp; - omp::TargetOp postTargetOp; -}; - +// collectNonRecomputableDeps collects dependencies that cannot be recomputed static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, SetVector &nonRecomputable, SetVector &toCache, @@ -597,20 +607,40 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, toRecompute); } +// createBlockArgsAndMap creates block arguments and maps them static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, + SmallVector &hostEvalVars, SmallVector &mapOperands, SmallVector &allocs, IRMapping &irMapping) { - // Map `map_operands` to block arguments. + // FIRST: Map `host_eval_vars` to block arguments + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < hostEvalVars.size(); ++i) { + Value originalValue; + BlockArgument newArg; + if (i < originalHostEvalVarsSize) { + originalValue = targetBlock->getArgument(i); // Host_eval args come first + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } else { + originalValue = hostEvalVars[i]; + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // SECOND: Map `map_operands` to block arguments unsigned originalMapVarsSize = targetOp.getMapVars().size(); for (unsigned i = 0; i < mapOperands.size(); ++i) { Value originalValue; BlockArgument newArg; // Map the new arguments from the original block. if (i < originalMapVarsSize) { - originalValue = targetBlock->getArgument(i); + originalValue = targetBlock->getArgument(originalHostEvalVarsSize + + i); // Offset by host_eval count newArg = newTargetBlock->addArgument(originalValue.getType(), originalValue.getLoc()); } @@ -622,10 +652,12 @@ static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, } irMapping.map(originalValue, newArg); } - // Map `private_vars` to block arguments. + + // THIRD: Map `private_vars` to block arguments (if any) unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size(); for (unsigned i = 0; i < originalPrivateVarsSize; ++i) { - auto originalArg = targetBlock->getArgument(originalMapVarsSize + i); + auto originalArg = targetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); auto newArg = newTargetBlock->addArgument(originalArg.getType(), originalArg.getLoc()); irMapping.map(originalArg, newArg); @@ -633,24 +665,25 @@ static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, return; } +// reloadCacheAndRecompute reloads cached values and recomputes operations static void reloadCacheAndRecompute( Location loc, RewriterBase &rewriter, Operation *splitBefore, omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, - SmallVector &mapOperands, SmallVector &allocs, - SetVector &toRecompute, IRMapping &irMapping) { - createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, newTargetBlock, - mapOperands, allocs, irMapping); + SmallVector &hostEvalVars, SmallVector &mapOperands, + SmallVector &allocs, SetVector &toRecompute, + IRMapping &irMapping) { // Handle the load operations for the allocs. rewriter.setInsertionPointToStart(newTargetBlock); auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); unsigned originalMapVarsSize = targetOp.getMapVars().size(); + unsigned hostEvalVarsSize = hostEvalVars.size(); // Create Stores for allocs. for (unsigned i = 0; i < allocs.size(); ++i) { Value original = allocs[i]; // Get the new block argument for this specific allocated value. - Value newArg = newTargetBlock->getArgument(originalMapVarsSize + i); - + Value newArg = + newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i); Value restored; if (isPtr(original.getType())) { restored = rewriter.create(loc, llvmPtrTy, newArg); @@ -662,171 +695,66 @@ static void reloadCacheAndRecompute( } irMapping.map(original, restored); } - + // Clone the operations if they are in the toRecompute set. for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { if (toRecompute.contains(&*it)) rewriter.clone(*it, irMapping); } } -// isolateOp method rewrites a omp.target_data { omp.target } in to -// omp.target_data { -// // preTargetOp region contains ops before splitBeforeOp. -// omp.target {} -// // isolatedTargetOp region contains splitBeforeOp, -// omp.target {} -// // postTargetOp region contains ops after splitBeforeOp. -// omp.target {} -// } -static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, - RewriterBase &rewriter) { - auto targetOp = cast(splitBeforeOp->getParentOp()); - assert(targetOp); - auto loc = targetOp.getLoc(); - auto *targetBlock = &targetOp.getRegion().front(); - rewriter.setInsertionPoint(targetOp); - - auto preMapOperands = SmallVector(targetOp.getMapVars()); - auto postMapOperands = SmallVector(targetOp.getMapVars()); - - SmallVector requiredVals; - SetVector toCache; - SetVector toRecompute; - SetVector nonRecomputable; - SmallVector allocs; - - for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); - it++) { - for (auto res : it->getResults()) { - if (usedOutsideSplit(res, splitBeforeOp)) - requiredVals.push_back(res); +// Given a teamsOp, navigate down the nested structure to find the +// innermost LoopNestOp. The expected nesting is: +// teams -> parallel -> distribute -> wsloop -> loop_nest +static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) { + if (teamsOp.getRegion().empty()) + return nullptr; + // Ensure the teams region has a single block. + if (teamsOp.getRegion().getBlocks().size() != 1) + return nullptr; + // Find parallel op inside teams + mlir::omp::ParallelOp parallelOp = nullptr; + for (auto &op : teamsOp.getRegion().front()) { + if (auto parallel = dyn_cast(op)) { + parallelOp = parallel; + break; } - if (!isRecomputableAfterFission(&*it, splitBeforeOp)) - nonRecomputable.insert(&*it); } + if (!parallelOp) + return nullptr; - for (auto requiredVal : requiredVals) - collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, - toRecompute); - - for (Operation *op : toCache) { - for (auto res : op->getResults()) { - auto alloc = - allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); - allocs.push_back(res); - preMapOperands.push_back(alloc.from); - postMapOperands.push_back(alloc.to); + // Find distribute op inside parallel + mlir::omp::DistributeOp distributeOp = nullptr; + for (auto &op : parallelOp.getRegion().front()) { + if (auto distribute = dyn_cast(op)) { + distributeOp = distribute; + break; } } + if (!distributeOp) + return nullptr; - rewriter.setInsertionPoint(targetOp); - - auto preTargetOp = rewriter.create( - targetOp.getLoc(), targetOp.getAllocateVars(), - targetOp.getAllocatorVars(), targetOp.getBareAttr(), - targetOp.getDependKindsAttr(), targetOp.getDependVars(), - targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), - targetOp.getHostEvalVars(), targetOp.getIfExpr(), - targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), - targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), - preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), - targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), - targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); - auto *preTargetBlock = rewriter.createBlock( - &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); - IRMapping preMapping; - - createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock, - preMapOperands, allocs, preMapping); - - // Handle the store operations for the allocs. - rewriter.setInsertionPointToStart(preTargetBlock); - auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); - - // Clone the original operations. - for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); - it++) { - rewriter.clone(*it, preMapping); - } - - unsigned originalMapVarsSize = targetOp.getMapVars().size(); - // Create Stores for allocs. - for (unsigned i = 0; i < allocs.size(); ++i) { - Value originalResult = allocs[i]; - Value toStore = preMapping.lookup(originalResult); - // Get the new block argument for this specific allocated value. - Value newArg = preTargetBlock->getArgument(originalMapVarsSize + i); - - if (isPtr(originalResult.getType())) { - if (!isa(toStore.getType())) - toStore = rewriter.create(loc, llvmPtrTy, toStore); - rewriter.create(loc, toStore, newArg); - } else { - rewriter.create(loc, toStore, newArg); + // Find wsloop op inside distribute + mlir::omp::WsloopOp wsloopOp = nullptr; + for (auto &op : distributeOp.getRegion().front()) { + if (auto wsloop = dyn_cast(op)) { + wsloopOp = wsloop; + break; } } - rewriter.create(loc); - - rewriter.setInsertionPoint(targetOp); - - auto isolatedTargetOp = rewriter.create( - targetOp.getLoc(), targetOp.getAllocateVars(), - targetOp.getAllocatorVars(), targetOp.getBareAttr(), - targetOp.getDependKindsAttr(), targetOp.getDependVars(), - targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), - targetOp.getHostEvalVars(), targetOp.getIfExpr(), - targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), - targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), - postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), - targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), - targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); - - auto *isolatedTargetBlock = - rewriter.createBlock(&isolatedTargetOp.getRegion(), - isolatedTargetOp.getRegion().begin(), {}, {}); - - IRMapping isolatedMapping; - reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, - isolatedTargetBlock, postMapOperands, allocs, - toRecompute, isolatedMapping); - rewriter.clone(*splitBeforeOp, isolatedMapping); - rewriter.create(loc); - - omp::TargetOp postTargetOp = nullptr; + if (!wsloopOp) + return nullptr; - if (splitAfter) { - rewriter.setInsertionPoint(targetOp); - postTargetOp = rewriter.create( - targetOp.getLoc(), targetOp.getAllocateVars(), - targetOp.getAllocatorVars(), targetOp.getBareAttr(), - targetOp.getDependKindsAttr(), targetOp.getDependVars(), - targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), - targetOp.getHostEvalVars(), targetOp.getIfExpr(), - targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), - targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), - postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), - targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), - targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); - auto *postTargetBlock = rewriter.createBlock( - &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); - IRMapping postMapping; - reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, - postTargetBlock, postMapOperands, allocs, - toRecompute, postMapping); - - assert(splitBeforeOp->getNumResults() == 0 || - llvm::all_of(splitBeforeOp->getResults(), - [](Value result) { return result.use_empty(); })); - - for (auto it = std::next(splitBeforeOp->getIterator()); - it != targetBlock->end(); it++) - rewriter.clone(*it, postMapping); + // Find loop_nest op inside wsloop + for (auto &op : wsloopOp.getRegion().front()) { + if (auto loopNest = dyn_cast(op)) { + return loopNest; + } } - rewriter.eraseOp(targetOp); - return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; + return nullptr; } +// Generate LLVM constant operations for i32 and i64 types. static mlir::LLVM::ConstantOp genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { mlir::Type i32Ty = rewriter.getI32Type(); @@ -834,6 +762,7 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { return rewriter.create(loc, i32Ty, attr); } +// Generate LLVM constant operations for i64 type. static mlir::LLVM::ConstantOp genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { mlir::Type i64Ty = rewriter.getI64Type(); @@ -841,6 +770,9 @@ genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { return rewriter.create(loc, i64Ty, attr); } +// Given a box descriptor, extract the base address of the data it describes. +// If the box descriptor is a reference, load it first. +// The base address is returned as an i8* pointer. static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, Location loc, Value boxDesc) { Value box = boxDesc; @@ -858,6 +790,9 @@ static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, return rawAddr; } +// Given a box descriptor, extract the total number of elements in the array it +// describes. If the box descriptor is a reference, load it first. +// The total number of elements is returned as an i64 value. static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, Location loc, Value boxDesc) { Value box = boxDesc; @@ -870,6 +805,9 @@ static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box); } +// Given a box descriptor, extract the size of each element in the array it +// describes. If the box descriptor is a reference, load it first. +// The element size is returned as an i64 value. static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, Value boxDesc) { Value box = boxDesc; @@ -882,6 +820,10 @@ static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, return fir::BoxEleSizeOp::create(builder, loc, i64Type, box); } +// Given a box descriptor, compute the total size in bytes of the data it +// describes. This is done by multiplying the total number of elements by the +// size of each element. If the box descriptor is a reference, load it first. +// The total size in bytes is returned as an i64 value. static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, Location loc, Value boxDesc) { Value box = boxDesc; @@ -895,6 +837,11 @@ static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize); } +// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to +// retrieve the device pointer corresponding to a given host pointer and device +// number. If no mapping exists, the original host pointer is returned. +// Signature: +// void *omp_get_mapped_ptr(void *host_ptr, int device_num); static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value hostPtr, @@ -930,15 +877,18 @@ static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, return result; } +// Generate a call to the OpenMP runtime function `omp_target_memcpy` to +// perform memory copy between host and device or between devices. +// Signature: +// int omp_target_memcpy(void *dst, const void *src, size_t length, +// size_t dst_offset, size_t src_offset, +// int dst_device, int src_device); static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value dst, mlir::Value src, mlir::Value length, mlir::Value dstOffset, mlir::Value srcOffset, mlir::Value device, mlir::ModuleOp module) { auto *context = builder.getContext(); - // int omp_target_memcpy(void *dst, const void *src, size_t length, - // size_t dst_offset, size_t src_offset, - // int dst_device, int src_device) auto funcName = "omp_target_memcpy"; auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit @@ -962,30 +912,48 @@ static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, return; } +// Struct to hold the host eval vars corresponding to loop bounds and steps +struct HostEvalVars { + SmallVector lbs; + SmallVector ubs; + SmallVector steps; +}; + // moveToHost method clones all the ops from target region outside of it. // It hoists runtime functions and replaces them with omp vesions. // Also hoists and replaces fir.allocmem with omp.target_allocmem and // fir.freemem with omp.target_freemem static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, - mlir::ModuleOp module) { + mlir::ModuleOp module, + struct HostEvalVars &hostEvalVars) { OpBuilder::InsertionGuard guard(rewriter); Block *targetBlock = &targetOp.getRegion().front(); assert(targetBlock == &targetOp.getRegion().back()); IRMapping mapping; + // create mapping for host_eval_vars + unsigned hostEvalVarCount = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) { + Value hostEvalVar = targetOp.getHostEvalVars()[i]; + BlockArgument arg = targetBlock->getArguments()[i]; + mapping.map(arg, hostEvalVar); + } + // create mapping for map_vars for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) { Value mapInfo = targetOp.getMapVars()[i]; - BlockArgument arg = targetBlock->getArguments()[i]; + BlockArgument arg = targetBlock->getArguments()[hostEvalVarCount + i]; Operation *op = mapInfo.getDefiningOp(); assert(op); auto mapInfoOp = cast(op); // map the block argument to the host-side variable pointer mapping.map(arg, mapInfoOp.getVarPtr()); } + // create mapping for private_vars unsigned mapSize = targetOp.getMapVars().size(); for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) { Value privateVar = targetOp.getPrivateVars()[i]; // The mapping should link the device-side variable to the host-side one. - BlockArgument arg = targetBlock->getArguments()[mapSize + i]; + BlockArgument arg = + targetBlock->getArguments()[hostEvalVarCount + mapSize + i]; // Map the device-side copy (`arg`) to the host-side value (`privateVar`). mapping.map(arg, privateVar); } @@ -993,20 +961,22 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, rewriter.setInsertionPoint(targetOp); SmallVector opsToReplace; Value device = targetOp.getDevice(); + if (!device) { device = genI32Constant(targetOp.getLoc(), rewriter, 0); } + // Clone all operations. for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); it != end; ++it) { auto *op = &*it; Operation *clonedOp = rewriter.clone(*op, mapping); + // Map the results of the original op to the cloned op. for (unsigned i = 0; i < op->getNumResults(); ++i) { mapping.map(op->getResult(i), clonedOp->getResult(i)); } // fir.declare changes its type when hoisting it out of omp.target to // omp.target_data Introduce a load, if original declareOp input is not of // reference type, but cloned delcareOp input is reference type. - if (fir::DeclareOp clonedDeclareOp = dyn_cast(clonedOp)) { auto originalDeclareOp = cast(op); Type originalInType = originalDeclareOp.getMemref().getType(); @@ -1023,8 +993,10 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, } } } + // Collect the ops to be replaced. if (isa(clonedOp) || isa(clonedOp)) opsToReplace.push_back(clonedOp); + // Check for runtime calls to be replaced. if (isRuntimeCall(clonedOp)) { fir::CallOp runtimeCall = cast(op); if ((*runtimeCall.getCallee()).getRootReference().getValue() == @@ -1035,7 +1007,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, } } } - + // Replace fir.allocmem with omp.target_allocmem. for (Operation *op : opsToReplace) { if (auto allocOp = dyn_cast(op)) { rewriter.setInsertionPoint(allocOp); @@ -1048,14 +1020,20 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, allocOp.getLoc(), allocOp.getResult().getType(), ompAllocmemOp.getResult()); rewriter.replaceOp(allocOp, firConvertOp.getResult()); - } else if (auto freeOp = dyn_cast(op)) { + } + // Replace fir.freemem with omp.target_freemem. + else if (auto freeOp = dyn_cast(op)) { rewriter.setInsertionPoint(freeOp); auto firConvertOp = rewriter.create( freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); rewriter.create(freeOp.getLoc(), device, firConvertOp.getResult()); rewriter.eraseOp(freeOp); - } else if (fir::DeclareOp clonedDeclareOp = dyn_cast(op)) { + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + else if (fir::DeclareOp clonedDeclareOp = dyn_cast(op)) { Type clonedInType = clonedDeclareOp.getMemref().getType(); fir::ReferenceType clonedRefType = dyn_cast(clonedInType); @@ -1064,7 +1042,9 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, Value loadedValue = rewriter.create( clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); - } else if (isRuntimeCall(op)) { + } + // Replace runtime calls with omp versions. + else if (isRuntimeCall(op)) { rewriter.setInsertionPoint(op); fir::CallOp runtimeCall = cast(op); SmallVector operands = runtimeCall.getOperands(); @@ -1087,40 +1067,370 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, rewriter.eraseOp(op); } } + + // Update the host_eval_vars to use the mapped values. + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) { + hostEvalVars.lbs[i] = mapping.lookup(hostEvalVars.lbs[i]); + hostEvalVars.ubs[i] = mapping.lookup(hostEvalVars.ubs[i]); + hostEvalVars.steps[i] = mapping.lookup(hostEvalVars.steps[i]); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); +} + +// Result of isolateOp method +struct SplitResult { + omp::TargetOp preTargetOp; + omp::TargetOp isolatedTargetOp; + omp::TargetOp postTargetOp; +}; + +// computeAllocsCacheRecomputable method computes the allocs needed to cache +// the values that are used outside the split point. It also computes the ops +// that need to be cached and the ops that can be recomputed after the split. +static void computeAllocsCacheRecomputable( + omp::TargetOp targetOp, Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector &preMapOperands, SmallVector &postMapOperands, + SmallVector &allocs, SmallVector &requiredVals, + SetVector &nonRecomputable, SetVector &toCache, + SetVector &toRecompute) { + auto *targetBlock = &targetOp.getRegion().front(); + // Find all values that are used outside the split point. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + // Check if any of the results are used outside the split point. + for (auto res : it->getResults()) { + if (usedOutsideSplit(res, splitBeforeOp)) + requiredVals.push_back(res); + } + // If the op is not recomputable, add it to the nonRecomputable set. + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) + nonRecomputable.insert(&*it); + } + // For each required value, collect its dependencies. + for (auto requiredVal : requiredVals) + collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, + toRecompute); + // For each op in toCache, create an alloc and update the pre and post map + // operands. + for (Operation *op : toCache) { + for (auto res : op->getResults()) { + auto alloc = + allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); + allocs.push_back(res); + preMapOperands.push_back(alloc.from); + postMapOperands.push_back(alloc.to); + } + } +} + +// genPreTargetOp method generates the preTargetOp that contains all the ops +// before the split point. It also creates the block arguments and maps the +// values accordingly. It also creates the store operations for the allocs. +static omp::TargetOp +genPreTargetOp(omp::TargetOp targetOp, SmallVector &preMapOperands, + SmallVector &allocs, Operation *splitBeforeOp, + RewriterBase &rewriter, struct HostEvalVars &hostEvalVars, + bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector preHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of preTargetOp + omp::TargetOp preTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *preTargetBlock = rewriter.createBlock( + &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); + IRMapping preMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock, + preHostEvalVars, preMapOperands, allocs, preMapping); + + // Handle the store operations for the allocs. + rewriter.setInsertionPointToStart(preTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + // Clone the original operations. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + rewriter.clone(*it, preMapping); + } + + unsigned originalHostEvalVarsSize = preHostEvalVars.size(); + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + // Create Stores for allocs. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value originalResult = allocs[i]; + Value toStore = preMapping.lookup(originalResult); + // Get the new block argument for this specific allocated value. + Value newArg = preTargetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + // Create the store operation. + if (isPtr(originalResult.getType())) { + if (!isa(toStore.getType())) + toStore = rewriter.create(loc, llvmPtrTy, toStore); + rewriter.create(loc, toStore, newArg); + } else { + rewriter.create(loc, toStore, newArg); + } + } + rewriter.create(loc); + + // Update hostEvalVars with the mapped values for the loop bounds if we have + // a loopNestOp and we are not generating code for the target device. + omp::LoopNestOp loopNestOp = + getLoopNestFromTeams(cast(splitBeforeOp)); + if (loopNestOp && !isTargetDevice) { + for (size_t i = 0; i < loopNestOp.getLoopLowerBounds().size(); ++i) { + Value lb = loopNestOp.getLoopLowerBounds()[i]; + Value ub = loopNestOp.getLoopUpperBounds()[i]; + Value step = loopNestOp.getLoopSteps()[i]; + + hostEvalVars.lbs.push_back(preMapping.lookup(lb)); + hostEvalVars.ubs.push_back(preMapping.lookup(ub)); + hostEvalVars.steps.push_back(preMapping.lookup(step)); + } + } + + return preTargetOp; +} + +// genIsolatedTargetOp method generates the isolatedTargetOp that contains the +// ops between the split point. It also creates the block arguments and maps +// the values accordingly. It also creates the load operations for the allocs +// and recomputes the necessary ops. +static omp::TargetOp +genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector &postMapOperands, + Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector &allocs, + SetVector &toRecompute, + struct HostEvalVars &hostEvalVars, bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector isolatedHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of isolatedTargetOp + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) { + isolatedHostEvalVars.push_back(hostEvalVars.lbs[i]); + } + for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) { + isolatedHostEvalVars.push_back(hostEvalVars.ubs[i]); + } + for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) { + isolatedHostEvalVars.push_back(hostEvalVars.steps[i]); + } + } + // Create the isolated target op + omp::TargetOp isolatedTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *isolatedTargetBlock = + rewriter.createBlock(&isolatedTargetOp.getRegion(), + isolatedTargetOp.getRegion().begin(), {}, {}); + IRMapping isolatedMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, isolatedMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, toRecompute, + isolatedMapping); + + // Clone the original operations. + rewriter.clone(*splitBeforeOp, isolatedMapping); + rewriter.create(loc); + + // update the loop bounds in the isolatedTargetOp if we have host_eval vars + // and we are not generating code for the target device. + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + omp::TeamsOp teamsOp; + for (auto &op : *isolatedTargetBlock) { + if (isa(&op)) + teamsOp = cast(&op); + } + assert(teamsOp && "No teamsOp found in isolated target region"); + // Get the loopNestOp inside the teamsOp + auto loopNestOp = getLoopNestFromTeams(teamsOp); + // Get the BlockArgs related to host_eval vars and update loop_nest bounds + // to them + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + unsigned index = originalHostEvalVarsSize; + // Replace loop bounds with the block arguments passed down via host_eval + SmallVector lbs, ubs, steps; + + // Collect new lb/ub/step values from target block args + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) + lbs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) + ubs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) + steps.push_back(isolatedTargetBlock->getArgument(index++)); + + // Reset the loop bounds + loopNestOp.getLoopLowerBoundsMutable().assign(lbs); + loopNestOp.getLoopUpperBoundsMutable().assign(ubs); + loopNestOp.getLoopStepsMutable().assign(steps); + } + + return isolatedTargetOp; +} + +// genPostTargetOp method generates the postTargetOp that contains all the ops +// after the split point. It also creates the block arguments and maps the +// values accordingly. It also creates the load operations for the allocs +// and recomputes the necessary ops. +static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, + Operation *splitBeforeOp, + SmallVector &postMapOperands, + RewriterBase &rewriter, + SmallVector &allocs, + SetVector &toRecompute) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector postHostEvalVars{targetOp.getHostEvalVars()}; + // Create the post target op + omp::TargetOp postTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + // Create the block for postTargetOp + auto *postTargetBlock = rewriter.createBlock( + &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); + IRMapping postMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, postTargetBlock, + postHostEvalVars, postMapOperands, allocs, postMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + postTargetBlock, postHostEvalVars, postMapOperands, + allocs, toRecompute, postMapping); + assert(splitBeforeOp->getNumResults() == 0 || + llvm::all_of(splitBeforeOp->getResults(), + [](Value result) { return result.use_empty(); })); + // Clone the original operations after the split point. + for (auto it = std::next(splitBeforeOp->getIterator()); + it != targetBlock->end(); it++) + rewriter.clone(*it, postMapping); + return postTargetOp; +} + +// isolateOp method rewrites a omp.target_data { omp.target } in to +// omp.target_data { +// // preTargetOp region contains ops before splitBeforeOp. +// omp.target {} +// // isolatedTargetOp region contains splitBeforeOp, +// omp.target {} +// // postTargetOp region contains ops after splitBeforeOp. +// omp.target {} +// } +// It also handles the mapping of variables and the caching/recomputing +// of values as needed. +static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, + RewriterBase &rewriter, mlir::ModuleOp module, + bool isTargetDevice) { + auto targetOp = cast(splitBeforeOp->getParentOp()); + assert(targetOp); + rewriter.setInsertionPoint(targetOp); + + // Prepare the map operands for preTargetOp and postTargetOp + auto preMapOperands = SmallVector(targetOp.getMapVars()); + auto postMapOperands = SmallVector(targetOp.getMapVars()); + + // Vectors to hold analysis results + SmallVector requiredVals; + SetVector toCache; + SetVector toRecompute; + SetVector nonRecomputable; + SmallVector allocs; + struct HostEvalVars hostEvalVars; + + // Analyze the ops in target region to determine which ops need to be + // cached and which ops need to be recomputed + computeAllocsCacheRecomputable( + targetOp, splitBeforeOp, rewriter, preMapOperands, postMapOperands, + allocs, requiredVals, nonRecomputable, toCache, toRecompute); + + rewriter.setInsertionPoint(targetOp); + + // Generate the preTargetOp that contains all the ops before splitBeforeOp. + auto preTargetOp = + genPreTargetOp(targetOp, preMapOperands, allocs, splitBeforeOp, rewriter, + hostEvalVars, isTargetDevice); + + // Move the ops of preTarget to host. + moveToHost(preTargetOp, rewriter, module, hostEvalVars); + rewriter.setInsertionPoint(targetOp); + + // Generate the isolatedTargetOp + omp::TargetOp isolatedTargetOp = + genIsolatedTargetOp(targetOp, postMapOperands, splitBeforeOp, rewriter, + allocs, toRecompute, hostEvalVars, isTargetDevice); + + omp::TargetOp postTargetOp = nullptr; + // Generate the postTargetOp that contains all the ops after splitBeforeOp. + if (splitAfter) { + rewriter.setInsertionPoint(targetOp); + postTargetOp = genPostTargetOp(targetOp, splitBeforeOp, postMapOperands, + rewriter, allocs, toRecompute); + } + // Finally erase the original targetOp. rewriter.eraseOp(targetOp); + return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; } -void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter, - mlir::ModuleOp module) { +// Recursively fission target ops until no more nested ops can be isolated. +static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter, + mlir::ModuleOp module, bool isTargetDevice) { auto tuple = getNestedOpToIsolate(targetOp); if (!tuple) { LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); - moveToHost(targetOp, rewriter, module); + struct HostEvalVars hostEvalVars; + moveToHost(targetOp, rewriter, module, hostEvalVars); return; } - Operation *toIsolate = std::get<0>(*tuple); bool splitBefore = !std::get<1>(*tuple); bool splitAfter = !std::get<2>(*tuple); if (splitBefore && splitAfter) { - auto res = isolateOp(toIsolate, splitAfter, rewriter); - moveToHost(res.preTargetOp, rewriter, module); - fissionTarget(res.postTargetOp, rewriter, module); - return; - } - if (splitBefore) { - auto res = isolateOp(toIsolate, splitAfter, rewriter); - moveToHost(res.preTargetOp, rewriter, module); - return; - } - if (splitAfter) { - auto res = isolateOp(toIsolate->getNextNode(), splitAfter, rewriter); - fissionTarget(res.postTargetOp, rewriter, module); + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + fissionTarget(res.postTargetOp, rewriter, module, isTargetDevice); return; + } else { + llvm::errs() << "Unhandled case in fissionTarget\n"; + llvm::report_fatal_error("Unhandled case in fissionTarget"); } } +// Pass to lower omp.workdistribute ops. class LowerWorkdistributePass : public flangomp::impl::LowerWorkdistributeBase { public: @@ -1128,25 +1438,26 @@ class LowerWorkdistributePass MLIRContext &context = getContext(); auto moduleOp = getOperation(); bool changed = false; + SetVector targetOpsToProcess; moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { changed |= FissionWorkdistribute(workdistribute); }); moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { - changed |= WorkdistributeDoLower(workdistribute); + changed |= WorkdistributeDoLower(workdistribute, targetOpsToProcess); }); moduleOp->walk([&](mlir::omp::TeamsOp teams) { changed |= TeamsWorkdistributeToSingleOp(teams); }); if (changed) { - SmallVector targetOps; - moduleOp->walk( - [&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); }); + bool isTargetDevice = + llvm::cast(*moduleOp) + .getIsTargetDevice(); IRRewriter rewriter(&context); - for (auto targetOp : targetOps) { + for (auto targetOp : targetOpsToProcess) { auto res = splitTargetData(targetOp, rewriter); if (res) - fissionTarget(res->targetOp, rewriter, moduleOp); + fissionTarget(*res, rewriter, moduleOp, isTargetDevice); } } } diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir new file mode 100644 index 0000000000000..b4c9598a78f0e --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir @@ -0,0 +1,117 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s +// Test lowering of workdistribute after fission on host device. + +// CHECK-LABEL: func.func @x( +// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} +// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} +// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"} +// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { +// CHECK: %[[VAL_11:.*]] = fir.alloca index +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_14:.*]] = fir.alloca index +// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_17:.*]] = fir.alloca index +// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_29:.*]] = omp.target_allocmem %[[VAL_23]] : i32, index, %[[VAL_27]] {uniq_name = "dev_buf"} +// CHECK: %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (i64) -> !fir.heap +// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref +// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref +// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref +// CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target host_eval(%[[VAL_24]] -> %[[VAL_31:.*]], %[[VAL_25]] -> %[[VAL_32:.*]], %[[VAL_26]] -> %[[VAL_33:.*]] : index, index, index) map_entries(%[[VAL_7]] -> %[[VAL_34:.*]], %[[VAL_8]] -> %[[VAL_35:.*]], %[[VAL_9]] -> %[[VAL_36:.*]], %[[VAL_10]] -> %[[VAL_37:.*]], %[[VAL_13]] -> %[[VAL_38:.*]], %[[VAL_16]] -> %[[VAL_39:.*]], %[[VAL_19]] -> %[[VAL_40:.*]], %[[VAL_22]] -> %[[VAL_41:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr +// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.llvm_ptr +// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.llvm_ptr +// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.llvm_ptr> +// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_47:.*]]) : index = (%[[VAL_31]]) to (%[[VAL_32]]) inclusive step (%[[VAL_33]]) { +// CHECK: fir.store %[[VAL_46]] to %[[VAL_45]] : !fir.heap +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: %[[VAL_48:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_49:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_50:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_51:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_52:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_50]], %[[VAL_50]] : index +// CHECK: fir.store %[[VAL_49]] to %[[VAL_52]] : !fir.heap +// CHECK: %[[VAL_54:.*]] = fir.convert %[[VAL_52]] : (!fir.heap) -> i64 +// CHECK: omp.target_freemem %[[VAL_48]], %[[VAL_54]] : i32, i64 +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false} { +func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { + %lb_ref = fir.alloca index {bindc_name = "lb"} + fir.store %lb to %lb_ref : !fir.ref + %ub_ref = fir.alloca index {bindc_name = "ub"} + fir.store %ub to %ub_ref : !fir.ref + %step_ref = fir.alloca index {bindc_name = "step"} + fir.store %step to %step_ref : !fir.ref + + %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} + %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} + %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} + %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} + + omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %ARG0 : !fir.ref + %ub_val = fir.load %ARG1 : !fir.ref + %step_val = fir.load %ARG2 : !fir.ref + %one = arith.constant 1 : index + + %20 = arith.addi %ub_val, %ub_val : index + omp.teams { + omp.workdistribute { + %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"} + fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered { + fir.store %20 to %dev_mem : !fir.heap + } + fir.store %lb_val to %dev_mem : !fir.heap + fir.freemem %dev_mem : !fir.heap + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir index 25ef34f81b492..6e82efb308328 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -1,4 +1,5 @@ // RUN: fir-opt --lower-workdistribute %s | FileCheck %s +// Test lowering of workdistribute after fission on host device. // CHECK-LABEL: func.func @x( // CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} @@ -76,6 +77,7 @@ // CHECK: return // CHECK: } +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { %lb_ref = fir.alloca index {bindc_name = "lb"} fir.store %lb to %lb_ref : !fir.ref @@ -112,3 +114,4 @@ func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { } return } +} From 693dc506b73196cfcbec83622aa38ca3e8b24b11 Mon Sep 17 00:00:00 2001 From: skc7 Date: Mon, 15 Sep 2025 18:57:33 +0530 Subject: [PATCH 10/21] Fix CI errors --- flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index aa0e1f3416114..e839db17150f9 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -569,7 +569,7 @@ static bool usedOutsideSplit(Value v, Operation *split) { return true; } return false; -}; +} // isRecomputableAfterFission checks if an operation can be recomputed static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { From 5c43c0e68fe8042445d31e933d6fa75c0e33b964 Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 16 Sep 2025 23:10:32 +0530 Subject: [PATCH 11/21] Handle lowering of scalar assignments to arrays --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 338 ++++++++++++++++-- 1 file changed, 302 insertions(+), 36 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index e839db17150f9..4a91b074b7cd1 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -76,6 +76,17 @@ static bool isRuntimeCall(Operation *op) { // This is the single source of truth about whether we should parallelize an // operation nested in an omp.workdistribute region. static bool shouldParallelize(Operation *op) { + // True if the op is a runtime call to Assign + if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast(op); + if ((*runtimeCall.getCallee()).getRootReference().getValue() == + "_FortranAAssign") { + return true; + } + } + // We cannot parallelize ops with side effects. + // Parallelizable operations should not produce + // values that other operations depend on if (llvm::any_of(op->getResults(), [](OpResult v) -> bool { return !v.use_empty(); })) return false; @@ -86,11 +97,6 @@ static bool shouldParallelize(Operation *op) { return false; return *unordered; } - // True if the op is a runtime call to Assign - if (isRuntimeCall(op) && - (op->getName().getStringRef() == "_FortranAAssign")) { - return true; - } // We cannot parallise anything else. return false; } @@ -268,6 +274,7 @@ genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, } // Generate omp.wsloop operation with an empty region and +// clone the body of fir.do_loop operation inside the loop nest region. static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, const mlir::omp::LoopNestOperands &clauseOps, bool composite) { @@ -349,6 +356,221 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute, return false; } +// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array +static bool isEnclosedTypeRefToBoxArray(Type type) { + // Step 1: Check if it's a reference type + if (auto refType = dyn_cast(type)) { + // Step 2: Get the referenced type (should be fir.box) + auto referencedType = refType.getEleTy(); + + // Step 3: Check if referenced type is a box + if (auto boxType = dyn_cast(referencedType)) { + // Step 4: Get the boxed type and check if it's an array + auto boxedType = boxType.getEleTy(); + + // Step 5: Check if boxed type is a sequence (array) + return isa(boxedType); + } + } + return false; +} + +// Check if the enclosed type in fir.box is scalar (not array) +static bool isEnclosedTypeBoxScalar(Type type) { + // Step 1: Check if it's a box type + if (auto boxType = dyn_cast(type)) { + // Step 2: Get the boxed type + auto boxedType = boxType.getEleTy(); + // Step 3: Check if boxed type is NOT a sequence (array) + return !isa(boxedType); + } + return false; +} + +// Check if the FortranAAssign call has src as scalar and dest as array +static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) { + if (callOp.getNumOperands() < 2) + return false; + auto srcArg = callOp.getOperand(1); + auto destArg = callOp.getOperand(0); + // Both operands should be fir.convert ops + auto srcConvert = srcArg.getDefiningOp(); + auto destConvert = destArg.getDefiningOp(); + if (!srcConvert || !destConvert) { + emitError(callOp->getLoc(), + "Unimplemented: FortranAssign to OpenMP lowering\n"); + return false; + } + // Get the original types before conversion + auto srcOrigType = srcConvert.getValue().getType(); + auto destOrigType = destConvert.getValue().getType(); + + // Check if src is scalar and dest is array + bool srcIsScalar = isEnclosedTypeBoxScalar(srcOrigType); + bool destIsArray = isEnclosedTypeRefToBoxArray(destOrigType); + return srcIsScalar && destIsArray; +} + +// Convert a flat index to multi-dimensional indices for an array box +// Example: 2D array with shape (2,4) +// Col 1 Col 2 Col 3 Col 4 +// Row 1: (1,1) (1,2) (1,3) (1,4) +// Row 2: (2,1) (2,2) (2,3) (2,4) +// +// extents: (2,4) +// +// flatIdx: 0 1 2 3 4 5 6 7 +// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4) +static SmallVector convertFlatToMultiDim(OpBuilder &builder, + Location loc, Value flatIdx, + Value arrayBox) { + // Get array type and rank + auto boxType = cast(arrayBox.getType()); + auto seqType = cast(boxType.getEleTy()); + int rank = seqType.getDimension(); + + // Get all extents + SmallVector extents; + // Get extents for each dimension + for (int i = 0; i < rank; ++i) { + auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + extents.push_back(boxDims.getResult(1)); + } + + // Convert flat index to multi-dimensional indices + SmallVector indices(rank); + Value temp = flatIdx; + auto c1 = builder.create(loc, 1); + + // Work backwards through dimensions (row-major order) + for (int i = rank - 1; i >= 0; --i) { + Value zeroBasedIdx = builder.create(loc, temp, extents[i]); + // Convert to one-based index + indices[i] = builder.create(loc, zeroBasedIdx, c1); + if (i > 0) { + temp = builder.create(loc, temp, extents[i]); + } + } + + return indices; +} + +// Calculate the total number of elements in the array box +// (totalElems = extent(1) * extent(2) * ... * extent(n)) +static Value CalculateTotalElements(OpBuilder &builder, Location loc, + Value arrayBox) { + auto boxType = cast(arrayBox.getType()); + auto seqType = cast(boxType.getEleTy()); + int rank = seqType.getDimension(); + + Value totalElems = nullptr; + for (int i = 0; i < rank; ++i) { + auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + Value extent = boxDims.getResult(1); + if (i == 0) { + totalElems = extent; + } else { + totalElems = builder.create(loc, totalElems, extent); + } + } + return totalElems; +} + +// Replace the FortranAAssign runtime call with an unordered do loop +static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, + omp::TeamsOp teamsOp, + omp::WorkdistributeOp workdistribute, + fir::CallOp callOp) { + auto destConvert = callOp.getOperand(0).getDefiningOp(); + auto srcConvert = callOp.getOperand(1).getDefiningOp(); + + Value destBox = destConvert.getValue(); + Value srcBox = srcConvert.getValue(); + + builder.setInsertionPoint(teamsOp); + // Load destination array box and source scalar + auto arrayBox = builder.create(loc, destBox); + auto scalarValue = builder.create(loc, srcBox); + auto scalar = builder.create(loc, scalarValue); + + // Calculate total number of elements (flattened) + auto c0 = builder.create(loc, 0); + auto c1 = builder.create(loc, 1); + Value totalElems = CalculateTotalElements(builder, loc, arrayBox); + + auto *workdistributeBlock = &workdistribute.getRegion().front(); + builder.setInsertionPointToStart(workdistributeBlock); + // Create single unordered loop for flattened array + auto doLoop = fir::DoLoopOp::create(builder, loc, c0, totalElems, c1, true); + Block *loopBlock = &doLoop.getRegion().front(); + builder.setInsertionPointToStart(doLoop.getBody()); + + auto flatIdx = loopBlock->getArgument(0); + SmallVector indices = + convertFlatToMultiDim(builder, loc, flatIdx, arrayBox); + // Use fir.array_coor for linear addressing + auto elemPtr = fir::ArrayCoorOp::create( + builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox, + nullptr, nullptr, ValueRange{indices}, ValueRange{}); + + builder.create(loc, scalar, elemPtr); +} + +// WorkdistributeRuntimeCallLower method finds the runtime calls +// nested in teams {workdistribute{}} and +// lowers FortranAAssign to unordered do loop if src is scalar and dest is +// array. Other runtime calls are not handled currently. +static bool +WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, + SetVector &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return false; + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return false; + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return false; + } + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + bool changed = false; + omp::TargetOp targetOp; + // Get the target op parent of teams + if (auto teamsOp = dyn_cast(workdistribute->getParentOp())) { + targetOp = dyn_cast(teamsOp->getParentOp()); + } + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; + } + if (isRuntimeCall(&op)) { + rewriter.setInsertionPoint(&op); + fir::CallOp runtimeCall = cast(op); + if ((*runtimeCall.getCallee()).getRootReference().getValue() == + "_FortranAAssign") { + if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) { + // Record the target ops to process later + targetOpsToProcess.insert(targetOp); + replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute, + runtimeCall); + op.erase(); + return true; + } + } + } + } + return changed; +} + // TeamsWorkdistributeToSingleOp method hoists all the ops inside // teams {workdistribute{}} before teams op. // @@ -367,13 +589,24 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute, // B() // // If only the terminator remains in teams after hoisting, we erase teams op. -static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) { +static bool +TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp, + SetVector &targetOpsToProcess) { auto workdistributeOp = getPerfectlyNested(teamsOp); if (!workdistributeOp) return false; // Get the block containing teamsOp (the parent block). Block *parentBlock = teamsOp->getBlock(); Block &workdistributeBlock = *workdistributeOp.getRegion().begin(); + // Record the target ops to process later + for (auto &op : workdistributeBlock.getOperations()) { + if (shouldParallelize(&op)) { + auto targetOp = dyn_cast(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + } auto insertPoint = Block::iterator(teamsOp); // Get the range of operations to move (excluding the terminator). auto workdistributeBegin = workdistributeBlock.begin(); @@ -762,14 +995,6 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { return rewriter.create(loc, i32Ty, attr); } -// Generate LLVM constant operations for i64 type. -static mlir::LLVM::ConstantOp -genI64Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { - mlir::Type i64Ty = rewriter.getI64Type(); - mlir::IntegerAttr attr = rewriter.getI64IntegerAttr(value); - return rewriter.create(loc, i64Ty, attr); -} - // Given a box descriptor, extract the base address of the data it describes. // If the box descriptor is a reference, load it first. // The base address is returned as an i8* pointer. @@ -912,6 +1137,46 @@ static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, return; } +// Generate code to replace a Fortran array assignment call with OpenMP +// runtime calls to perform the equivalent operation on the device. +// This involves extracting the source and destination pointers from the +// Fortran array descriptors, retrieving their mapped device pointers (if any), +// and invoking `omp_target_memcpy` to copy the data on the device. +static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, + mlir::Location loc, + fir::CallOp callOp, + mlir::Value device, + mlir::ModuleOp module) { + assert(callOp.getNumResults() == 0 && + "Expected _FortranAAssign to have no results"); + assert(callOp.getNumOperands() >= 2 && + "Expected _FortranAAssign to have at least two operands"); + + // Extract the source and destination pointers from the call operands. + mlir::Value dest = callOp.getOperand(0); + mlir::Value src = callOp.getOperand(1); + + // Get the base addresses of the source and destination arrays. + mlir::Value srcBase = genDescriptorGetBaseAddress(builder, loc, src); + mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest); + + // Get the total size in bytes of the data to be copied. + mlir::Value dataSize = genDescriptorGetDataSizeInBytes(builder, loc, src); + + // Retrieve the mapped device pointers for source and destination. + // If no mapping exists, the original host pointer is used. + Value destPtr = + genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module); + Value srcPtr = + genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); + Value zero = builder.create(loc, builder.getI64Type(), + builder.getI64IntegerAttr(0)); + // Generate the call to omp_target_memcpy to perform the data copy on the + // device. + genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, dataSize, zero, zero, + device, module); +} + // Struct to hold the host eval vars corresponding to loop bounds and steps struct HostEvalVars { SmallVector lbs; @@ -1045,26 +1310,21 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, } // Replace runtime calls with omp versions. else if (isRuntimeCall(op)) { - rewriter.setInsertionPoint(op); fir::CallOp runtimeCall = cast(op); - SmallVector operands = runtimeCall.getOperands(); - mlir::Location loc = runtimeCall.getLoc(); - fir::FirOpBuilder builder{rewriter, op}; - assert(operands.size() == 4); - auto fromBaseAddr = - genDescriptorGetBaseAddress(builder, loc, operands[1]); - auto toBaseAddr = genDescriptorGetBaseAddress(builder, loc, operands[0]); - auto dataSizeInBytes = - genDescriptorGetDataSizeInBytes(builder, loc, operands[1]); - - Value toPtr = - genOmpGetMappedPtrIfPresent(builder, loc, toBaseAddr, device, module); - Value fromPtr = genOmpGetMappedPtrIfPresent(builder, loc, fromBaseAddr, - device, module); - Value zero = genI64Constant(loc, rewriter, 0); - genOmpTargetMemcpyCall(builder, loc, toPtr, fromPtr, dataSizeInBytes, - zero, zero, device, module); - rewriter.eraseOp(op); + if ((*runtimeCall.getCallee()).getRootReference().getValue() == + "_FortranAAssign") { + rewriter.setInsertionPoint(op); + fir::FirOpBuilder builder{rewriter, op}; + + mlir::Location loc = runtimeCall.getLoc(); + genFortranAssignOmpReplacement(builder, loc, runtimeCall, device, + module); + rewriter.eraseOp(op); + } else { + llvm_unreachable("Unhandled runtime call hoisting."); + } + } else { + llvm_unreachable("Unhandled op hoisting."); } } @@ -1424,8 +1684,11 @@ static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter, isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); fissionTarget(res.postTargetOp, rewriter, module, isTargetDevice); return; + } + if (splitBefore) { + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + return; } else { - llvm::errs() << "Unhandled case in fissionTarget\n"; llvm::report_fatal_error("Unhandled case in fissionTarget"); } } @@ -1442,13 +1705,16 @@ class LowerWorkdistributePass moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { changed |= FissionWorkdistribute(workdistribute); }); + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= + WorkdistributeRuntimeCallLower(workdistribute, targetOpsToProcess); + }); moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { changed |= WorkdistributeDoLower(workdistribute, targetOpsToProcess); }); moduleOp->walk([&](mlir::omp::TeamsOp teams) { - changed |= TeamsWorkdistributeToSingleOp(teams); + changed |= TeamsWorkdistributeToSingleOp(teams, targetOpsToProcess); }); - if (changed) { bool isTargetDevice = llvm::cast(*moduleOp) From 2b010b811b0f190b8bc2c8a13c7833ff86ef99ab Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 18 Sep 2025 11:58:26 +0530 Subject: [PATCH 12/21] Add tests for scalar assign --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 27 ++--- .../Lower/OpenMP/workdistribute-saxpy-1d.f90 | 23 ++++ .../Lower/OpenMP/workdistribute-saxpy-2d.f90 | 26 +++++ .../Lower/OpenMP/workdistribute-saxpy-3d.f90 | 27 +++++ ...workdistribute-saxpy-and-scalar-assign.f90 | 33 ++++++ .../OpenMP/workdistribute-saxpy-two-2d.f90 | 38 ++++++ .../OpenMP/workdistribute-scalar-assign.f90 | 20 ++++ ...-workdistribute-runtime-assign-scalar.mlir | 108 ++++++++++++++++++ 8 files changed, 289 insertions(+), 13 deletions(-) create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 4a91b074b7cd1..88836c9323cef 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -358,17 +358,15 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute, // Check if the enclosed type in fir.ref is fir.box and fir.box encloses array static bool isEnclosedTypeRefToBoxArray(Type type) { - // Step 1: Check if it's a reference type + // Check if it's a reference type if (auto refType = dyn_cast(type)) { - // Step 2: Get the referenced type (should be fir.box) + // Get the referenced type (should be fir.box) auto referencedType = refType.getEleTy(); - - // Step 3: Check if referenced type is a box + // Check if referenced type is a box if (auto boxType = dyn_cast(referencedType)) { - // Step 4: Get the boxed type and check if it's an array + // Get the boxed type and check if it's an array auto boxedType = boxType.getEleTy(); - - // Step 5: Check if boxed type is a sequence (array) + // Check if boxed type is a sequence (array) return isa(boxedType); } } @@ -377,11 +375,11 @@ static bool isEnclosedTypeRefToBoxArray(Type type) { // Check if the enclosed type in fir.box is scalar (not array) static bool isEnclosedTypeBoxScalar(Type type) { - // Step 1: Check if it's a box type + // Check if it's a box type if (auto boxType = dyn_cast(type)) { - // Step 2: Get the boxed type + // Get the boxed type auto boxedType = boxType.getEleTy(); - // Step 3: Check if boxed type is NOT a sequence (array) + // Check if boxed type is NOT a sequence (array) return !isa(boxedType); } return false; @@ -743,7 +741,7 @@ static Type getPtrTypeForOmp(Type ty) { if (isPtr(ty)) return LLVM::LLVMPointerType::get(ty.getContext()); else - return fir::LLVMPointerType::get(ty); + return fir::ReferenceType::get(ty); } // allocateTempOmpVar allocates a temporary variable for OpenMP mapping @@ -806,6 +804,8 @@ static bool usedOutsideSplit(Value v, Operation *split) { // isRecomputableAfterFission checks if an operation can be recomputed static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { + // If the op has side effects, it cannot be recomputed. + // We consider fir.declare as having no side effects. if (isa(op)) return true; @@ -1161,7 +1161,7 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest); // Get the total size in bytes of the data to be copied. - mlir::Value dataSize = genDescriptorGetDataSizeInBytes(builder, loc, src); + mlir::Value srcDataSize = genDescriptorGetDataSizeInBytes(builder, loc, src); // Retrieve the mapped device pointers for source and destination. // If no mapping exists, the original host pointer is used. @@ -1171,9 +1171,10 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); Value zero = builder.create(loc, builder.getI64Type(), builder.getI64IntegerAttr(0)); + // Generate the call to omp_target_memcpy to perform the data copy on the // device. - genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, dataSize, zero, zero, + genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, srcDataSize, zero, zero, device, module); } diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 new file mode 100644 index 0000000000000..95c3f37f4720e --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 @@ -0,0 +1,23 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + + diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 new file mode 100644 index 0000000000000..70e82780edb1a --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 @@ -0,0 +1,26 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + + diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 new file mode 100644 index 0000000000000..d6fa300eaff99 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 @@ -0,0 +1,27 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols, depth) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols, depth + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols, depth) :: x, y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + ! CHECK: fir.do_loop + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + + diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 new file mode 100644 index 0000000000000..6b6dc3e3a184f --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 @@ -0,0 +1,33 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp target teams workdistribute + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = a * x + y + + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = 2.0_real32 + + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + + diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 new file mode 100644 index 0000000000000..2229ccf34e920 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 @@ -0,0 +1,38 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + !$omp target teams workdistribute + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * x + y + + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * y + x + + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + + diff --git a/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 new file mode 100644 index 0000000000000..af94559dfa8cf --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 @@ -0,0 +1,20 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute_scalar_assign +subroutine target_teams_workdistribute_scalar_assign() + integer :: aa(10) + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp target teams workdistribute + aa = 20 + !$omp end target teams workdistribute + +end subroutine target_teams_workdistribute_scalar_assign + diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir new file mode 100644 index 0000000000000..03d5d71df0a82 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir @@ -0,0 +1,108 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// Test lowering of workdistribute for a scalar assignment within a target teams workdistribute region. +// The test checks that the scalar assignment is correctly lowered to wsloop and loop_nest operations. + +// Example Fortran code: +// !$omp target teams workdistribute +// y = 3.0_real32 +// !$omp end target teams workdistribute + + +// CHECK-LABEL: func.func @x( +// CHECK: omp.target {{.*}} { +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_73:.*]]) : index = (%[[VAL_66:.*]]) to (%[[VAL_72:.*]]) inclusive step (%[[VAL_67:.*]]) { +// CHECK: %[[VAL_74:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_75:.*]]:3 = fir.box_dims %[[VAL_64:.*]], %[[VAL_74]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_76:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_77:.*]]:3 = fir.box_dims %[[VAL_64]], %[[VAL_76]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_78:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_79:.*]] = arith.remsi %[[VAL_73]], %[[VAL_77]]#1 : index +// CHECK: %[[VAL_80:.*]] = arith.addi %[[VAL_79]], %[[VAL_78]] : index +// CHECK: %[[VAL_81:.*]] = arith.divsi %[[VAL_73]], %[[VAL_77]]#1 : index +// CHECK: %[[VAL_82:.*]] = arith.remsi %[[VAL_81]], %[[VAL_75]]#1 : index +// CHECK: %[[VAL_83:.*]] = arith.addi %[[VAL_82]], %[[VAL_78]] : index +// CHECK: %[[VAL_84:.*]] = fir.array_coor %[[VAL_64]] %[[VAL_83]], %[[VAL_80]] : (!fir.box>, index, index) -> !fir.ref +// CHECK: fir.store %[[VAL_65:.*]] to %[[VAL_84]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: func.func private @_FortranAAssign(!fir.ref>, !fir.box, !fir.ref, i32) attributes {fir.runtime} + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { +func.func @x(%arr : !fir.ref>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c78 = arith.constant 78 : index + %cst = arith.constant 3.000000e+00 : f32 + %0 = fir.alloca i32 + %1 = fir.alloca i32 + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %194 = arith.subi %c10, %c1 : index + %195 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%194 : index) extent(%c10 : index) stride(%c1 : index) start_idx(%c1 : index) + %196 = arith.subi %c20, %c1 : index + %197 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%196 : index) extent(%c20 : index) stride(%c1 : index) start_idx(%c1 : index) + %198 = omp.map.info var_ptr(%arr : !fir.ref>, f32) map_clauses(implicit, tofrom) capture(ByRef) bounds(%195, %197) -> !fir.ref> {name = "y"} + %199 = omp.map.info var_ptr(%1 : !fir.ref, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref {name = ""} + %200 = omp.map.info var_ptr(%0 : !fir.ref, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref {name = ""} + omp.target map_entries(%198 -> %arg5, %199 -> %arg6, %200 -> %arg7 : !fir.ref>, !fir.ref, !fir.ref) { + %c0_0 = arith.constant 0 : index + %201 = fir.load %arg7 : !fir.ref + %202 = fir.load %arg6 : !fir.ref + %203 = fir.convert %202 : (i32) -> i64 + %204 = fir.convert %201 : (i32) -> i64 + %205 = fir.convert %204 : (i64) -> index + %206 = arith.cmpi sgt, %205, %c0_0 : index + %207 = fir.convert %203 : (i64) -> index + %208 = arith.cmpi sgt, %207, %c0_0 : index + %209 = arith.select %208, %207, %c0_0 : index + %210 = arith.select %206, %205, %c0_0 : index + %211 = fir.shape %210, %209 : (index, index) -> !fir.shape<2> + %212 = fir.declare %arg5(%211) {uniq_name = "_QFFaxpy_array_workdistributeEy"} : (!fir.ref>, !fir.shape<2>) -> !fir.ref> + %213 = fir.embox %212(%211) : (!fir.ref>, !fir.shape<2>) -> !fir.box> + omp.teams { + %214 = fir.alloca !fir.box> {pinned} + omp.workdistribute { + %215 = fir.alloca f32 + %216 = fir.embox %215 : (!fir.ref) -> !fir.box + %217 = fir.shape %210, %209 : (index, index) -> !fir.shape<2> + %218 = fir.embox %212(%217) : (!fir.ref>, !fir.shape<2>) -> !fir.box> + fir.store %218 to %214 : !fir.ref>> + %219 = fir.address_of(@_QQclXf9c642d28e5bba1f07fa9a090b72f4fc) : !fir.ref> + %c39_i32 = arith.constant 39 : i32 + %220 = fir.convert %214 : (!fir.ref>>) -> !fir.ref> + %221 = fir.convert %216 : (!fir.box) -> !fir.box + %222 = fir.convert %219 : (!fir.ref>) -> !fir.ref + fir.call @_FortranAAssign(%220, %221, %222, %c39_i32) : (!fir.ref>, !fir.box, !fir.ref, i32) -> () + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +func.func private @_FortranAAssign(!fir.ref>, !fir.box, !fir.ref, i32) attributes {fir.runtime} + +fir.global linkonce @_QQclXf9c642d28e5bba1f07fa9a090b72f4fc constant : !fir.char<1,78> { + %0 = fir.string_lit "File: /work/github/skc7/llvm-project/build_fomp_reldebinfo/saxpy_tests/\00"(78) : !fir.char<1,78> + fir.has_value %0 : !fir.char<1,78> +} +} From c27ad3d9105840bf962bb38f8dba9a0689de956d Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 19 Sep 2025 18:50:33 +0530 Subject: [PATCH 13/21] Fix Scalar assign bug. And Fix CI tests --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 51 +++++++++++++++---- .../lower-workdistribute-fission-host.mlir | 8 +-- .../lower-workdistribute-fission-target.mlir | 9 ++-- 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 88836c9323cef..ff62457e0a7da 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -487,11 +487,30 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, Value destBox = destConvert.getValue(); Value srcBox = srcConvert.getValue(); + // get defining alloca op of destBox and srcBox + auto destAlloca = destBox.getDefiningOp(); + + if (!destAlloca) { + emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n"); + return; + } + + // get the store op that stores to the alloca + for (auto user : destAlloca->getUsers()) { + if (auto storeOp = dyn_cast(user)) { + destBox = storeOp.getValue(); + break; + } + } + builder.setInsertionPoint(teamsOp); - // Load destination array box and source scalar - auto arrayBox = builder.create(loc, destBox); + // Load destination array box (if it's a reference) + Value arrayBox = destBox; + if (isa(destBox.getType())) + arrayBox = builder.create(loc, destBox); + auto scalarValue = builder.create(loc, srcBox); - auto scalar = builder.create(loc, scalarValue); + Value scalar = builder.create(loc, scalarValue); // Calculate total number of elements (flattened) auto c0 = builder.create(loc, 0); @@ -543,9 +562,8 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, bool changed = false; omp::TargetOp targetOp; // Get the target op parent of teams - if (auto teamsOp = dyn_cast(workdistribute->getParentOp())) { - targetOp = dyn_cast(teamsOp->getParentOp()); - } + targetOp = dyn_cast(teams->getParentOp()); + SmallVector opsToErase; for (auto &op : workdistribute.getOps()) { if (&op == terminator) { break; @@ -560,12 +578,15 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, targetOpsToProcess.insert(targetOp); replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute, runtimeCall); - op.erase(); - return true; + opsToErase.push_back(&op); + changed = true; } } } } + for (auto *op : opsToErase) { + op->erase(); + } return changed; } @@ -911,7 +932,7 @@ static void reloadCacheAndRecompute( unsigned originalMapVarsSize = targetOp.getMapVars().size(); unsigned hostEvalVarsSize = hostEvalVars.size(); - // Create Stores for allocs. + // Create load operations for each allocated variable. for (unsigned i = 0; i < allocs.size(); ++i) { Value original = allocs[i]; // Get the new block argument for this specific allocated value. @@ -1196,6 +1217,12 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, Block *targetBlock = &targetOp.getRegion().front(); assert(targetBlock == &targetOp.getRegion().back()); IRMapping mapping; + + auto targetDataOp = cast(targetOp->getParentOp()); + if (!targetDataOp) { + llvm_unreachable("Expected target op to be inside target_data op"); + return; + } // create mapping for host_eval_vars unsigned hostEvalVarCount = targetOp.getHostEvalVars().size(); for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) { @@ -1361,12 +1388,14 @@ static void computeAllocsCacheRecomputable( it++) { // Check if any of the results are used outside the split point. for (auto res : it->getResults()) { - if (usedOutsideSplit(res, splitBeforeOp)) + if (usedOutsideSplit(res, splitBeforeOp)) { requiredVals.push_back(res); + } } // If the op is not recomputable, add it to the nonRecomputable set. - if (!isRecomputableAfterFission(&*it, splitBeforeOp)) + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) { nonRecomputable.insert(&*it); + } } // For each required value, collect its dependencies. for (auto requiredVal : requiredVals) diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir index b4c9598a78f0e..04e60ca8bbf37 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir @@ -42,10 +42,10 @@ // CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref // CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref> // CHECK: omp.target host_eval(%[[VAL_24]] -> %[[VAL_31:.*]], %[[VAL_25]] -> %[[VAL_32:.*]], %[[VAL_26]] -> %[[VAL_33:.*]] : index, index, index) map_entries(%[[VAL_7]] -> %[[VAL_34:.*]], %[[VAL_8]] -> %[[VAL_35:.*]], %[[VAL_9]] -> %[[VAL_36:.*]], %[[VAL_10]] -> %[[VAL_37:.*]], %[[VAL_13]] -> %[[VAL_38:.*]], %[[VAL_16]] -> %[[VAL_39:.*]], %[[VAL_19]] -> %[[VAL_40:.*]], %[[VAL_22]] -> %[[VAL_41:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { -// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr -// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.llvm_ptr -// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.llvm_ptr -// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.llvm_ptr> +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref +// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.ref +// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.ref +// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.ref> // CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index // CHECK: omp.teams { // CHECK: omp.parallel { diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir index 6e82efb308328..062eb701b52ef 100644 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -42,10 +42,10 @@ // CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref // CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref> // CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_31:.*]], %[[VAL_8]] -> %[[VAL_32:.*]], %[[VAL_9]] -> %[[VAL_33:.*]], %[[VAL_10]] -> %[[VAL_34:.*]], %[[VAL_13]] -> %[[VAL_35:.*]], %[[VAL_16]] -> %[[VAL_36:.*]], %[[VAL_19]] -> %[[VAL_37:.*]], %[[VAL_22]] -> %[[VAL_38:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { -// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.llvm_ptr -// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr -// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.llvm_ptr -// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.llvm_ptr> +// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.ref +// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.ref +// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.ref +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref> // CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_40]] : index // CHECK: omp.teams { // CHECK: omp.parallel { @@ -77,6 +77,7 @@ // CHECK: return // CHECK: } + module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { %lb_ref = fir.alloca index {bindc_name = "lb"} From de0f9af7058837458665539748a9822e5fcde370 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 24 Sep 2025 19:13:38 +0530 Subject: [PATCH 14/21] Comments fix and new test. --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 50 ++++++++++++------- .../workdistribute-target-teams-clauses.f90 | 32 ++++++++++++ 2 files changed, 63 insertions(+), 19 deletions(-) create mode 100644 flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index ff62457e0a7da..7bba699e6ff2e 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -12,11 +12,11 @@ // Fortran array statements are lowered to fir as fir.do_loop unordered. // lower-workdistribute pass works mainly on identifying fir.do_loop unordered // that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and -// lowers it to target{teams{parallel{wsloop{loop_nest}}}}. +// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}. // It hoists all the other ops outside target region. // Relaces heap allocation on target with omp.target_allocmem and // deallocation with omp.target_freemem from host. Also replaces -// runtime function "Assign" with omp.target_memcpy. +// runtime function "Assign" with omp_target_memcpy. // //===----------------------------------------------------------------------===// @@ -319,13 +319,14 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, // Then, its lowered to // // omp.teams { -// omp.parallel { -// omp.distribute { -// omp.wsloop { -// omp.loop_nest -// ... -// } -// } +// omp.parallel { +// omp.distribute { +// omp.wsloop { +// omp.loop_nest +// ... +// } +// } +// } // } // } @@ -345,6 +346,7 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute, targetOpsToProcess.insert(targetOp); } } + // Generate the nested parallel, distribute, wsloop and loop_nest ops. genParallelOp(wdLoc, rewriter, true); genDistributeOp(wdLoc, rewriter, true); mlir::omp::LoopNestOperands loopNestClauseOps; @@ -584,6 +586,7 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, } } } + // Erase the runtime calls that have been replaced. for (auto *op : opsToErase) { op->erase(); } @@ -772,6 +775,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, Value alloc; Type allocType; auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + // Get the appropriate type for allocation if (isPtr(ty)) { Type intTy = rewriter.getI32Type(); auto one = rewriter.create(loc, intTy, 1); @@ -782,6 +786,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, allocType = ty; alloc = rewriter.create(loc, allocType); } + // Lambda to create mapinfo ops auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { return rewriter.create( loc, alloc.getType(), alloc, TypeAttr::get(allocType), @@ -796,6 +801,7 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, /*mapperId=*/mlir::FlatSymbolRefAttr(), /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); }; + // Create mapinfo ops. uint64_t mapFrom = static_cast>( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); @@ -847,14 +853,17 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, SetVector &toCache, SetVector &toRecompute) { Operation *op = v.getDefiningOp(); + // If v is a block argument, it must be from the targetOp. if (!op) { assert(cast(v).getOwner()->getParentOp() == targetOp); return; } + // If the op is in the nonRecomputable set, add it to toCache and return. if (nonRecomputable.contains(op)) { toCache.insert(op); return; } + // Add the op to toRecompute. toRecompute.insert(op); for (auto opr : op->getOperands()) collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, @@ -939,6 +948,8 @@ static void reloadCacheAndRecompute( Value newArg = newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i); Value restored; + // If the original value is a pointer or reference, load and convert if + // necessary. if (isPtr(original.getType())) { restored = rewriter.create(loc, llvmPtrTy, newArg); if (!isa(original.getType())) @@ -967,6 +978,7 @@ static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) { return nullptr; // Find parallel op inside teams mlir::omp::ParallelOp parallelOp = nullptr; + // Look for the parallel op in the teams region for (auto &op : teamsOp.getRegion().front()) { if (auto parallel = dyn_cast(op)) { parallelOp = parallel; @@ -1218,6 +1230,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, assert(targetBlock == &targetOp.getRegion().back()); IRMapping mapping; + // Get the parent target_data op auto targetDataOp = cast(targetOp->getParentOp()); if (!targetDataOp) { llvm_unreachable("Expected target op to be inside target_data op"); @@ -1255,6 +1268,7 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, SmallVector opsToReplace; Value device = targetOp.getDevice(); + // If device is not specified, default to device 0. if (!device) { device = genI32Constant(targetOp.getLoc(), rewriter, 0); } @@ -1508,15 +1522,12 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector &postMapOperands, SmallVector isolatedHostEvalVars{targetOp.getHostEvalVars()}; // update the hostEvalVars of isolatedTargetOp if (!hostEvalVars.lbs.empty() && !isTargetDevice) { - for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) { - isolatedHostEvalVars.push_back(hostEvalVars.lbs[i]); - } - for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) { - isolatedHostEvalVars.push_back(hostEvalVars.ubs[i]); - } - for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) { - isolatedHostEvalVars.push_back(hostEvalVars.steps[i]); - } + isolatedHostEvalVars.append(hostEvalVars.lbs.begin(), + hostEvalVars.lbs.end()); + isolatedHostEvalVars.append(hostEvalVars.ubs.begin(), + hostEvalVars.ubs.end()); + isolatedHostEvalVars.append(hostEvalVars.steps.begin(), + hostEvalVars.steps.end()); } // Create the isolated target op omp::TargetOp isolatedTargetOp = rewriter.create( @@ -1708,13 +1719,14 @@ static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter, Operation *toIsolate = std::get<0>(*tuple); bool splitBefore = !std::get<1>(*tuple); bool splitAfter = !std::get<2>(*tuple); - + // Recursively isolate the target op. if (splitBefore && splitAfter) { auto res = isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); fissionTarget(res.postTargetOp, rewriter, module, isTargetDevice); return; } + // Isolate only before the op. if (splitBefore) { isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); return; diff --git a/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 b/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 new file mode 100644 index 0000000000000..4a08e53bc316a --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 @@ -0,0 +1,32 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +! CHECK: omp.target_data map_entries({{.*}}) +! CHECK: omp.target thread_limit({{.*}}) host_eval({{.*}}) map_entries({{.*}}) +! CHECK: omp.teams num_teams({{.*}}) +! CHECK: omp.parallel +! CHECK: omp.distribute +! CHECK: omp.wsloop +! CHECK: omp.loop_nest + +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + integer :: i + + a = 2.0_real32 + x = [(real(i, real32), i = 1, 10)] + y = [(real(i * 0.5, real32), i = 1, 10)] + + !$omp target teams workdistribute & + !$omp& num_teams(4) & + !$omp& thread_limit(8) & + !$omp& default(shared) & + !$omp& private(i) & + !$omp& map(to: x) & + !$omp& map(tofrom: y) + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute From d7e8d905969610d9ef92dac924dd543112242a9e Mon Sep 17 00:00:00 2001 From: skc7 Date: Tue, 30 Sep 2025 16:55:32 +0530 Subject: [PATCH 15/21] Add verifier. --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 68 +++++++++++++++---- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 7bba699e6ff2e..0776ff8e9a4a3 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -118,6 +118,53 @@ static T getPerfectlyNested(Operation *op) { return nullptr; } +// VerifyTargetTeamsWorkdistribute method verifies that +// omp.target { teams { workdistribute { ... } } } is well formed +// and fails for function calls that don't have lowering implemented yet. +static bool +VerifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + workdistribute.emitError() << "workdistribute not nested in teams\n"; + return false; + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + workdistribute.emitError() << "workdistribute with multiple blocks\n"; + return false; + } + if (teams.getRegion().getBlocks().size() != 1) { + workdistribute.emitError() << "teams with multiple blocks\n"; + return false; + } + omp::TargetOp targetOp = dyn_cast(teams->getParentOp()); + // return if not omp.target + if (!targetOp) + return true; + + for (auto &op : workdistribute.getOps()) { + if (auto callOp = dyn_cast(op)) { + if (isRuntimeCall(&op)) { + auto funcName = (*callOp.getCallee()).getRootReference().getValue(); + // _FortranAAssign is handled. Other runtime calls are not supported + // in omp.workdistribute yet. + if (funcName == "_FortranAAssign") + continue; + else + workdistribute.emitError() + << "Runtime call " << funcName + << " lowering not supported for workdistribute yet."; + return false; + } else { + workdistribute.emitError() << "Non-runtime fir.call lowering not " + "supported in workdistribute yet."; + return false; + } + } + } + return true; +} + // FissionWorkdistribute method finds the parallelizable ops // within teams {workdistribute} region and moves them to their // own teams{workdistribute} region. @@ -154,18 +201,10 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { OpBuilder rewriter(workdistribute); auto loc = workdistribute->getLoc(); auto teams = dyn_cast(workdistribute->getParentOp()); - if (!teams) { - emitError(loc, "workdistribute not nested in teams\n"); - return false; - } - if (workdistribute.getRegion().getBlocks().size() != 1) { - emitError(loc, "workdistribute with multiple blocks\n"); - return false; - } - if (teams.getRegion().getBlocks().size() != 1) { - emitError(loc, "teams with multiple blocks\n"); - return false; - } + + omp::TargetOp targetOp; + // Get the target op parent of teams + targetOp = dyn_cast(teams->getParentOp()); auto *teamsBlock = &teams.getRegion().front(); bool changed = false; @@ -1744,6 +1783,11 @@ class LowerWorkdistributePass auto moduleOp = getOperation(); bool changed = false; SetVector targetOpsToProcess; + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + bool res = VerifyTargetTeamsWorkdistribute(workdistribute); + if (!res) + signalPassFailure(); + }); moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { changed |= FissionWorkdistribute(workdistribute); }); From baff55b76fafe5fe6c1260594ff7d436a3adfc88 Mon Sep 17 00:00:00 2001 From: skc7 Date: Wed, 8 Oct 2025 17:00:30 +0530 Subject: [PATCH 16/21] Fix error reporting. Fix comments --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 574 +++++++++--------- .../OpenMP/lower-workdistribute-target.mlir | 32 - 2 files changed, 290 insertions(+), 316 deletions(-) delete mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 0776ff8e9a4a3..7a41fbca679b0 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -59,8 +59,11 @@ using namespace mlir; namespace { -// The isRuntimeCall function is a utility designed to determine -// if a given operation is a call to a Fortran-specific runtime function. +/// This string is used to identify the Fortran-specific runtime FortranAAssign. +static constexpr llvm::StringRef FortranAssignStr = "_FortranAAssign"; + +/// The isRuntimeCall function is a utility designed to determine +/// if a given operation is a call to a Fortran-specific runtime function. static bool isRuntimeCall(Operation *op) { if (auto callOp = dyn_cast(op)) { auto callee = callOp.getCallee(); @@ -73,14 +76,14 @@ static bool isRuntimeCall(Operation *op) { return false; } -// This is the single source of truth about whether we should parallelize an -// operation nested in an omp.workdistribute region. +/// This is the single source of truth about whether we should parallelize an +/// operation nested in an omp.workdistribute region. static bool shouldParallelize(Operation *op) { // True if the op is a runtime call to Assign if (isRuntimeCall(op)) { fir::CallOp runtimeCall = cast(op); - if ((*runtimeCall.getCallee()).getRootReference().getValue() == - "_FortranAAssign") { + auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + if (funcName == FortranAssignStr) { return true; } } @@ -97,12 +100,12 @@ static bool shouldParallelize(Operation *op) { return false; return *unordered; } - // We cannot parallise anything else. + // We cannot parallelize anything else. return false; } -// The getPerfectlyNested function is a generic utility for finding -// a single, "perfectly nested" operation within a parent operation. +/// The getPerfectlyNested function is a generic utility for finding +/// a single, "perfectly nested" operation within a parent operation. template static T getPerfectlyNested(Operation *op) { if (op->getNumRegions() != 1) @@ -118,24 +121,25 @@ static T getPerfectlyNested(Operation *op) { return nullptr; } -// VerifyTargetTeamsWorkdistribute method verifies that -// omp.target { teams { workdistribute { ... } } } is well formed -// and fails for function calls that don't have lowering implemented yet. -static bool -VerifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { +/// verifyTargetTeamsWorkdistribute method verifies that +/// omp.target { teams { workdistribute { ... } } } is well formed +/// and fails for function calls that don't have lowering implemented yet. +static FailureOr +verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); auto teams = dyn_cast(workdistribute->getParentOp()); if (!teams) { - workdistribute.emitError() << "workdistribute not nested in teams\n"; - return false; + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); } if (workdistribute.getRegion().getBlocks().size() != 1) { - workdistribute.emitError() << "workdistribute with multiple blocks\n"; - return false; + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); } if (teams.getRegion().getBlocks().size() != 1) { - workdistribute.emitError() << "teams with multiple blocks\n"; - return false; + emitError(loc, "teams with multiple blocks\n"); + return failure(); } omp::TargetOp targetOp = dyn_cast(teams->getParentOp()); // return if not omp.target @@ -148,64 +152,55 @@ VerifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { auto funcName = (*callOp.getCallee()).getRootReference().getValue(); // _FortranAAssign is handled. Other runtime calls are not supported // in omp.workdistribute yet. - if (funcName == "_FortranAAssign") + if (funcName == FortranAssignStr) continue; - else - workdistribute.emitError() - << "Runtime call " << funcName - << " lowering not supported for workdistribute yet."; - return false; - } else { - workdistribute.emitError() << "Non-runtime fir.call lowering not " - "supported in workdistribute yet."; - return false; + else { + emitError(loc, "Runtime call " + funcName + + " lowering not supported for workdistribute yet."); + return failure(); + } } } } return true; } -// FissionWorkdistribute method finds the parallelizable ops -// within teams {workdistribute} region and moves them to their -// own teams{workdistribute} region. -// -// If B() and D() are parallelizable, -// -// omp.teams { -// omp.workdistribute { -// A() -// B() -// C() -// D() -// E() -// } -// } -// -// becomes -// -// A() -// omp.teams { -// omp.workdistribute { -// B() -// } -// } -// C() -// omp.teams { -// omp.workdistribute { -// D() -// } -// } -// E() - -static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { +/// fissionWorkdistribute method finds the parallelizable ops +/// within teams {workdistribute} region and moves them to their +/// own teams{workdistribute} region. +/// +/// If B() and D() are parallelizable, +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// C() +/// D() +/// E() +/// } +/// } +/// +/// becomes +/// +/// A() +/// omp.teams { +/// omp.workdistribute { +/// B() +/// } +/// } +/// C() +/// omp.teams { +/// omp.workdistribute { +/// D() +/// } +/// } +/// E() +static FailureOr +fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { OpBuilder rewriter(workdistribute); auto loc = workdistribute->getLoc(); auto teams = dyn_cast(workdistribute->getParentOp()); - - omp::TargetOp targetOp; - // Get the target op parent of teams - targetOp = dyn_cast(teams->getParentOp()); - auto *teamsBlock = &teams.getRegion().front(); bool changed = false; // Move the ops inside teams and before workdistribute outside. @@ -217,7 +212,7 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { } if (shouldParallelize(&op)) { emitError(loc, "teams has parallelize ops before first workdistribute\n"); - return false; + return failure(); } else { rewriter.setInsertionPoint(teams); rewriter.clone(op, irMapping); @@ -280,7 +275,7 @@ static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) { return changed; } -// Generate omp.parallel operation with an empty region. +/// Generate omp.parallel operation with an empty region. static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { auto parallelOp = rewriter.create(loc); parallelOp.setComposite(composite); @@ -289,7 +284,7 @@ static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { return; } -// Generate omp.distribute operation with an empty region. +/// Generate omp.distribute operation with an empty region. static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { mlir::omp::DistributeOperands distributeClauseOps; auto distributeOp = @@ -300,7 +295,7 @@ static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { return; } -// Generate loop nest clause operands from fir.do_loop operation. +/// Generate loop nest clause operands from fir.do_loop operation. static void genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, mlir::omp::LoopNestOperands &loopNestClauseOps) { @@ -312,8 +307,8 @@ genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); } -// Generate omp.wsloop operation with an empty region and -// clone the body of fir.do_loop operation inside the loop nest region. +/// Generate omp.wsloop operation with an empty region and +/// clone the body of fir.do_loop operation inside the loop nest region. static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, const mlir::omp::LoopNestOperands &clauseOps, bool composite) { @@ -341,36 +336,35 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, return; } -// WorkdistributeDoLower method finds the fir.do_loop unoredered -// nested in teams {workdistribute{fir.do_loop unoredered}} and -// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}. -// -// If fir.do_loop is present inside teams workdistribute -// -// omp.teams { -// omp.workdistribute { -// fir.do_loop unoredered { -// ... -// } -// } -// } -// -// Then, its lowered to -// -// omp.teams { -// omp.parallel { -// omp.distribute { -// omp.wsloop { -// omp.loop_nest -// ... -// } -// } -// } -// } -// } - +/// workdistributeDoLower method finds the fir.do_loop unoredered +/// nested in teams {workdistribute{fir.do_loop unoredered}} and +/// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}. +/// +/// If fir.do_loop is present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.parallel { +/// omp.distribute { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } +/// } static bool -WorkdistributeDoLower(omp::WorkdistributeOp workdistribute, +workdistributeDoLower(omp::WorkdistributeOp workdistribute, SetVector &targetOpsToProcess) { OpBuilder rewriter(workdistribute); auto doLoop = getPerfectlyNested(workdistribute); @@ -397,7 +391,7 @@ WorkdistributeDoLower(omp::WorkdistributeOp workdistribute, return false; } -// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array +/// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array static bool isEnclosedTypeRefToBoxArray(Type type) { // Check if it's a reference type if (auto refType = dyn_cast(type)) { @@ -414,7 +408,7 @@ static bool isEnclosedTypeRefToBoxArray(Type type) { return false; } -// Check if the enclosed type in fir.box is scalar (not array) +/// Check if the enclosed type in fir.box is scalar (not array) static bool isEnclosedTypeBoxScalar(Type type) { // Check if it's a box type if (auto boxType = dyn_cast(type)) { @@ -426,7 +420,7 @@ static bool isEnclosedTypeBoxScalar(Type type) { return false; } -// Check if the FortranAAssign call has src as scalar and dest as array +/// Check if the FortranAAssign call has src as scalar and dest as array static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) { if (callOp.getNumOperands() < 2) return false; @@ -450,16 +444,16 @@ static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) { return srcIsScalar && destIsArray; } -// Convert a flat index to multi-dimensional indices for an array box -// Example: 2D array with shape (2,4) -// Col 1 Col 2 Col 3 Col 4 -// Row 1: (1,1) (1,2) (1,3) (1,4) -// Row 2: (2,1) (2,2) (2,3) (2,4) -// -// extents: (2,4) -// -// flatIdx: 0 1 2 3 4 5 6 7 -// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4) +/// Convert a flat index to multi-dimensional indices for an array box +/// Example: 2D array with shape (2,4) +/// Col 1 Col 2 Col 3 Col 4 +/// Row 1: (1,1) (1,2) (1,3) (1,4) +/// Row 2: (2,1) (2,2) (2,3) (2,4) +/// +/// extents: (2,4) +/// +/// flatIdx: 0 1 2 3 4 5 6 7 +/// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4) static SmallVector convertFlatToMultiDim(OpBuilder &builder, Location loc, Value flatIdx, Value arrayBox) { @@ -495,8 +489,8 @@ static SmallVector convertFlatToMultiDim(OpBuilder &builder, return indices; } -// Calculate the total number of elements in the array box -// (totalElems = extent(1) * extent(2) * ... * extent(n)) +/// Calculate the total number of elements in the array box +/// (totalElems = extent(1) * extent(2) * ... * extent(n)) static Value CalculateTotalElements(OpBuilder &builder, Location loc, Value arrayBox) { auto boxType = cast(arrayBox.getType()); @@ -517,7 +511,7 @@ static Value CalculateTotalElements(OpBuilder &builder, Location loc, return totalElems; } -// Replace the FortranAAssign runtime call with an unordered do loop +/// Replace the FortranAAssign runtime call with an unordered do loop static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, omp::TeamsOp teamsOp, omp::WorkdistributeOp workdistribute, @@ -576,27 +570,27 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, builder.create(loc, scalar, elemPtr); } -// WorkdistributeRuntimeCallLower method finds the runtime calls -// nested in teams {workdistribute{}} and -// lowers FortranAAssign to unordered do loop if src is scalar and dest is -// array. Other runtime calls are not handled currently. -static bool -WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, +/// workdistributeRuntimeCallLower method finds the runtime calls +/// nested in teams {workdistribute{}} and +/// lowers FortranAAssign to unordered do loop if src is scalar and dest is +/// array. Other runtime calls are not handled currently. +static FailureOr +workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, SetVector &targetOpsToProcess) { OpBuilder rewriter(workdistribute); auto loc = workdistribute->getLoc(); auto teams = dyn_cast(workdistribute->getParentOp()); if (!teams) { emitError(loc, "workdistribute not nested in teams\n"); - return false; + return failure(); } if (workdistribute.getRegion().getBlocks().size() != 1) { emitError(loc, "workdistribute with multiple blocks\n"); - return false; + return failure(); } if (teams.getRegion().getBlocks().size() != 1) { emitError(loc, "teams with multiple blocks\n"); - return false; + return failure(); } auto *workdistributeBlock = &workdistribute.getRegion().front(); auto *terminator = workdistributeBlock->getTerminator(); @@ -612,8 +606,8 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, if (isRuntimeCall(&op)) { rewriter.setInsertionPoint(&op); fir::CallOp runtimeCall = cast(op); - if ((*runtimeCall.getCallee()).getRootReference().getValue() == - "_FortranAAssign") { + auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + if (funcName == FortranAssignStr) { if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) { // Record the target ops to process later targetOpsToProcess.insert(targetOp); @@ -632,26 +626,26 @@ WorkdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, return changed; } -// TeamsWorkdistributeToSingleOp method hoists all the ops inside -// teams {workdistribute{}} before teams op. -// -// If A() and B () are present inside teams workdistribute -// -// omp.teams { -// omp.workdistribute { -// A() -// B() -// } -// } -// -// Then, its lowered to -// -// A() -// B() -// -// If only the terminator remains in teams after hoisting, we erase teams op. +/// teamsWorkdistributeToSingleOp method hoists all the ops inside +/// teams {workdistribute{}} before teams op. +/// +/// If A() and B () are present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// } +/// } +/// +/// Then, its lowered to +/// +/// A() +/// B() +/// +/// If only the terminator remains in teams after hoisting, we erase teams op. static bool -TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp, +teamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp, SetVector &targetOpsToProcess) { auto workdistributeOp = getPerfectlyNested(teamsOp); if (!workdistributeOp) @@ -687,18 +681,17 @@ TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp, return true; } -// If multiple workdistribute are nested in a target regions, we will need to -// split the target region, but we want to preserve the data semantics of the -// original data region and avoid unnecessary data movement at each of the -// subkernels - we split the target region into a target_data{target} -// nest where only the outer one moves the data -std::optional splitTargetData(omp::TargetOp targetOp, - RewriterBase &rewriter) { +/// If multiple workdistribute are nested in a target regions, we will need to +/// split the target region, but we want to preserve the data semantics of the +/// original data region and avoid unnecessary data movement at each of the +/// subkernels - we split the target region into a target_data{target} +/// nest where only the outer one moves the data +FailureOr splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { auto loc = targetOp->getLoc(); if (targetOp.getMapVars().empty()) { - LLVM_DEBUG(llvm::dbgs() - << DEBUG_TYPE << " target region has no data maps\n"); - return std::nullopt; + emitError(loc, "Target region has no data maps\n"); + return failure(); } // Collect all the mapinfo ops SmallVector mapInfos; @@ -727,7 +720,8 @@ std::optional splitTargetData(omp::TargetOp targetOp, newCaptureType = originalCaptureType; outerMapInfos.push_back(mapInfo); } else { - llvm_unreachable("Unhandled case"); + emitError(targetOp->getLoc(), "Unhandled case"); + return failure(); } auto innerMapInfo = cast(rewriter.clone(*mapInfo)); innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( @@ -768,10 +762,10 @@ std::optional splitTargetData(omp::TargetOp targetOp, return newTargetOp; } -// getNestedOpToIsolate function is designed to identify a specific teams -// parallel op within the body of an omp::TargetOp that should be "isolated." -// This returns a tuple of op, if its first op in targetBlock, or if the op is -// last op in the tragte block. +/// getNestedOpToIsolate function is designed to identify a specific teams +/// parallel op within the body of an omp::TargetOp that should be "isolated." +/// This returns a tuple of op, if its first op in targetBlock, or if the op is +/// last op in the tragte block. static std::optional> getNestedOpToIsolate(omp::TargetOp targetOp) { if (targetOp.getRegion().empty()) @@ -789,17 +783,17 @@ getNestedOpToIsolate(omp::TargetOp targetOp) { return std::nullopt; } -// Temporary structure to hold the two mapinfo ops +/// Temporary structure to hold the two mapinfo ops struct TempOmpVar { omp::MapInfoOp from, to; }; -// isPtr checks if the type is a pointer or reference type. +/// isPtr checks if the type is a pointer or reference type. static bool isPtr(Type ty) { return isa(ty) || isa(ty); } -// getPtrTypeForOmp returns an LLVM pointer type for the given type. +/// getPtrTypeForOmp returns an LLVM pointer type for the given type. static Type getPtrTypeForOmp(Type ty) { if (isPtr(ty)) return LLVM::LLVMPointerType::get(ty.getContext()); @@ -807,7 +801,7 @@ static Type getPtrTypeForOmp(Type ty) { return fir::ReferenceType::get(ty); } -// allocateTempOmpVar allocates a temporary variable for OpenMP mapping +/// allocateTempOmpVar allocates a temporary variable for OpenMP mapping static TempOmpVar allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) { MLIRContext &ctx = *ty.getContext(); @@ -868,25 +862,14 @@ static bool usedOutsideSplit(Value v, Operation *split) { return false; } -// isRecomputableAfterFission checks if an operation can be recomputed +/// isRecomputableAfterFission checks if an operation can be recomputed static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { // If the op has side effects, it cannot be recomputed. // We consider fir.declare as having no side effects. - if (isa(op)) - return true; - - llvm::SmallVector effects; - MemoryEffectOpInterface interface = dyn_cast(op); - if (!interface) { - return false; - } - interface.getEffects(effects); - if (effects.empty()) - return true; - return false; + return isa(op) || isMemoryEffectFree(op); } -// collectNonRecomputableDeps collects dependencies that cannot be recomputed +/// collectNonRecomputableDeps collects dependencies that cannot be recomputed static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, SetVector &nonRecomputable, SetVector &toCache, @@ -909,7 +892,7 @@ static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, toRecompute); } -// createBlockArgsAndMap creates block arguments and maps them +/// createBlockArgsAndMap creates block arguments and maps them static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, @@ -967,7 +950,7 @@ static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, return; } -// reloadCacheAndRecompute reloads cached values and recomputes operations +/// reloadCacheAndRecompute reloads cached values and recomputes operations static void reloadCacheAndRecompute( Location loc, RewriterBase &rewriter, Operation *splitBefore, omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, @@ -1006,9 +989,9 @@ static void reloadCacheAndRecompute( } } -// Given a teamsOp, navigate down the nested structure to find the -// innermost LoopNestOp. The expected nesting is: -// teams -> parallel -> distribute -> wsloop -> loop_nest +/// Given a teamsOp, navigate down the nested structure to find the +/// innermost LoopNestOp. The expected nesting is: +/// teams -> parallel -> distribute -> wsloop -> loop_nest static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) { if (teamsOp.getRegion().empty()) return nullptr; @@ -1059,7 +1042,7 @@ static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) { return nullptr; } -// Generate LLVM constant operations for i32 and i64 types. +/// Generate LLVM constant operations for i32 and i64 types. static mlir::LLVM::ConstantOp genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { mlir::Type i32Ty = rewriter.getI32Type(); @@ -1067,9 +1050,9 @@ genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { return rewriter.create(loc, i32Ty, attr); } -// Given a box descriptor, extract the base address of the data it describes. -// If the box descriptor is a reference, load it first. -// The base address is returned as an i8* pointer. +/// Given a box descriptor, extract the base address of the data it describes. +/// If the box descriptor is a reference, load it first. +/// The base address is returned as an i8* pointer. static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, Location loc, Value boxDesc) { Value box = boxDesc; @@ -1087,9 +1070,9 @@ static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, return rawAddr; } -// Given a box descriptor, extract the total number of elements in the array it -// describes. If the box descriptor is a reference, load it first. -// The total number of elements is returned as an i64 value. +/// Given a box descriptor, extract the total number of elements in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The total number of elements is returned as an i64 value. static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, Location loc, Value boxDesc) { Value box = boxDesc; @@ -1102,9 +1085,9 @@ static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box); } -// Given a box descriptor, extract the size of each element in the array it -// describes. If the box descriptor is a reference, load it first. -// The element size is returned as an i64 value. +/// Given a box descriptor, extract the size of each element in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The element size is returned as an i64 value. static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, Value boxDesc) { Value box = boxDesc; @@ -1117,10 +1100,10 @@ static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, return fir::BoxEleSizeOp::create(builder, loc, i64Type, box); } -// Given a box descriptor, compute the total size in bytes of the data it -// describes. This is done by multiplying the total number of elements by the -// size of each element. If the box descriptor is a reference, load it first. -// The total size in bytes is returned as an i64 value. +/// Given a box descriptor, compute the total size in bytes of the data it +/// describes. This is done by multiplying the total number of elements by the +/// size of each element. If the box descriptor is a reference, load it first. +/// The total size in bytes is returned as an i64 value. static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, Location loc, Value boxDesc) { Value box = boxDesc; @@ -1134,11 +1117,11 @@ static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize); } -// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to -// retrieve the device pointer corresponding to a given host pointer and device -// number. If no mapping exists, the original host pointer is returned. -// Signature: -// void *omp_get_mapped_ptr(void *host_ptr, int device_num); +/// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to +/// retrieve the device pointer corresponding to a given host pointer and device +/// number. If no mapping exists, the original host pointer is returned. +/// Signature: +/// void *omp_get_mapped_ptr(void *host_ptr, int device_num); static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value hostPtr, @@ -1174,12 +1157,12 @@ static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, return result; } -// Generate a call to the OpenMP runtime function `omp_target_memcpy` to -// perform memory copy between host and device or between devices. -// Signature: -// int omp_target_memcpy(void *dst, const void *src, size_t length, -// size_t dst_offset, size_t src_offset, -// int dst_device, int src_device); +/// Generate a call to the OpenMP runtime function `omp_target_memcpy` to +/// perform memory copy between host and device or between devices. +/// Signature: +/// int omp_target_memcpy(void *dst, const void *src, size_t length, +/// size_t dst_offset, size_t src_offset, +/// int dst_device, int src_device); static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value dst, mlir::Value src, mlir::Value length, @@ -1209,11 +1192,11 @@ static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, return; } -// Generate code to replace a Fortran array assignment call with OpenMP -// runtime calls to perform the equivalent operation on the device. -// This involves extracting the source and destination pointers from the -// Fortran array descriptors, retrieving their mapped device pointers (if any), -// and invoking `omp_target_memcpy` to copy the data on the device. +/// Generate code to replace a Fortran array assignment call with OpenMP +/// runtime calls to perform the equivalent operation on the device. +/// This involves extracting the source and destination pointers from the +/// Fortran array descriptors, retrieving their mapped device pointers (if any), +/// and invoking `omp_target_memcpy` to copy the data on the device. static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, mlir::Location loc, fir::CallOp callOp, @@ -1250,20 +1233,20 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, device, module); } -// Struct to hold the host eval vars corresponding to loop bounds and steps +/// Struct to hold the host eval vars corresponding to loop bounds and steps struct HostEvalVars { SmallVector lbs; SmallVector ubs; SmallVector steps; }; -// moveToHost method clones all the ops from target region outside of it. -// It hoists runtime functions and replaces them with omp vesions. -// Also hoists and replaces fir.allocmem with omp.target_allocmem and -// fir.freemem with omp.target_freemem -static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, - mlir::ModuleOp module, - struct HostEvalVars &hostEvalVars) { +/// moveToHost method clones all the ops from target region outside of it. +/// It hoists runtime function "_FortranAAssign" and replaces it with omp +/// version. Also hoists and replaces fir.allocmem with omp.target_allocmem and +/// fir.freemem with omp.target_freemem +static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, + mlir::ModuleOp module, + struct HostEvalVars &hostEvalVars) { OpBuilder::InsertionGuard guard(rewriter); Block *targetBlock = &targetOp.getRegion().front(); assert(targetBlock == &targetOp.getRegion().back()); @@ -1272,8 +1255,9 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, // Get the parent target_data op auto targetDataOp = cast(targetOp->getParentOp()); if (!targetDataOp) { - llvm_unreachable("Expected target op to be inside target_data op"); - return; + emitError(targetOp->getLoc(), + "Expected target op to be inside target_data op"); + return failure(); } // create mapping for host_eval_vars unsigned hostEvalVarCount = targetOp.getHostEvalVars().size(); @@ -1345,11 +1329,12 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, // Check for runtime calls to be replaced. if (isRuntimeCall(clonedOp)) { fir::CallOp runtimeCall = cast(op); - if ((*runtimeCall.getCallee()).getRootReference().getValue() == - "_FortranAAssign") { + auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + if (funcName == FortranAssignStr) { opsToReplace.push_back(clonedOp); } else { - llvm_unreachable("Unhandled runtime call hoisting."); + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); } } } @@ -1392,8 +1377,8 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, // Replace runtime calls with omp versions. else if (isRuntimeCall(op)) { fir::CallOp runtimeCall = cast(op); - if ((*runtimeCall.getCallee()).getRootReference().getValue() == - "_FortranAAssign") { + auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + if (funcName == FortranAssignStr) { rewriter.setInsertionPoint(op); fir::FirOpBuilder builder{rewriter, op}; @@ -1402,10 +1387,12 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, module); rewriter.eraseOp(op); } else { - llvm_unreachable("Unhandled runtime call hoisting."); + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); } } else { - llvm_unreachable("Unhandled op hoisting."); + emitError(op->getLoc(), "Unhandled op hoisting."); + return failure(); } } @@ -1417,18 +1404,19 @@ static void moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, } // Finally erase the original targetOp. rewriter.eraseOp(targetOp); + return success(); } -// Result of isolateOp method +/// Result of isolateOp method struct SplitResult { omp::TargetOp preTargetOp; omp::TargetOp isolatedTargetOp; omp::TargetOp postTargetOp; }; -// computeAllocsCacheRecomputable method computes the allocs needed to cache -// the values that are used outside the split point. It also computes the ops -// that need to be cached and the ops that can be recomputed after the split. +/// computeAllocsCacheRecomputable method computes the allocs needed to cache +/// the values that are used outside the split point. It also computes the ops +/// that need to be cached and the ops that can be recomputed after the split. static void computeAllocsCacheRecomputable( omp::TargetOp targetOp, Operation *splitBeforeOp, RewriterBase &rewriter, SmallVector &preMapOperands, SmallVector &postMapOperands, @@ -1467,9 +1455,9 @@ static void computeAllocsCacheRecomputable( } } -// genPreTargetOp method generates the preTargetOp that contains all the ops -// before the split point. It also creates the block arguments and maps the -// values accordingly. It also creates the store operations for the allocs. +/// genPreTargetOp method generates the preTargetOp that contains all the ops +/// before the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the store operations for the allocs. static omp::TargetOp genPreTargetOp(omp::TargetOp targetOp, SmallVector &preMapOperands, SmallVector &allocs, Operation *splitBeforeOp, @@ -1546,10 +1534,10 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector &preMapOperands, return preTargetOp; } -// genIsolatedTargetOp method generates the isolatedTargetOp that contains the -// ops between the split point. It also creates the block arguments and maps -// the values accordingly. It also creates the load operations for the allocs -// and recomputes the necessary ops. +/// genIsolatedTargetOp method generates the isolatedTargetOp that contains the +/// ops between the split point. It also creates the block arguments and maps +/// the values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. static omp::TargetOp genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector &postMapOperands, Operation *splitBeforeOp, RewriterBase &rewriter, @@ -1635,10 +1623,10 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector &postMapOperands, return isolatedTargetOp; } -// genPostTargetOp method generates the postTargetOp that contains all the ops -// after the split point. It also creates the block arguments and maps the -// values accordingly. It also creates the load operations for the allocs -// and recomputes the necessary ops. +/// genPostTargetOp method generates the postTargetOp that contains all the ops +/// after the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, Operation *splitBeforeOp, SmallVector &postMapOperands, @@ -1681,20 +1669,21 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, return postTargetOp; } -// isolateOp method rewrites a omp.target_data { omp.target } in to -// omp.target_data { -// // preTargetOp region contains ops before splitBeforeOp. -// omp.target {} -// // isolatedTargetOp region contains splitBeforeOp, -// omp.target {} -// // postTargetOp region contains ops after splitBeforeOp. -// omp.target {} -// } -// It also handles the mapping of variables and the caching/recomputing -// of values as needed. -static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, - RewriterBase &rewriter, mlir::ModuleOp module, - bool isTargetDevice) { +/// isolateOp method rewrites a omp.target_data { omp.target } in to +/// omp.target_data { +/// // preTargetOp region contains ops before splitBeforeOp. +/// omp.target {} +/// // isolatedTargetOp region contains splitBeforeOp, +/// omp.target {} +/// // postTargetOp region contains ops after splitBeforeOp. +/// omp.target {} +/// } +/// It also handles the mapping of variables and the caching/recomputing +/// of values as needed. +static FailureOr isolateOp(Operation *splitBeforeOp, + bool splitAfter, RewriterBase &rewriter, + mlir::ModuleOp module, + bool isTargetDevice) { auto targetOp = cast(splitBeforeOp->getParentOp()); assert(targetOp); rewriter.setInsertionPoint(targetOp); @@ -1725,7 +1714,9 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, hostEvalVars, isTargetDevice); // Move the ops of preTarget to host. - moveToHost(preTargetOp, rewriter, module, hostEvalVars); + auto res = moveToHost(preTargetOp, rewriter, module, hostEvalVars); + if (failed(res)) + return failure(); rewriter.setInsertionPoint(targetOp); // Generate the isolatedTargetOp @@ -1745,15 +1736,15 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter, return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; } -// Recursively fission target ops until no more nested ops can be isolated. -static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter, - mlir::ModuleOp module, bool isTargetDevice) { +/// Recursively fission target ops until no more nested ops can be isolated. +static LogicalResult fissionTarget(omp::TargetOp targetOp, + RewriterBase &rewriter, + mlir::ModuleOp module, bool isTargetDevice) { auto tuple = getNestedOpToIsolate(targetOp); if (!tuple) { LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); struct HostEvalVars hostEvalVars; - moveToHost(targetOp, rewriter, module, hostEvalVars); - return; + return moveToHost(targetOp, rewriter, module, hostEvalVars); } Operation *toIsolate = std::get<0>(*tuple); bool splitBefore = !std::get<1>(*tuple); @@ -1762,19 +1753,24 @@ static void fissionTarget(omp::TargetOp targetOp, RewriterBase &rewriter, if (splitBefore && splitAfter) { auto res = isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); - fissionTarget(res.postTargetOp, rewriter, module, isTargetDevice); - return; + if (failed(res)) + return failure(); + return fissionTarget((*res).postTargetOp, rewriter, module, isTargetDevice); } // Isolate only before the op. if (splitBefore) { - isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); - return; + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); } else { - llvm::report_fatal_error("Unhandled case in fissionTarget"); + emitError(toIsolate->getLoc(), "Unhandled case in fissionTarget"); + return failure(); } + return success(); } -// Pass to lower omp.workdistribute ops. +/// Pass to lower omp.workdistribute ops. class LowerWorkdistributePass : public flangomp::impl::LowerWorkdistributeBase { public: @@ -1784,22 +1780,28 @@ class LowerWorkdistributePass bool changed = false; SetVector targetOpsToProcess; moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { - bool res = VerifyTargetTeamsWorkdistribute(workdistribute); - if (!res) + auto res = verifyTargetTeamsWorkdistribute(workdistribute); + if (failed(res)) signalPassFailure(); }); moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { - changed |= FissionWorkdistribute(workdistribute); + auto res = fissionWorkdistribute(workdistribute); + if (failed(res)) + signalPassFailure(); + changed |= *res; }); moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { - changed |= - WorkdistributeRuntimeCallLower(workdistribute, targetOpsToProcess); + auto res = + workdistributeRuntimeCallLower(workdistribute, targetOpsToProcess); + if (failed(res)) + signalPassFailure(); + changed |= *res; }); moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { - changed |= WorkdistributeDoLower(workdistribute, targetOpsToProcess); + changed |= workdistributeDoLower(workdistribute, targetOpsToProcess); }); moduleOp->walk([&](mlir::omp::TeamsOp teams) { - changed |= TeamsWorkdistributeToSingleOp(teams, targetOpsToProcess); + changed |= teamsWorkdistributeToSingleOp(teams, targetOpsToProcess); }); if (changed) { bool isTargetDevice = @@ -1808,8 +1810,12 @@ class LowerWorkdistributePass IRRewriter rewriter(&context); for (auto targetOp : targetOpsToProcess) { auto res = splitTargetData(targetOp, rewriter); - if (res) - fissionTarget(*res, rewriter, moduleOp, isTargetDevice); + if (failed(res)) + signalPassFailure(); + if (*res) { + if (failed(fissionTarget(*res, rewriter, moduleOp, isTargetDevice))) + signalPassFailure(); + } } } } diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir deleted file mode 100644 index d96068b26ca2f..0000000000000 --- a/flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir +++ /dev/null @@ -1,32 +0,0 @@ -// RUN: fir-opt --lower-workdistribute %s | FileCheck %s - -// CHECK-LABEL: func.func @test_nested_derived_type_map_operand_and_block_addition( -// CHECK-SAME: %[[ARG0:.*]]: !fir.ref}>>) { -// CHECK: %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref}>>) -> !fir.ref}>> -// CHECK: %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref}>>) -> !fir.ref> -// CHECK: %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref>) -> !fir.ref -// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"} -// CHECK: %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref}>>) -> !fir.ref> -// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref>) -> !fir.ref -// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"} -// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} -// CHECK: omp.target map_entries(%[[VAL_10]] -> %[[VAL_11:.*]] : !fir.ref}>>) { -// CHECK: omp.terminator -// CHECK: } -// CHECK: return -// CHECK: } - -func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref}>>) { - %0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref}>>) -> !fir.ref}>> - %2 = fir.coordinate_of %0, n : (!fir.ref}>>) -> !fir.ref> - %4 = fir.coordinate_of %2, i : (!fir.ref>) -> !fir.ref - %5 = omp.map.info var_ptr(%4 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%n%i"} - %7 = fir.coordinate_of %0, n : (!fir.ref}>>) -> !fir.ref> - %9 = fir.coordinate_of %7, r : (!fir.ref>) -> !fir.ref - %10 = omp.map.info var_ptr(%9 : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "sa%n%r"} - %11 = omp.map.info var_ptr(%0 : !fir.ref}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref, !fir.ref) -> !fir.ref}>> {name = "sa", partial_map = true} - omp.target map_entries(%11 -> %arg1 : !fir.ref}>>) { - omp.terminator - } - return -} From 3cbdf5d3b6cba9fd270e070b6a360eb24eefd037 Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 9 Oct 2025 15:20:13 +0530 Subject: [PATCH 17/21] Add verification for nested teamd and workdistribute. --- .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 82 +++++++++++++------ .../Lower/OpenMP/workdistribute-multiple.f90 | 20 +++++ ...workdistribute-teams-unsupported-after.f90 | 22 +++++ ...orkdistribute-teams-unsupported-before.f90 | 22 +++++ 4 files changed, 123 insertions(+), 23 deletions(-) create mode 100644 flang/test/Lower/OpenMP/workdistribute-multiple.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 7a41fbca679b0..090d9a0e3b985 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -78,6 +78,7 @@ static bool isRuntimeCall(Operation *op) { /// This is the single source of truth about whether we should parallelize an /// operation nested in an omp.workdistribute region. +/// Parallelize here refers to dividing into units of work. static bool shouldParallelize(Operation *op) { // True if the op is a runtime call to Assign if (isRuntimeCall(op)) { @@ -124,7 +125,7 @@ static T getPerfectlyNested(Operation *op) { /// verifyTargetTeamsWorkdistribute method verifies that /// omp.target { teams { workdistribute { ... } } } is well formed /// and fails for function calls that don't have lowering implemented yet. -static FailureOr +static LogicalResult verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { OpBuilder rewriter(workdistribute); auto loc = workdistribute->getLoc(); @@ -141,10 +142,30 @@ verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { emitError(loc, "teams with multiple blocks\n"); return failure(); } + + bool foundWorkdistribute = false; + for (auto &op : teams.getOps()) { + if (isa(op)) { + if (foundWorkdistribute) { + emitError(loc, "teams has multiple workdistribute ops.\n"); + return failure(); + } + foundWorkdistribute = true; + continue; + } + // Identify any omp dialect ops present before/after workdistribute. + if (op.getDialect() && isa(op.getDialect()) && + !isa(op)) { + emitError(loc, "teams has omp ops other than workdistribute. Lowering " + "not implemented yet.\n"); + return failure(); + } + } + omp::TargetOp targetOp = dyn_cast(teams->getParentOp()); // return if not omp.target if (!targetOp) - return true; + return success(); for (auto &op : workdistribute.getOps()) { if (auto callOp = dyn_cast(op)) { @@ -162,7 +183,7 @@ verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { } } } - return true; + return success(); } /// fissionWorkdistribute method finds the parallelizable ops @@ -1779,27 +1800,42 @@ class LowerWorkdistributePass auto moduleOp = getOperation(); bool changed = false; SetVector targetOpsToProcess; - moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { - auto res = verifyTargetTeamsWorkdistribute(workdistribute); - if (failed(res)) - signalPassFailure(); - }); - moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { - auto res = fissionWorkdistribute(workdistribute); - if (failed(res)) - signalPassFailure(); - changed |= *res; - }); - moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { - auto res = - workdistributeRuntimeCallLower(workdistribute, targetOpsToProcess); - if (failed(res)) - signalPassFailure(); - changed |= *res; - }); + auto verify = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + if (failed(verifyTargetTeamsWorkdistribute(workdistribute))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (verify.wasInterrupted()) + return signalPassFailure(); + + auto fission = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = fissionWorkdistribute(workdistribute); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (fission.wasInterrupted()) + return signalPassFailure(); + + auto rtCallLower = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = workdistributeRuntimeCallLower(workdistribute, + targetOpsToProcess); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (rtCallLower.wasInterrupted()) + return signalPassFailure(); + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { changed |= workdistributeDoLower(workdistribute, targetOpsToProcess); }); + moduleOp->walk([&](mlir::omp::TeamsOp teams) { changed |= teamsWorkdistributeToSingleOp(teams, targetOpsToProcess); }); @@ -1811,10 +1847,10 @@ class LowerWorkdistributePass for (auto targetOp : targetOpsToProcess) { auto res = splitTargetData(targetOp, rewriter); if (failed(res)) - signalPassFailure(); + return signalPassFailure(); if (*res) { if (failed(fissionTarget(*res, rewriter, moduleOp, isTargetDevice))) - signalPassFailure(); + return signalPassFailure(); } } } diff --git a/flang/test/Lower/OpenMP/workdistribute-multiple.f90 b/flang/test/Lower/OpenMP/workdistribute-multiple.f90 new file mode 100644 index 0000000000000..cf1d9dd294cea --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-multiple.f90 @@ -0,0 +1,20 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: teams has multiple workdistribute ops. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + + !$omp workdistribute + y = a * y + x + !$omp end workdistribute + !$omp end teams +end subroutine teams_workdistribute_1 diff --git a/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 new file mode 100644 index 0000000000000..f9c5a771f401d --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 @@ -0,0 +1,22 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: teams has omp ops other than workdistribute. Lowering not implemented yet. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + + !$omp distribute + do i = 1, 10 + x(i) = real(i, kind=real32) + end do + !$omp end distribute + !$omp end teams +end subroutine teams_workdistribute_1 diff --git a/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 new file mode 100644 index 0000000000000..3ef7f90087944 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 @@ -0,0 +1,22 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: teams has omp ops other than workdistribute. Lowering not implemented yet. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp distribute + do i = 1, 10 + x(i) = real(i, kind=real32) + end do + !$omp end distribute + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + !$omp end teams +end subroutine teams_workdistribute_1 From 6af7be30bc917b02c4b53b6ce18d2c162dea86f3 Mon Sep 17 00:00:00 2001 From: skc7 Date: Thu, 9 Oct 2025 15:44:52 +0530 Subject: [PATCH 18/21] Add corresponding tests for teams workdistribute aswell --- .../Lower/OpenMP/workdistribute-saxpy-1d.f90 | 16 ++++++++++ .../Lower/OpenMP/workdistribute-saxpy-2d.f90 | 19 ++++++++++++ .../Lower/OpenMP/workdistribute-saxpy-3d.f90 | 20 +++++++++++++ ...workdistribute-saxpy-and-scalar-assign.f90 | 20 +++++++++++++ .../OpenMP/workdistribute-saxpy-two-2d.f90 | 30 +++++++++++++++++++ .../OpenMP/workdistribute-scalar-assign.f90 | 9 ++++++ 6 files changed, 114 insertions(+) diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 index 95c3f37f4720e..b2dbc0f15121e 100644 --- a/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 @@ -20,4 +20,20 @@ subroutine target_teams_workdistribute() !$omp end target teams workdistribute end subroutine target_teams_workdistribute +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 index 70e82780edb1a..09e1211541edb 100644 --- a/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 @@ -23,4 +23,23 @@ subroutine target_teams_workdistribute(a, x, y, rows, cols) !$omp end target teams workdistribute end subroutine target_teams_workdistribute +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 index d6fa300eaff99..cf5d0234edb39 100644 --- a/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 @@ -24,4 +24,24 @@ subroutine target_teams_workdistribute(a, x, y, rows, cols, depth) !$omp end target teams workdistribute end subroutine target_teams_workdistribute +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols, depth) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols, depth + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols, depth) :: x, y + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + ! CHECK: fir.do_loop + + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 index 6b6dc3e3a184f..516c4603bd5da 100644 --- a/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 @@ -30,4 +30,24 @@ subroutine target_teams_workdistribute() !$omp end target teams workdistribute end subroutine target_teams_workdistribute +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams workdistribute + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = a * x + y + + ! CHECK: fir.call @_FortranAAssign + y = 2.0_real32 + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 index 2229ccf34e920..4aeb2e89140cc 100644 --- a/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 @@ -35,4 +35,34 @@ subroutine target_teams_workdistribute(a, x, y, rows, cols) !$omp end target teams workdistribute end subroutine target_teams_workdistribute +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + !$omp teams workdistribute + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * x + y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * y + x + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 index af94559dfa8cf..3062b3598b8ae 100644 --- a/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 +++ b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 @@ -18,3 +18,12 @@ subroutine target_teams_workdistribute_scalar_assign() end subroutine target_teams_workdistribute_scalar_assign +! CHECK-LABEL: func @_QPteams_workdistribute_scalar_assign +subroutine teams_workdistribute_scalar_assign() + integer :: aa(10) + ! CHECK: fir.call @_FortranAAssign + !$omp teams workdistribute + aa = 20 + !$omp end teams workdistribute + +end subroutine teams_workdistribute_scalar_assign From 5eaee02f06d0363a2f0afef92034f84f31867e9d Mon Sep 17 00:00:00 2001 From: ksankisa_amdeng Date: Fri, 17 Oct 2025 22:20:47 +0530 Subject: [PATCH 19/21] Fix typo and runtimeCall funcName --- flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 090d9a0e3b985..78a53a866da95 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -83,7 +83,7 @@ static bool shouldParallelize(Operation *op) { // True if the op is a runtime call to Assign if (isRuntimeCall(op)) { fir::CallOp runtimeCall = cast(op); - auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); if (funcName == FortranAssignStr) { return true; } @@ -354,7 +354,6 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, rewriter.create(doLoop->getLoc()); terminatorOp->erase(); } - return; } /// workdistributeDoLower method finds the fir.do_loop unoredered @@ -621,13 +620,10 @@ workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, targetOp = dyn_cast(teams->getParentOp()); SmallVector opsToErase; for (auto &op : workdistribute.getOps()) { - if (&op == terminator) { - break; - } if (isRuntimeCall(&op)) { rewriter.setInsertionPoint(&op); fir::CallOp runtimeCall = cast(op); - auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); if (funcName == FortranAssignStr) { if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) { // Record the target ops to process later @@ -786,7 +782,7 @@ FailureOr splitTargetData(omp::TargetOp targetOp, /// getNestedOpToIsolate function is designed to identify a specific teams /// parallel op within the body of an omp::TargetOp that should be "isolated." /// This returns a tuple of op, if its first op in targetBlock, or if the op is -/// last op in the tragte block. +/// last op in the traget block. static std::optional> getNestedOpToIsolate(omp::TargetOp targetOp) { if (targetOp.getRegion().empty()) @@ -1350,7 +1346,7 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, // Check for runtime calls to be replaced. if (isRuntimeCall(clonedOp)) { fir::CallOp runtimeCall = cast(op); - auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); if (funcName == FortranAssignStr) { opsToReplace.push_back(clonedOp); } else { @@ -1398,7 +1394,7 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, // Replace runtime calls with omp versions. else if (isRuntimeCall(op)) { fir::CallOp runtimeCall = cast(op); - auto funcName = (*runtimeCall.getCallee()).getRootReference().getValue(); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); if (funcName == FortranAssignStr) { rewriter.setInsertionPoint(op); fir::FirOpBuilder builder{rewriter, op}; From 790e017ac74d10810b649fcc38abdb5aa9314699 Mon Sep 17 00:00:00 2001 From: ksankisa_amdeng Date: Fri, 17 Oct 2025 22:48:18 +0530 Subject: [PATCH 20/21] fix CI error --- flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 78a53a866da95..c34f2e1c38dff 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -613,7 +613,6 @@ workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, return failure(); } auto *workdistributeBlock = &workdistribute.getRegion().front(); - auto *terminator = workdistributeBlock->getTerminator(); bool changed = false; omp::TargetOp targetOp; // Get the target op parent of teams From f4e1c3685306b9faa17370926fb3ece21b2695f4 Mon Sep 17 00:00:00 2001 From: skc7 Date: Fri, 17 Oct 2025 23:21:58 +0530 Subject: [PATCH 21/21] fix unused vars CI error --- flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index c34f2e1c38dff..9278e17e74d1b 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -612,11 +612,9 @@ workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, emitError(loc, "teams with multiple blocks\n"); return failure(); } - auto *workdistributeBlock = &workdistribute.getRegion().front(); bool changed = false; - omp::TargetOp targetOp; // Get the target op parent of teams - targetOp = dyn_cast(teams->getParentOp()); + omp::TargetOp targetOp = dyn_cast(teams->getParentOp()); SmallVector opsToErase; for (auto &op : workdistribute.getOps()) { if (isRuntimeCall(&op)) {