@@ -574,8 +574,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
574574 llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
575575
576576 ClauseProcessor cp (converter, semaCtx, clauseList);
577- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
578- ifClauseOperand);
577+ cp.processIf (clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
579578 cp.processNumThreads (stmtCtx, numThreadsClauseOperand);
580579 cp.processProcBind (procBindKindAttr);
581580 cp.processDefault ();
@@ -751,8 +750,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
751750 dependOperands;
752751
753752 ClauseProcessor cp (converter, semaCtx, clauseList);
754- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
755- ifClauseOperand);
753+ cp.processIf (clause::If::DirectiveNameModifier::Task, ifClauseOperand);
756754 cp.processAllocate (allocatorOperands, allocateOperands);
757755 cp.processDefault ();
758756 cp.processFinal (stmtCtx, finalClauseOperand);
@@ -865,8 +863,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
865863 llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
866864
867865 ClauseProcessor cp (converter, semaCtx, clauseList);
868- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
869- ifClauseOperand);
866+ cp.processIf (clause::If::DirectiveNameModifier::TargetData, ifClauseOperand);
870867 cp.processDevice (stmtCtx, deviceOperand);
871868 cp.processUseDevicePtr (devicePtrOperands, useDeviceTypes, useDeviceLocs,
872869 useDeviceSymbols);
@@ -911,20 +908,17 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
911908 llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
912909 llvm::SmallVector<mlir::Attribute> dependTypeOperands;
913910
914- Fortran::parser::OmpIfClause ::DirectiveNameModifier directiveName;
911+ clause::If ::DirectiveNameModifier directiveName;
915912 // GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
916913 [[maybe_unused]] llvm::omp::Directive directive;
917914 if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
918- directiveName =
919- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
915+ directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
920916 directive = llvm::omp::Directive::OMPD_target_enter_data;
921917 } else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
922- directiveName =
923- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
918+ directiveName = clause::If::DirectiveNameModifier::TargetExitData;
924919 directive = llvm::omp::Directive::OMPD_target_exit_data;
925920 } else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
926- directiveName =
927- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
921+ directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
928922 directive = llvm::omp::Directive::OMPD_target_update;
929923 } else {
930924 return nullptr ;
@@ -1126,8 +1120,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
11261120 llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
11271121
11281122 ClauseProcessor cp (converter, semaCtx, clauseList);
1129- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
1130- ifClauseOperand);
1123+ cp.processIf (clause::If::DirectiveNameModifier::Target, ifClauseOperand);
11311124 cp.processDevice (stmtCtx, deviceOperand);
11321125 cp.processThreadLimit (stmtCtx, threadLimitOperand);
11331126 cp.processDepend (dependTypeOperands, dependOperands);
@@ -1258,8 +1251,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
12581251 llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
12591252
12601253 ClauseProcessor cp (converter, semaCtx, clauseList);
1261- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
1262- ifClauseOperand);
1254+ cp.processIf (clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
12631255 cp.processAllocate (allocatorOperands, allocateOperands);
12641256 cp.processDefault ();
12651257 cp.processNumTeams (stmtCtx, numTeamsClauseOperand);
@@ -1298,8 +1290,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
12981290
12991291 if (const auto *objectList{
13001292 Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u )}) {
1293+ ObjectList objects{makeList (*objectList, semaCtx)};
13011294 // Case: declare target(func, var1, var2)
1302- gatherFuncAndVarSyms (*objectList , mlir::omp::DeclareTargetCaptureClause::to,
1295+ gatherFuncAndVarSyms (objects , mlir::omp::DeclareTargetCaptureClause::to,
13031296 symbolAndClause);
13041297 } else if (const auto *clauseList{
13051298 Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -1438,7 +1431,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
14381431 if (const auto &ompObjectList =
14391432 std::get<std::optional<Fortran::parser::OmpObjectList>>(
14401433 flushConstruct.t ))
1441- genObjectList (*ompObjectList, converter, operandRange);
1434+ genObjectList2 (*ompObjectList, converter, operandRange);
14421435 const auto &memOrderClause =
14431436 std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
14441437 flushConstruct.t );
@@ -1600,8 +1593,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
16001593 loopVarTypeSize);
16011594 cp.processScheduleChunk (stmtCtx, scheduleChunkClauseOperand);
16021595 cp.processReduction (loc, reductionVars, reductionDeclSymbols);
1603- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
1604- ifClauseOperand);
1596+ cp.processIf (clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
16051597 cp.processSimdlen (simdlenClauseOperand);
16061598 cp.processSafelen (safelenClauseOperand);
16071599 cp.processTODO <Fortran::parser::OmpClause::Aligned,
@@ -2419,106 +2411,100 @@ void Fortran::lower::genOpenMPReduction(
24192411 const Fortran::parser::OmpClauseList &clauseList) {
24202412 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
24212413
2422- for (const Fortran::parser::OmpClause &clause : clauseList.v ) {
2414+ List<Clause> clauses{makeList (clauseList, semaCtx)};
2415+
2416+ for (const Clause &clause : clauses) {
24232417 if (const auto &reductionClause =
2424- std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u )) {
2425- const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
2426- reductionClause->v .t )};
2427- const auto &objectList{
2428- std::get<Fortran::parser::OmpObjectList>(reductionClause->v .t )};
2418+ std::get_if<clause::Reduction>(&clause.u )) {
2419+ const auto &redOperator{
2420+ std::get<clause::ReductionOperator>(reductionClause->t )};
2421+ const auto &objects{std::get<ObjectList>(reductionClause->t )};
24292422 if (const auto *reductionOp =
2430- std::get_if<Fortran::parser ::DefinedOperator>(&redOperator.u )) {
2423+ std::get_if<clause ::DefinedOperator>(&redOperator.u )) {
24312424 const auto &intrinsicOp{
2432- std::get<Fortran::parser ::DefinedOperator::IntrinsicOperator>(
2425+ std::get<clause ::DefinedOperator::IntrinsicOperator>(
24332426 reductionOp->u )};
24342427
24352428 switch (intrinsicOp) {
2436- case Fortran::parser ::DefinedOperator::IntrinsicOperator::Add:
2437- case Fortran::parser ::DefinedOperator::IntrinsicOperator::Multiply:
2438- case Fortran::parser ::DefinedOperator::IntrinsicOperator::AND:
2439- case Fortran::parser ::DefinedOperator::IntrinsicOperator::EQV:
2440- case Fortran::parser ::DefinedOperator::IntrinsicOperator::OR:
2441- case Fortran::parser ::DefinedOperator::IntrinsicOperator::NEQV:
2429+ case clause ::DefinedOperator::IntrinsicOperator::Add:
2430+ case clause ::DefinedOperator::IntrinsicOperator::Multiply:
2431+ case clause ::DefinedOperator::IntrinsicOperator::AND:
2432+ case clause ::DefinedOperator::IntrinsicOperator::EQV:
2433+ case clause ::DefinedOperator::IntrinsicOperator::OR:
2434+ case clause ::DefinedOperator::IntrinsicOperator::NEQV:
24422435 break ;
24432436 default :
24442437 continue ;
24452438 }
2446- for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
2447- if (const auto *name{
2448- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2449- if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
2450- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2451- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2452- reductionVal = declOp.getBase ();
2453- mlir::Type reductionType =
2454- reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
2455- if (!reductionType.isa <fir::LogicalType>()) {
2456- if (!reductionType.isIntOrIndexOrFloat ())
2457- continue ;
2458- }
2459- for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
2460- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2461- reductionValUse.getOwner ())) {
2462- mlir::Value loadVal = loadOp.getRes ();
2463- if (reductionType.isa <fir::LogicalType>()) {
2464- mlir::Operation *reductionOp = findReductionChain (loadVal);
2465- fir::ConvertOp convertOp =
2466- getConvertFromReductionOp (reductionOp, loadVal);
2467- updateReduction (reductionOp, firOpBuilder, loadVal,
2468- reductionVal, &convertOp);
2469- removeStoreOp (reductionOp, reductionVal);
2470- } else if (mlir::Operation *reductionOp =
2471- findReductionChain (loadVal, &reductionVal)) {
2472- updateReduction (reductionOp, firOpBuilder, loadVal,
2473- reductionVal);
2474- }
2439+ for (const Object &object : objects) {
2440+ if (const Fortran::semantics::Symbol *symbol = object.id ()) {
2441+ mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2442+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2443+ reductionVal = declOp.getBase ();
2444+ mlir::Type reductionType =
2445+ reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
2446+ if (!reductionType.isa <fir::LogicalType>()) {
2447+ if (!reductionType.isIntOrIndexOrFloat ())
2448+ continue ;
2449+ }
2450+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
2451+ if (auto loadOp =
2452+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
2453+ mlir::Value loadVal = loadOp.getRes ();
2454+ if (reductionType.isa <fir::LogicalType>()) {
2455+ mlir::Operation *reductionOp = findReductionChain (loadVal);
2456+ fir::ConvertOp convertOp =
2457+ getConvertFromReductionOp (reductionOp, loadVal);
2458+ updateReduction (reductionOp, firOpBuilder, loadVal,
2459+ reductionVal, &convertOp);
2460+ removeStoreOp (reductionOp, reductionVal);
2461+ } else if (mlir::Operation *reductionOp =
2462+ findReductionChain (loadVal, &reductionVal)) {
2463+ updateReduction (reductionOp, firOpBuilder, loadVal,
2464+ reductionVal);
24752465 }
24762466 }
24772467 }
24782468 }
24792469 }
24802470 } else if (const auto *reductionIntrinsic =
2481- std::get_if<Fortran::parser::ProcedureDesignator>(
2482- &redOperator.u )) {
2471+ std::get_if<clause::ProcedureDesignator>(&redOperator.u )) {
24832472 if (!ReductionProcessor::supportedIntrinsicProcReduction (
24842473 *reductionIntrinsic))
24852474 continue ;
24862475 ReductionProcessor::ReductionIdentifier redId =
24872476 ReductionProcessor::getReductionType (*reductionIntrinsic);
2488- for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
2489- if (const auto *name{
2490- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2491- if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
2492- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2493- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2494- reductionVal = declOp.getBase ();
2495- for (const mlir::OpOperand &reductionValUse :
2496- reductionVal.getUses ()) {
2497- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2498- reductionValUse.getOwner ())) {
2499- mlir::Value loadVal = loadOp.getRes ();
2500- // Max is lowered as a compare -> select.
2501- // Match the pattern here.
2502- mlir::Operation *reductionOp =
2503- findReductionChain (loadVal, &reductionVal);
2504- if (reductionOp == nullptr )
2505- continue ;
2506-
2507- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2508- redId == ReductionProcessor::ReductionIdentifier::MIN) {
2509- assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2510- " Selection Op not found in reduction intrinsic" );
2511- mlir::Operation *compareOp =
2512- getCompareFromReductionOp (reductionOp, loadVal);
2513- updateReduction (compareOp, firOpBuilder, loadVal,
2514- reductionVal);
2515- }
2516- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2517- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2518- redId == ReductionProcessor::ReductionIdentifier::IAND) {
2519- updateReduction (reductionOp, firOpBuilder, loadVal,
2520- reductionVal);
2521- }
2477+ for (const Object &object : objects) {
2478+ if (const Fortran::semantics::Symbol *symbol = object.id ()) {
2479+ mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2480+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2481+ reductionVal = declOp.getBase ();
2482+ for (const mlir::OpOperand &reductionValUse :
2483+ reductionVal.getUses ()) {
2484+ if (auto loadOp =
2485+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
2486+ mlir::Value loadVal = loadOp.getRes ();
2487+ // Max is lowered as a compare -> select.
2488+ // Match the pattern here.
2489+ mlir::Operation *reductionOp =
2490+ findReductionChain (loadVal, &reductionVal);
2491+ if (reductionOp == nullptr )
2492+ continue ;
2493+
2494+ if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2495+ redId == ReductionProcessor::ReductionIdentifier::MIN) {
2496+ assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2497+ " Selection Op not found in reduction intrinsic" );
2498+ mlir::Operation *compareOp =
2499+ getCompareFromReductionOp (reductionOp, loadVal);
2500+ updateReduction (compareOp, firOpBuilder, loadVal,
2501+ reductionVal);
2502+ }
2503+ if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2504+ redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2505+ redId == ReductionProcessor::ReductionIdentifier::IAND) {
2506+ updateReduction (reductionOp, firOpBuilder, loadVal,
2507+ reductionVal);
25222508 }
25232509 }
25242510 }
0 commit comments