1313#include " Utils.h"
1414
1515#include " ClauseFinder.h"
16+ #include " flang/Evaluate/fold.h"
1617#include " flang/Lower/OpenMP/Clauses.h"
1718#include < flang/Lower/AbstractConverter.h>
1819#include < flang/Lower/ConvertType.h>
2425#include < flang/Parser/parse-tree.h>
2526#include < flang/Parser/tools.h>
2627#include < flang/Semantics/tools.h>
28+ #include < flang/Semantics/type.h>
2729#include < flang/Utils/OpenMP.h>
2830#include < llvm/Support/CommandLine.h>
2931
3032#include < iterator>
3133
34+ template <typename T>
35+ Fortran::semantics::MaybeIntExpr
36+ EvaluateIntExpr (Fortran::semantics::SemanticsContext &context, const T &expr) {
37+ if (Fortran::semantics::MaybeExpr maybeExpr{
38+ Fold (context.foldingContext (), AnalyzeExpr (context, expr))}) {
39+ if (auto *intExpr{
40+ Fortran::evaluate::UnwrapExpr<Fortran::semantics::SomeIntExpr>(
41+ *maybeExpr)}) {
42+ return std::move (*intExpr);
43+ }
44+ }
45+ return std::nullopt ;
46+ }
47+
48+ template <typename T>
49+ std::optional<std::int64_t >
50+ EvaluateInt64 (Fortran::semantics::SemanticsContext &context, const T &expr) {
51+ return Fortran::evaluate::ToInt64 (EvaluateIntExpr (context, expr));
52+ }
53+
3254llvm::cl::opt<bool > treatIndexAsSection (
3355 " openmp-treat-index-as-section" ,
3456 llvm::cl::desc (" In the OpenMP data clauses treat `a(N)` as `a(N:N)`." ),
@@ -577,12 +599,64 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
577599 }
578600}
579601
580- bool collectLoopRelatedInfo (
602+ // Helper function that finds the sizes clause in a inner OMPD_tile directive
603+ // and passes the sizes clause to the callback function if found.
604+ static void processTileSizesFromOpenMPConstruct (
605+ const parser::OpenMPConstruct *ompCons,
606+ std::function<void (const parser::OmpClause::Sizes *)> processFun) {
607+ if (!ompCons)
608+ return ;
609+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u )}) {
610+ const auto &nestedOptional =
611+ std::get<std::optional<parser::NestedConstruct>>(ompLoop->t );
612+ assert (nestedOptional.has_value () &&
613+ " Expected a DoConstruct or OpenMPLoopConstruct" );
614+ const auto *innerConstruct =
615+ std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
616+ &(nestedOptional.value ()));
617+ if (innerConstruct) {
618+ const auto &innerLoopDirective = innerConstruct->value ();
619+ const auto &innerBegin =
620+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
621+ const auto &innerDirective =
622+ std::get<parser::OmpLoopDirective>(innerBegin.t ).v ;
623+
624+ if (innerDirective == llvm::omp::Directive::OMPD_tile) {
625+ // Get the size values from parse tree and convert to a vector.
626+ const auto &innerClauseList{
627+ std::get<parser::OmpClauseList>(innerBegin.t )};
628+ for (const auto &clause : innerClauseList.v ) {
629+ if (const auto tclause{
630+ std::get_if<parser::OmpClause::Sizes>(&clause.u )}) {
631+ processFun (tclause);
632+ break ;
633+ }
634+ }
635+ }
636+ }
637+ }
638+ }
639+
640+ // / Populates the sizes vector with values if the given OpenMPConstruct
641+ // / contains a loop construct with an inner tiling construct.
642+ void collectTileSizesFromOpenMPConstruct (
643+ const parser::OpenMPConstruct *ompCons,
644+ llvm::SmallVectorImpl<int64_t > &tileSizes,
645+ Fortran::semantics::SemanticsContext &semaCtx) {
646+ processTileSizesFromOpenMPConstruct (
647+ ompCons, [&](const parser::OmpClause::Sizes *tclause) {
648+ for (auto &tval : tclause->v )
649+ if (const auto v{EvaluateInt64 (semaCtx, tval)})
650+ tileSizes.push_back (*v);
651+ });
652+ }
653+
654+ int64_t collectLoopRelatedInfo (
581655 lower::AbstractConverter &converter, mlir::Location currentLocation,
582656 lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
583657 mlir::omp::LoopRelatedClauseOps &result,
584658 llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
585- bool found = false ;
659+ int64_t numCollapse = 1 ;
586660 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
587661
588662 // Collect the loops to collapse.
@@ -595,9 +669,19 @@ bool collectLoopRelatedInfo(
595669 if (auto *clause =
596670 ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) {
597671 collapseValue = evaluate::ToInt64 (clause->v ).value ();
598- found = true ;
672+ numCollapse = collapseValue;
673+ }
674+
675+ // Collect sizes from tile directive if present.
676+ std::int64_t sizesLengthValue = 0l ;
677+ if (auto *ompCons{eval.getIf <parser::OpenMPConstruct>()}) {
678+ processTileSizesFromOpenMPConstruct (
679+ ompCons, [&](const parser::OmpClause::Sizes *tclause) {
680+ sizesLengthValue = tclause->v .size ();
681+ });
599682 }
600683
684+ collapseValue = std::max (collapseValue, sizesLengthValue);
601685 std::size_t loopVarTypeSize = 0 ;
602686 do {
603687 lower::pft::Evaluation *doLoop =
@@ -631,7 +715,7 @@ bool collectLoopRelatedInfo(
631715
632716 convertLoopBounds (converter, currentLocation, result, loopVarTypeSize);
633717
634- return found ;
718+ return numCollapse ;
635719}
636720
637721} // namespace omp
0 commit comments