Skip to content

Conversation

@ergawy
Copy link
Member

@ergawy ergawy commented Apr 30, 2025

Corresponding RFC can be found here.

ergawy added 5 commits April 30, 2025 02:27
Adds support for lowering `do concurrent` nests from PFT to the new
`fir.do_concurrent` MLIR op as well as its special terminator
`fir.do_concurrent.loop` which models the actual loop nest.

To that end, this PR emits the allocations for the iteration variables
within the block of the `fir.do_concurrent` op and creates a region for
the `fir.do_concurrent.loop` op that accepts arguments equal in number
to the number of the input `do concurrent` iteration ranges.

For example, given the following input:
```fortran
   do concurrent(i=1:10, j=11:20)
   end do
```
the changes in this PR emit the following MLIR:
```mlir
    fir.do_concurrent {
      %22 = fir.alloca i32 {bindc_name = "i"}
      %23:2 = hlfir.declare %22 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
      %24 = fir.alloca i32 {bindc_name = "j"}
      %25:2 = hlfir.declare %24 {uniq_name = "_QFsub1Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
      fir.do_concurrent.loop (%arg1, %arg2) = (%18, %20) to (%19, %21) step (%c1, %c1_0) {
        %26 = fir.convert %arg1 : (index) -> i32
        fir.store %26 to %23#0 : !fir.ref<i32>
        %27 = fir.convert %arg2 : (index) -> i32
        fir.store %27 to %25#0 : !fir.ref<i32>
      }
    }
```
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Apr 30, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

Patch is 42.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137929.diff

12 Files Affected:

  • (modified) flang/include/flang/Lower/AbstractConverter.h (+3)
  • (modified) flang/include/flang/Optimizer/Dialect/CMakeLists.txt (+2-2)
  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+31-2)
  • (modified) flang/lib/Lower/Bridge.cpp (+52-10)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+37-14)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.h (+8-2)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+95-18)
  • (modified) flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp (+63-1)
  • (modified) flang/test/Fir/do_concurrent.fir (+77)
  • (modified) flang/test/Fir/invalid.fir (+5-5)
  • (added) flang/test/Lower/do_concurrent_delayed_locality.f90 (+58)
  • (modified) flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir (+62)
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 1d1323642bf9c..81c220e29e164 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -348,6 +348,9 @@ class AbstractConverter {
   virtual Fortran::lower::SymbolBox
   lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
 
+  virtual Fortran::lower::SymbolBox
+  shallowLookupSymbol(const Fortran::semantics::Symbol &sym) = 0;
+
   /// Return the mlir::SymbolTable associated to the ModuleOp.
   /// Look-ups are faster using it than using module.lookup<>,
   /// but the module op should be queried in case of failure
diff --git a/flang/include/flang/Optimizer/Dialect/CMakeLists.txt b/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
index adefcfea0b5dc..f0927d555190f 100644
--- a/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
+++ b/flang/include/flang/Optimizer/Dialect/CMakeLists.txt
@@ -17,8 +17,8 @@ mlir_tablegen(FIRAttr.cpp.inc -gen-attrdef-defs)
 set(LLVM_TARGET_DEFINITIONS FIROps.td)
 mlir_tablegen(FIROps.h.inc -gen-op-decls)
 mlir_tablegen(FIROps.cpp.inc -gen-op-defs)
-mlir_tablegen(FIROpsTypes.h.inc --gen-typedef-decls)
-mlir_tablegen(FIROpsTypes.cpp.inc --gen-typedef-defs)
+mlir_tablegen(FIROpsTypes.h.inc --gen-typedef-decls -typedefs-dialect=fir)
+mlir_tablegen(FIROpsTypes.cpp.inc --gen-typedef-defs -typedefs-dialect=fir)
 add_public_tablegen_target(FIROpsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS FortranVariableInterface.td)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index f9dc2e51a396c..4dce413b775fe 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -17,6 +17,7 @@
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
+include "mlir/Dialect/OpenMP/OpenMPClauses.td"
 include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
 include "flang/Optimizer/Dialect/FIRDialect.td"
 include "flang/Optimizer/Dialect/FIRTypes.td"
@@ -3570,7 +3571,7 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
       LLVM.
   }];
 
