diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 3a61861f455d86..ef9b6515f2a24c 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1654,6 +1654,19 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder, return success(); } +static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) { + auto iface = + llvm::cast(teamsOp.getOperation()); + // Check that all uses of the reduction block arg has a distribute op parent. + for (auto ra : iface.getReductionBlockArgs()) + for (auto &use : ra.getUses()) { + auto useOp = use.getOwner(); + if (!useOp->getParentOfType()) + return false; + } + return true; +} + // Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder static LogicalResult convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, @@ -1662,32 +1675,39 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, if (failed(checkImplementationStatus(*op))) return failure(); - llvm::ArrayRef isByRef = getIsByRef(op.getReductionByref()); - assert(isByRef.size() == op.getNumReductionVars()); - + DenseMap reductionVariableMap; + unsigned numReductionVars = op.getNumReductionVars(); SmallVector reductionDecls; - collectReductionDecls(op, reductionDecls); + SmallVector privateReductionVariables(numReductionVars); + llvm::ArrayRef isByRef; llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); - SmallVector privateReductionVariables( - op.getNumReductionVars()); - DenseMap reductionVariableMap; + // Only do teams reduction if there is no distribute op that captures the + // reduction instead. + bool doTeamsReduction = !teamsReductionContainedInDistribute(op); + if (doTeamsReduction) { + isByRef = getIsByRef(op.getReductionByref()); - MutableArrayRef reductionArgs = - llvm::cast(*op).getReductionBlockArgs(); + assert(isByRef.size() == op.getNumReductionVars()); - if (failed(allocAndInitializeReductionVars( - op, reductionArgs, builder, moduleTranslation, allocaIP, - reductionDecls, privateReductionVariables, reductionVariableMap, - isByRef))) - return failure(); + MutableArrayRef reductionArgs = + llvm::cast(*op).getReductionBlockArgs(); - // Store the mapping between reduction variables and their private copies on - // ModuleTranslation stack. It can be then recovered when translating - // omp.reduce operations in a separate call. - LLVM::ModuleTranslation::SaveStack mappingGuard( - moduleTranslation, reductionVariableMap); + collectReductionDecls(op, reductionDecls); + + if (failed(allocAndInitializeReductionVars( + op, reductionArgs, builder, moduleTranslation, allocaIP, + reductionDecls, privateReductionVariables, reductionVariableMap, + isByRef))) + return failure(); + + // Store the mapping between reduction variables and their private copies on + // ModuleTranslation stack. It can be then recovered when translating + // omp.reduce operations in a separate call. + LLVM::ModuleTranslation::SaveStack mappingGuard( + moduleTranslation, reductionVariableMap); + } auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { LLVM::ModuleTranslation::SaveStack frame( @@ -1723,13 +1743,13 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, return failure(); builder.restoreIP(*afterIP); - - // Process the reductions if required. - return createReductionsAndCleanup( - op, builder, moduleTranslation, allocaIP, reductionDecls, - privateReductionVariables, isByRef, - /*isNoWait*/ false, /*isTeamsReduction*/ true); - + if (doTeamsReduction) { + // Process the reductions if required. + return createReductionsAndCleanup( + op, builder, moduleTranslation, allocaIP, reductionDecls, + privateReductionVariables, isByRef, + /*isNoWait*/ false, /*isTeamsReduction*/ true); + } return success(); } @@ -3815,6 +3835,43 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, if (failed(checkImplementationStatus(opInst))) return failure(); + /// Process teams op reduction in distribute if the reduction is contained in + /// the distribute op. + omp::TeamsOp teamsOp = opInst.getParentOfType(); + bool doDistributeReduction = + teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false; + + DenseMap reductionVariableMap; + unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0; + SmallVector reductionDecls; + SmallVector privateReductionVariables(numReductionVars); + llvm::ArrayRef isByRef; + + if (doDistributeReduction) { + isByRef = getIsByRef(teamsOp.getReductionByref()); + assert(isByRef.size() == teamsOp.getNumReductionVars()); + + collectReductionDecls(teamsOp, reductionDecls); + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); + + MutableArrayRef reductionArgs = + llvm::cast(*teamsOp) + .getReductionBlockArgs(); + + if (failed(allocAndInitializeReductionVars( + teamsOp, reductionArgs, builder, moduleTranslation, allocaIP, + reductionDecls, privateReductionVariables, reductionVariableMap, + isByRef))) + return failure(); + } + + // Store the mapping between reduction variables and their private copies on + // ModuleTranslation stack. It can be then recovered when translating + // omp.reduce operations in a separate call. + LLVM::ModuleTranslation::SaveStack mappingGuard( + moduleTranslation, reductionVariableMap); + auto loopOp = cast(distributeOp.getWrappedLoop()); SmallVector loopWrappers; @@ -3861,6 +3918,13 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, return opInst.emitError(llvm::toString(afterIP.takeError())); builder.restoreIP(*afterIP); + if (doDistributeReduction) { + // Process the reductions if required. + return createReductionsAndCleanup( + teamsOp, builder, moduleTranslation, allocaIP, reductionDecls, + privateReductionVariables, isByRef, + /*isNoWait*/ false, /*isTeamsReduction*/ true); + } return success(); }