@@ -572,8 +572,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
572572 llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
573573
574574 ClauseProcessor cp (converter, semaCtx, clauseList);
575- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
576- ifClauseOperand);
575+ cp.processIf (clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
577576 cp.processNumThreads (stmtCtx, numThreadsClauseOperand);
578577 cp.processProcBind (procBindKindAttr);
579578 cp.processDefault ();
@@ -676,8 +675,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
676675 dependOperands;
677676
678677 ClauseProcessor cp (converter, semaCtx, clauseList);
679- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
680- ifClauseOperand);
678+ cp.processIf (clause::If::DirectiveNameModifier::Task, ifClauseOperand);
681679 cp.processAllocate (allocatorOperands, allocateOperands);
682680 cp.processDefault ();
683681 cp.processFinal (stmtCtx, finalClauseOperand);
@@ -738,7 +736,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
738736 llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
739737
740738 ClauseProcessor cp (converter, semaCtx, clauseList);
741- cp.processIf (Fortran::parser::OmpIfClause ::DirectiveNameModifier::TargetData,
739+ cp.processIf (clause::If ::DirectiveNameModifier::TargetData,
742740 ifClauseOperand);
743741 cp.processDevice (stmtCtx, deviceOperand);
744742 cp.processUseDevicePtr (devicePtrOperands, useDeviceTypes, useDeviceLocs,
@@ -770,19 +768,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
770768 llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
771769 llvm::SmallVector<mlir::Attribute> dependTypeOperands;
772770
773- Fortran::parser::OmpIfClause ::DirectiveNameModifier directiveName;
771+ clause::If ::DirectiveNameModifier directiveName;
774772 llvm::omp::Directive directive;
775773 if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
776- directiveName =
777- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
774+ directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
778775 directive = llvm::omp::Directive::OMPD_target_enter_data;
779776 } else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
780- directiveName =
781- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
777+ directiveName = clause::If::DirectiveNameModifier::TargetExitData;
782778 directive = llvm::omp::Directive::OMPD_target_exit_data;
783779 } else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
784- directiveName =
785- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
780+ directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
786781 directive = llvm::omp::Directive::OMPD_target_update;
787782 } else {
788783 return nullptr ;
@@ -984,8 +979,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
984979 llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
985980
986981 ClauseProcessor cp (converter, semaCtx, clauseList);
987- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
988- ifClauseOperand);
982+ cp.processIf (clause::If::DirectiveNameModifier::Target, ifClauseOperand);
989983 cp.processDevice (stmtCtx, deviceOperand);
990984 cp.processThreadLimit (stmtCtx, threadLimitOperand);
991985 cp.processDepend (dependTypeOperands, dependOperands);
@@ -1102,8 +1096,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
11021096 llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
11031097
11041098 ClauseProcessor cp (converter, semaCtx, clauseList);
1105- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
1106- ifClauseOperand);
1099+ cp.processIf (clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
11071100 cp.processAllocate (allocatorOperands, allocateOperands);
11081101 cp.processDefault ();
11091102 cp.processNumTeams (stmtCtx, numTeamsClauseOperand);
@@ -1142,8 +1135,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
11421135
11431136 if (const auto *objectList{
11441137 Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u )}) {
1138+ ObjectList objects{makeList (*objectList, semaCtx)};
11451139 // Case: declare target(func, var1, var2)
1146- gatherFuncAndVarSyms (*objectList , mlir::omp::DeclareTargetCaptureClause::to,
1140+ gatherFuncAndVarSyms (objects , mlir::omp::DeclareTargetCaptureClause::to,
11471141 symbolAndClause);
11481142 } else if (const auto *clauseList{
11491143 Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -1257,7 +1251,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
12571251 if (const auto &ompObjectList =
12581252 std::get<std::optional<Fortran::parser::OmpObjectList>>(
12591253 flushConstruct.t ))
1260- genObjectList (*ompObjectList, converter, operandRange);
1254+ genObjectList2 (*ompObjectList, converter, operandRange);
12611255 const auto &memOrderClause =
12621256 std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
12631257 flushConstruct.t );
@@ -1419,8 +1413,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
14191413 loopVarTypeSize);
14201414 cp.processScheduleChunk (stmtCtx, scheduleChunkClauseOperand);
14211415 cp.processReduction (loc, reductionVars, reductionDeclSymbols);
1422- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
1423- ifClauseOperand);
1416+ cp.processIf (clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
14241417 cp.processSimdlen (simdlenClauseOperand);
14251418 cp.processSafelen (safelenClauseOperand);
14261419 cp.processTODO <Fortran::parser::OmpClause::Aligned,
@@ -2223,106 +2216,99 @@ void Fortran::lower::genOpenMPReduction(
22232216 const Fortran::parser::OmpClauseList &clauseList) {
22242217 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
22252218
2226- for (const Fortran::parser::OmpClause &clause : clauseList.v ) {
2219+ List<Clause> clauses{makeList (clauseList, semaCtx)};
2220+
2221+ for (const Clause &clause : clauses) {
22272222 if (const auto &reductionClause =
2228- std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u )) {
2229- const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
2230- reductionClause->v .t )};
2231- const auto &objectList{
2232- std::get<Fortran::parser::OmpObjectList>(reductionClause->v .t )};
2223+ std::get_if<clause::Reduction>(&clause.u )) {
2224+ const auto &redOperator{
2225+ std::get<clause::ReductionOperator>(reductionClause->t )};
2226+ const auto &objects{std::get<ObjectList>(reductionClause->t )};
22332227 if (const auto *reductionOp =
2234- std::get_if<Fortran::parser ::DefinedOperator>(&redOperator.u )) {
2228+ std::get_if<clause ::DefinedOperator>(&redOperator.u )) {
22352229 const auto &intrinsicOp{
2236- std::get<Fortran::parser ::DefinedOperator::IntrinsicOperator>(
2230+ std::get<clause ::DefinedOperator::IntrinsicOperator>(
22372231 reductionOp->u )};
22382232
22392233 switch (intrinsicOp) {
2240- case Fortran::parser ::DefinedOperator::IntrinsicOperator::Add:
2241- case Fortran::parser ::DefinedOperator::IntrinsicOperator::Multiply:
2242- case Fortran::parser ::DefinedOperator::IntrinsicOperator::AND:
2243- case Fortran::parser ::DefinedOperator::IntrinsicOperator::EQV:
2244- case Fortran::parser ::DefinedOperator::IntrinsicOperator::OR:
2245- case Fortran::parser ::DefinedOperator::IntrinsicOperator::NEQV:
2234+ case clause ::DefinedOperator::IntrinsicOperator::Add:
2235+ case clause ::DefinedOperator::IntrinsicOperator::Multiply:
2236+ case clause ::DefinedOperator::IntrinsicOperator::AND:
2237+ case clause ::DefinedOperator::IntrinsicOperator::EQV:
2238+ case clause ::DefinedOperator::IntrinsicOperator::OR:
2239+ case clause ::DefinedOperator::IntrinsicOperator::NEQV:
22462240 break ;
22472241 default :
22482242 continue ;
22492243 }
2250- for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
2251- if (const auto *name{
2252- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2253- if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
2254- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2255- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2256- reductionVal = declOp.getBase ();
2257- mlir::Type reductionType =
2258- reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
2259- if (!reductionType.isa <fir::LogicalType>()) {
2260- if (!reductionType.isIntOrIndexOrFloat ())
2261- continue ;
2262- }
2263- for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
2264- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2265- reductionValUse.getOwner ())) {
2266- mlir::Value loadVal = loadOp.getRes ();
2267- if (reductionType.isa <fir::LogicalType>()) {
2268- mlir::Operation *reductionOp = findReductionChain (loadVal);
2269- fir::ConvertOp convertOp =
2270- getConvertFromReductionOp (reductionOp, loadVal);
2271- updateReduction (reductionOp, firOpBuilder, loadVal,
2272- reductionVal, &convertOp);
2273- removeStoreOp (reductionOp, reductionVal);
2274- } else if (mlir::Operation *reductionOp =
2275- findReductionChain (loadVal, &reductionVal)) {
2276- updateReduction (reductionOp, firOpBuilder, loadVal,
2277- reductionVal);
2278- }
2244+ for (const Object &object : objects) {
2245+ if (const Fortran::semantics::Symbol *symbol = object.id ()) {
2246+ mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2247+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2248+ reductionVal = declOp.getBase ();
2249+ mlir::Type reductionType =
2250+ reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
2251+ if (!reductionType.isa <fir::LogicalType>()) {
2252+ if (!reductionType.isIntOrIndexOrFloat ())
2253+ continue ;
2254+ }
2255+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
2256+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
2257+ mlir::Value loadVal = loadOp.getRes ();
2258+ if (reductionType.isa <fir::LogicalType>()) {
2259+ mlir::Operation *reductionOp = findReductionChain (loadVal);
2260+ fir::ConvertOp convertOp =
2261+ getConvertFromReductionOp (reductionOp, loadVal);
2262+ updateReduction (reductionOp, firOpBuilder, loadVal,
2263+ reductionVal, &convertOp);
2264+ removeStoreOp (reductionOp, reductionVal);
2265+ } else if (mlir::Operation *reductionOp =
2266+ findReductionChain (loadVal, &reductionVal)) {
2267+ updateReduction (reductionOp, firOpBuilder, loadVal,
2268+ reductionVal);
22792269 }
22802270 }
22812271 }
22822272 }
22832273 }
22842274 } else if (const auto *reductionIntrinsic =
2285- std::get_if<Fortran::parser ::ProcedureDesignator>(
2275+ std::get_if<clause ::ProcedureDesignator>(
22862276 &redOperator.u )) {
22872277 if (!ReductionProcessor::supportedIntrinsicProcReduction (
22882278 *reductionIntrinsic))
22892279 continue ;
22902280 ReductionProcessor::ReductionIdentifier redId =
22912281 ReductionProcessor::getReductionType (*reductionIntrinsic);
2292- for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
2293- if (const auto *name{
2294- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2295- if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
2296- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2297- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2298- reductionVal = declOp.getBase ();
2299- for (const mlir::OpOperand &reductionValUse :
2300- reductionVal.getUses ()) {
2301- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2302- reductionValUse.getOwner ())) {
2303- mlir::Value loadVal = loadOp.getRes ();
2304- // Max is lowered as a compare -> select.
2305- // Match the pattern here.
2306- mlir::Operation *reductionOp =
2307- findReductionChain (loadVal, &reductionVal);
2308- if (reductionOp == nullptr )
2309- continue ;
2310-
2311- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2312- redId == ReductionProcessor::ReductionIdentifier::MIN) {
2313- assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2314- " Selection Op not found in reduction intrinsic" );
2315- mlir::Operation *compareOp =
2316- getCompareFromReductionOp (reductionOp, loadVal);
2317- updateReduction (compareOp, firOpBuilder, loadVal,
2318- reductionVal);
2319- }
2320- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2321- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2322- redId == ReductionProcessor::ReductionIdentifier::IAND) {
2323- updateReduction (reductionOp, firOpBuilder, loadVal,
2324- reductionVal);
2325- }
2282+ for (const Object &object : objects) {
2283+ if (const Fortran::semantics::Symbol *symbol = object.id ()) {
2284+ mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2285+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2286+ reductionVal = declOp.getBase ();
2287+ for (const mlir::OpOperand &reductionValUse :
2288+ reductionVal.getUses ()) {
2289+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
2290+ mlir::Value loadVal = loadOp.getRes ();
2291+ // Max is lowered as a compare -> select.
2292+ // Match the pattern here.
2293+ mlir::Operation *reductionOp =
2294+ findReductionChain (loadVal, &reductionVal);
2295+ if (reductionOp == nullptr )
2296+ continue ;
2297+
2298+ if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2299+ redId == ReductionProcessor::ReductionIdentifier::MIN) {
2300+ assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2301+ " Selection Op not found in reduction intrinsic" );
2302+ mlir::Operation *compareOp =
2303+ getCompareFromReductionOp (reductionOp, loadVal);
2304+ updateReduction (compareOp, firOpBuilder, loadVal,
2305+ reductionVal);
2306+ }
2307+ if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2308+ redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2309+ redId == ReductionProcessor::ReductionIdentifier::IAND) {
2310+ updateReduction (reductionOp, firOpBuilder, loadVal,
2311+ reductionVal);
23262312 }
23272313 }
23282314 }
0 commit comments