-  let arguments = (ins
+  defvar opArgs = (ins
     Variadic<Index>:$lowerBound,
     Variadic<Index>:$upperBound,
     Variadic<Index>:$step,
@@ -3579,17 +3580,45 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
     OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
   );
 
+  let arguments = !con(opArgs, OpenMP_PrivateClause.arguments);
+
   let regions = (region SizedRegion<1>:$region);
 
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 
-  let extraClassDeclaration = [{
+  defvar opExtraClassDeclaration = [{
+    unsigned getNumInductionVars() { return getLowerBound().size(); }
+
+    unsigned getNumPrivateOperands() { return getPrivateVars().size(); }
+
+    mlir::Block::BlockArgListType getInductionVars() {
+      return getBody()->getArguments().slice(0, getNumInductionVars());
+    }
+
+    mlir::Block::BlockArgListType getRegionPrivateArgs() {
+      return getBody()->getArguments().slice(getNumInductionVars(),
+                                             getNumPrivateOperands());
+    }
+
+    /// Number of operands controlling the loop
+    unsigned getNumControlOperands() { return getLowerBound().size() * 3; }
+
     // Get Number of reduction operands
     unsigned getNumReduceOperands() {
       return getReduceOperands().size();
     }
+
+    mlir::Operation::operand_range getPrivateOperands() {
+      return getOperands()
+          .slice(getNumControlOperands() + getNumReduceOperands(),
+                 getNumPrivateOperands());
+    }
   }];
+
+  let extraClassDeclaration =
+    !strconcat(opExtraClassDeclaration, "\n",
+               OpenMP_PrivateClause.extraClassDeclaration);
 }
 
 #endif
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a84a9c4afb441..cc292d610dcb9 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -12,6 +12,8 @@
 
 #include "flang/Lower/Bridge.h"
 
+#include "OpenMP/DataSharingProcessor.h"
+#include "OpenMP/Utils.h"
 #include "flang/Lower/Allocatable.h"
 #include "flang/Lower/CallInterface.h"
 #include "flang/Lower/Coarray.h"
@@ -1144,6 +1146,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return name;
   }
 
+  /// Find the symbol in the inner-most level of the local map or return null.
+  Fortran::lower::SymbolBox
+  shallowLookupSymbol(const Fortran::semantics::Symbol &sym) override {
+    if (Fortran::lower::SymbolBox v = localSymbols.shallowLookupSymbol(sym))
+      return v;
+    return {};
+  }
+
 private:
   FirConverter() = delete;
   FirConverter(const FirConverter &) = delete;
@@ -1218,14 +1228,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return {};
   }
 
-  /// Find the symbol in the inner-most level of the local map or return null.
-  Fortran::lower::SymbolBox
-  shallowLookupSymbol(const Fortran::semantics::Symbol &sym) {
-    if (Fortran::lower::SymbolBox v = localSymbols.shallowLookupSymbol(sym))
-      return v;
-    return {};
-  }
-
   /// Find the symbol in one level up of symbol map such as for host-association
   /// in OpenMP code or return null.
   Fortran::lower::SymbolBox
