diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 3dcfe0fd775dc5..352ca66e8735b6 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -103,21 +103,6 @@ static fir::GlobalOp globalInitialization( return global; } -static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp, - mlir::Value loadVal) { - for (mlir::Value reductionOperand : reductionOp->getOperands()) { - if (mlir::Operation *compareOp = reductionOperand.getDefiningOp()) { - if (compareOp->getOperand(0) == loadVal || - compareOp->getOperand(1) == loadVal) - assert((mlir::isa(compareOp) || - mlir::isa(compareOp)) && - "Expected comparison not found in reduction intrinsic"); - return compareOp; - } - } - return nullptr; -} - // Get the extended value for \p val by extracting additional variable // information from \p base. static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base, @@ -237,213 +222,6 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter, return storeOp; } -static mlir::Operation * -findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) { - for (mlir::OpOperand &loadOperand : loadVal.getUses()) { - if (mlir::Operation *reductionOp = loadOperand.getOwner()) { - if (auto convertOp = mlir::dyn_cast(reductionOp)) { - for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) { - if (mlir::Operation *reductionOp = convertOperand.getOwner()) - return reductionOp; - } - } - for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) { - if (auto store = - mlir::dyn_cast(reductionOperand.getOwner())) { - if (store.getMemref() == *reductionVal) { - store.erase(); - return reductionOp; - } - } - if (auto assign = - mlir::dyn_cast(reductionOperand.getOwner())) { - if (assign.getLhs() == *reductionVal) { - assign.erase(); - return reductionOp; - } - } - } - } - } - return nullptr; -} - -// for a logical operator 'op' reduction X = X op Y -// This function returns the operation responsible for converting Y from -// fir.logical<4> to i1 -static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp, - mlir::Value loadVal) { - for (mlir::Value reductionOperand : reductionOp->getOperands()) { - if (auto convertOp = - mlir::dyn_cast(reductionOperand.getDefiningOp())) { - if (convertOp.getOperand() == loadVal) - continue; - return convertOp; - } - } - return nullptr; -} - -static void updateReduction(mlir::Operation *op, - fir::FirOpBuilder &firOpBuilder, - mlir::Value loadVal, mlir::Value reductionVal, - fir::ConvertOp *convertOp = nullptr) { - mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPoint(op); - - mlir::Value reductionOp; - if (convertOp) - reductionOp = convertOp->getOperand(); - else if (op->getOperand(0) == loadVal) - reductionOp = op->getOperand(1); - else - reductionOp = op->getOperand(0); - - firOpBuilder.create(op->getLoc(), reductionOp, - reductionVal); - firOpBuilder.restoreInsertionPoint(insertPtDel); -} - -static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) { - for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) { - if (auto convertReduction = - mlir::dyn_cast(reductionOpUse)) { - for (mlir::Operation *convertReductionUse : - convertReduction.getRes().getUsers()) { - if (auto storeOp = mlir::dyn_cast(convertReductionUse)) { - if (storeOp.getMemref() == symVal) - storeOp.erase(); - } - if (auto assignOp = - mlir::dyn_cast(convertReductionUse)) { - if (assignOp.getLhs() == symVal) - assignOp.erase(); - } - } - } - } -} - -// Generate an OpenMP reduction operation. -// TODO: Currently assumes it is either an integer addition/multiplication -// reduction, or a logical and reduction. Generalize this for various reduction -// operation types. -// TODO: Generate the reduction operation during lowering instead of creating -// and removing operations since this is not a robust approach. Also, removing -// ops in the builder (instead of a rewriter) is probably not the best approach. -static void -genOpenMPReduction(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauseList) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - List clauses{makeClauses(clauseList, semaCtx)}; - - for (const Clause &clause : clauses) { - if (const auto &reductionClause = - std::get_if(&clause.u)) { - const auto &redOperatorList{ - std::get( - reductionClause->t)}; - assert(redOperatorList.size() == 1 && "Expecting single operator"); - const auto &redOperator = redOperatorList.front(); - const auto &objects{std::get(reductionClause->t)}; - if (const auto *reductionOp = - std::get_if(&redOperator.u)) { - const auto &intrinsicOp{ - std::get( - reductionOp->u)}; - - switch (intrinsicOp) { - case clause::DefinedOperator::IntrinsicOperator::Add: - case clause::DefinedOperator::IntrinsicOperator::Multiply: - case clause::DefinedOperator::IntrinsicOperator::AND: - case clause::DefinedOperator::IntrinsicOperator::EQV: - case clause::DefinedOperator::IntrinsicOperator::OR: - case clause::DefinedOperator::IntrinsicOperator::NEQV: - break; - default: - continue; - } - for (const Object &object : objects) { - if (const Fortran::semantics::Symbol *symbol = object.id()) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - mlir::Type reductionType = - reductionVal.getType().cast().getEleTy(); - if (!reductionType.isa()) { - if (!reductionType.isIntOrIndexOrFloat()) - continue; - } - for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { - if (auto loadOp = - mlir::dyn_cast(reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - if (reductionType.isa()) { - mlir::Operation *reductionOp = findReductionChain(loadVal); - fir::ConvertOp convertOp = - getConvertFromReductionOp(reductionOp, loadVal); - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal, &convertOp); - removeStoreOp(reductionOp, reductionVal); - } else if (mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal)) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } - } - } - } - } - } else if (const auto *reductionIntrinsic = - std::get_if(&redOperator.u)) { - if (!ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) - continue; - ReductionProcessor::ReductionIdentifier redId = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Object &object : objects) { - if (const Fortran::semantics::Symbol *symbol = object.id()) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = - mlir::dyn_cast(reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; - - if (redId == ReductionProcessor::ReductionIdentifier::MAX || - redId == ReductionProcessor::ReductionIdentifier::MIN) { - assert(mlir::isa(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redId == ReductionProcessor::ReductionIdentifier::IOR || - redId == ReductionProcessor::ReductionIdentifier::IEOR || - redId == ReductionProcessor::ReductionIdentifier::IAND) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } - } - } - } - } - } - } - } -} - struct OpWithBodyGenInfo { /// A type for a code-gen callback function. This takes as argument the op for /// which the code is being generated and returns the arguments of the op's @@ -2197,7 +1975,6 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, // 2.9.3.1 SIMD construct createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList, currentLocation); - genOpenMPReduction(converter, semaCtx, loopOpClauseList); } else { createWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList, endClauseList, currentLocation);