@@ -2028,9 +2030,31 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void handleLocalitySpecs(const IncrementLoopInfo &info) {
     Fortran::semantics::SemanticsContext &semanticsContext =
         bridge.getSemanticsContext();
-    for (const Fortran::semantics::Symbol *sym : info.localSymList)
+
+    Fortran::lower::omp::DataSharingProcessor dsp(
+        *this, semanticsContext, getEval(),
+        /*useDelayedPrivatization=*/true, localSymbols);
+    mlir::omp::PrivateClauseOps privateClauseOps;
+    auto doConcurrentLoopOp =
+        mlir::dyn_cast_if_present<fir::DoConcurrentLoopOp>(info.loopOp);
+    bool useDelayedPriv =
+        enableDelayedPrivatizationStaging && doConcurrentLoopOp;
+
+    for (const Fortran::semantics::Symbol *sym : info.localSymList) {
+      if (useDelayedPriv) {
+        dsp.doPrivatize(sym, &privateClauseOps);
+        continue;
+      }
+
       createHostAssociateVarClone(*sym, /*skipDefaultInit=*/false);
+    }
+
     for (const Fortran::semantics::Symbol *sym : info.localInitSymList) {
+      if (useDelayedPriv) {
+        dsp.doPrivatize(sym, &privateClauseOps);
+        continue;
+      }
+
       createHostAssociateVarClone(*sym, /*skipDefaultInit=*/true);
       const auto *hostDetails =
           sym->detailsIf<Fortran::semantics::HostAssocDetails>();
@@ -2049,6 +2073,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           sym->detailsIf<Fortran::semantics::HostAssocDetails>();
       copySymbolBinding(hostDetails->symbol(), *sym);
     }
+
+    if (useDelayedPriv) {
+      doConcurrentLoopOp.getPrivateVarsMutable().assign(
+          privateClauseOps.privateVars);
+      doConcurrentLoopOp.setPrivateSymsAttr(
+          builder->getArrayAttr(privateClauseOps.privateSyms));
+
+      for (auto [sym, privateVar] : llvm::zip_equal(
+               dsp.getAllSymbolsToPrivatize(), privateClauseOps.privateVars)) {
+        auto arg = doConcurrentLoopOp.getRegion().begin()->addArgument(
+            privateVar.getType(), doConcurrentLoopOp.getLoc());
+        bindSymbol(*sym, hlfir::translateToExtendedValue(
+                             privateVar.getLoc(), *builder, hlfir::Entity{arg},
+                             /*contiguousHint=*/true)
+                             .first);
+      }
+    }
+
     // Note that allocatable, types with ultimate components, and type
     // requiring finalization are forbidden in LOCAL/LOCAL_INIT (F2023 C1130),
     // so no clean-up needs to be generated for these entities.
@@ -2459,7 +2501,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           nestReduceAttrs.empty()
               ? nullptr
               : mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs),
-          nullptr);
+          nullptr, /*private_vars=*/std::nullopt, /*private_syms=*/nullptr);
 
       llvm::SmallVector<mlir::Type> loopBlockArgTypes(
           incrementLoopNestInfo.size(), builder->getIndexType());
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index b88454c45da85..bf130d592bf29 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -53,6 +53,15 @@ DataSharingProcessor::DataSharingProcessor(
   });
 }
 
+DataSharingProcessor::DataSharingProcessor(lower::AbstractConverter &converter,
+                                           semantics::SemanticsContext &semaCtx,
+                                           lower::pft::Evaluation &eval,
+                                           bool useDelayedPrivatization,
+                                           lower::SymMap &symTable)
+    : DataSharingProcessor(converter, semaCtx, {}, eval,
+                           /*shouldCollectPreDeterminedSymols=*/false,
+                           useDelayedPrivatization, symTable) {}
+
 void DataSharingProcessor::processStep1(
     mlir::omp::PrivateClauseOps *clauseOps) {
   collectSymbolsForPrivatization();
@@ -172,7 +181,8 @@ void DataSharingProcessor::cloneSymbol(const semantics::Symbol *sym) {
 
 void DataSharingProcessor::copyFirstPrivateSymbol(
     const semantics::Symbol *sym, mlir::OpBuilder::InsertPoint *copyAssignIP) {
-  if (sym->test(semantics::Symbol::Flag::OmpFirstPrivate))
+  if (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
+      sym->test(semantics::Symbol::Flag::LocalityLocalInit))
     converter.copyHostAssociateVar(*sym, copyAssignIP);
 }
 
@@ -504,22 +514,29 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
     }
 }
 
-void DataSharingProcessor::doPrivatize(const semantics::Symbol *sym,
+void DataSharingProcessor::doPrivatize(const semantics::Symbol *symToPrivatize,
                                        mlir::omp::PrivateClauseOps *clauseOps) {
   if (!useDelayedPrivatization) {
-    cloneSymbol(sym);
-    copyFirstPrivateSymbol(sym);
+    cloneSymbol(symToPrivatize);
+    copyFirstPrivateSymbol(symToPrivatize);
     return;
   }
 
-  lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
+  const semantics::Symbol *sym = symToPrivatize->HasLocalLocality()
+                                     ? &symToPrivatize->GetUltimate()
+                                     : symToPrivatize;
+  lower::SymbolBox hsb = symToPrivatize->HasLocalLocality()
+                             ? converter.shallowLookupSymbol(*sym)
+                             : converter.lookupOneLevelUpSymbol(*sym);
   assert(hsb && "Host symbol box not found");
   hlfir::Entity entity{hsb.getAddr()};
   bool cannotHaveNonDefaultLowerBounds = !entity.mayHaveNonDefaultLowerBounds();
 
   mlir::Location symLoc = hsb.getAddr().getLoc();
   std::string privatizerName = sym->name().ToString() + ".privatizer";
-  bool isFirstPrivate = sym->test(semantics::Symbol::Flag::OmpFirstPrivate);
+  bool isFirstPrivate =
+      symToPrivatize->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
+      symToPrivatize->test(semantics::Symbol::Flag::LocalityLocalInit);
 
   mlir::Value privVal = hsb.getAddr();
   mlir::Type allocType = privVal.getType();
@@ -613,27 +630,30 @@ void DataSharingProcessor::doPrivatize(const semantics::Symbol *sym,
           &copyRegion, /*insertPt=*/{}, {argType, argType}, {symLoc, symLoc});
       firOpBuilder.setInsertionPointToEnd(copyEntryBlock);
 
-      auto addSymbol = [&](unsigned argIdx, bool force = false) {
+      auto addSymbol = [&](unsigned argIdx, const semantics::Symbol *symToMap,
+                           bool force = false) {
         symExV.match(
             [&](const fir::MutableBoxValue &box) {
               symTable.addSymbol(
-                  *sym, fir::substBase(box, copyRegion.getArgument(argIdx)),
-                  force);
+                  *symToMap,
+                  fir::substBase(box, copyRegion.getArgument(argIdx)), force);
             },
             [&](const auto &box) {
-              symTable.addSymbol(*sym, copyRegion.getArgument(argIdx), force);
+              symTable.addSymbol(*symToMap, copyRegion.getArgument(argIdx),
+                                 force);
             });
       };
 
-      addSymbol(0, true);
+      addSymbol(0, sym, true);
       lower::SymMapScope innerScope(symTable);
-      addSymbol(1);
+      addSymbol(1, symToPrivatize);
 
       auto ip = firOpBuilder.saveInsertionPoint();
-      copyFirstPrivateSymbol(sym, &ip);
+      copyFirstPrivateSymbol(symToPrivatize, &ip);
 
       firOpBuilder.create<mlir::omp::YieldOp>(
-          hsb.getAddr().getLoc(), symTable.shallowLookupSymbol(*sym).getAddr());
+          hsb.getAddr().getLoc(),
+          symTable.shallowLookupSymbol(*symToPrivatize).getAddr());
     }
 
     return result;
@@ -645,6 +665,9 @@ void DataSharingProcessor::doPrivatize(const semantics::Symbol *sym,
   }
 
   symToPrivatizer[sym] = privatizerOp;
+
+  if (symToPrivatize->HasLocalLocality())
+    allPrivatizedSymbols.insert(symToPrivatize);
 }
 
 } // namespace omp
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index 54a42fd199831..f5fef9f6dfe85 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -105,8 +105,6 @@ class DataSharingProcessor {
   void collectImplicitSymbols();
   void collectPreDeterminedSymbols();
   void privatize(mlir::omp::PrivateClauseOps *clauseOps);
-  void doPrivatize(const semantics::Symbol *sym,
-                   mlir::omp::PrivateClauseOps *clauseOps);
   void copyLastPrivatize(mlir::Operation *op);
   void insertLastPrivateCompare(mlir::Operation *op);
   void cloneSymbol(const semantics::Symbol *sym);
@@ -125,6 +123,11 @@ class DataSharingProcessor {
                        bool shouldCollectPreDeterminedSymbols,
                        bool useDelayedPrivatization, lower::SymMap &symTable);
 
+  DataSharingProcessor(lower::AbstractConverter &converter,
+                       semantics::SemanticsContext &semaCtx,
+                       lower::pft::Evaluation &eval,
+                       bool useDelayedPrivatization, lower::SymMap &symTable);
+
   // Privatisation is split into two steps.
   // Step1 performs cloning of all privatisation clauses and copying for
   // firstprivates. Step1 is performed at the place where process/processStep1
@@ -151,6 +154,9 @@ class DataSharingProcessor {
                ? allPrivatizedSymbols.getArrayRef()
                : llvm::ArrayRef<const semantics::Symbol *>();
   }
+
+  void doPrivatize(const semantics::Symbol *sym,
+                   mlir::omp::PrivateClauseOps *clauseOps);
 };
 
 } // namespace omp
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 21cedb1030896..603e3ff5cdbfd 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4886,21 +4886,25 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
                                                  mlir::OperationState &result) {
   auto &builder = parser.getBuilder();
   // Parse an opening `(` followed by induction variables followed by `)`
-  llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs;
-  if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren))
+  llvm::SmallVector<mlir::OpAsmParser::Argument, 4> regionArgs;
+
+  if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren))
     return mlir::failure();
 
+  llvm::SmallVector<mlir::Type> argTypes(regionArgs.size(),
+                                         builder.getIndexType());
+
   // Parse loop bounds.
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower;
   if (parser.parseEqual() ||
-      parser.parseOperandList(lower, ivs.size(),
+      parser.parseOperandList(lower, regionArgs.size(),
                               mlir::OpAsmParser::Delimiter::Paren) ||
       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
     return mlir::failure();
 
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper;
   if (parser.parseKeyword("to") ||
-      parser.parseOperandList(upper, ivs.size(),
+      parser.parseOperandList(upper, regionArgs.size(),
                               mlir::OpAsmParser::Delimiter::Paren) ||
       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
     return mlir::failure();
@@ -4908,7 +4912,7 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
   // Parse step values.
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps;
   if (parser.parseKeyword("step") ||
-      parser.parseOperandList(steps, ivs.size(),
+      parser.parseOperandList(steps, regionArgs.size(),
                               mlir::OpAsmParser::Delimiter::Paren) ||
       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
     return mlir::failure();
@@ -4939,12 +4943,55 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
                         builder.getArrayAttr(arrayAttr));
   }
 
-  // Now parse the body.
-  mlir::Region *body = result.addRegion();
-  for (auto &iv : ivs)
-    iv.type = builder.getIndexType();
-  if (parser.parseRegion(*body, ivs))
-    return mlir::failure();
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> privateOperands;
+  if (succeeded(parser.parseOptionalKeyword("private"))) {
+    std::size_t oldArgTypesSize = argTypes.size();
+    if (failed(parser.parseLParen()))
+      return mlir::failure();
+
+    llvm::SmallVector<mlir::SymbolRefAttr> privateSymbolVec;
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (failed(parser.parseAttribute(privateSymbolVec.emplace_back())))
+            return mlir::failure();
+
+          if (parser.parseOperand(privateOperands.emplace_back()) ||
+              parser.parseArrow() ||
+              parser.parseArgument(regionArgs.emplace_back()))
+            return mlir::failure();
+
+          return mlir::success();
+        })))
+      return mlir::failure();
+
+    if (failed(parser.parseColon()))
+      return mlir::failure();
+
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (failed(parser.parseType(argTypes.emplace_back())))
+            return mlir::failure();
+
+          return mlir::success();
+        })))
+      return mlir::failure();
+
+    if (regionArgs.size() != argTypes.size())
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of private arg and types");
+
+    if (failed(parser.parseRParen()))
+      return mlir::failure();
+
+    for (auto operandType : llvm::zip_equal(
+             privateOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
+      if (parser.resolveOperand(std::get<0>(operandType),
+                                std::get<1>(operandType), result.operands))
+        return mlir::failure();
+
+    llvm::SmallVector<mlir::Attribute> symbolAttrs(privateSymbolVec.begin(),
+                                                   privateSymbolVec.end());
+    result.addAttribute(getPrivateSymsAttrName(result.name),
+                        builder.getArrayAttr(symbolAttrs));
+  }
 
   // Set `operandSegmentSizes` attribute.
   result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
@@ -4952,7 +4999,16 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
                           {static_cast<int32_t>(lower.size()),
                            static_cast<int32_t>(upper.size()),
                            static_cast<int32_t>(steps.size()),
-                           static_cast<int32_t>(reduceOperands.size())}));
+                           static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(privateOperands.size())}));
+
+  // Now parse the body.
+  for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes))
+    arg.type = type;
+
+  mlir::Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    ret...
[truncated]

@ergawy ergawy force-pushed the users/ergawy/pft_to_do_concurrent_3 branch 3 times, most recently from 4374004 to 1211438 Compare May 5, 2025 11:13
@ergawy ergawy closed this May 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants