Skip to content

Commit ded0a19

Browse files
committed
[Flang][OpenMP] Add Semantics support for Nested OpenMPLoopConstructs
In OpenMP Version 5.1, the tile and unroll directives were added. When using these directives, it is possible to nest them within other OpenMP Loop Constructs. This patch enables the semantics to allow for this behaviour on these specific directives. Any nested loops will be stored within the initial Loop Construct until reaching the DoConstruct itself. Relevant tests have been added, and previous behaviour has been retained with no changes. See also, #110008
1 parent 779f724 commit ded0a19

File tree

11 files changed

+373
-65
lines changed

11 files changed

+373
-65
lines changed

flang/include/flang/Parser/parse-tree.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5025,7 +5025,7 @@ struct OpenMPLoopConstruct {
50255025
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
50265026
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
50275027
: t({std::move(a), std::nullopt, std::nullopt}) {}
5028-
std::tuple<OmpBeginLoopDirective, std::optional<DoConstruct>,
5028+
std::tuple<OmpBeginLoopDirective, std::optional<std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>>,
50295029
std::optional<OmpEndLoopDirective>>
50305030
t;
50315031
};

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4231,6 +4231,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
42314231
mlir::Location currentLocation =
42324232
converter.genLocation(beginLoopDirective.source);
42334233

4234+
auto &optLoopCons = std::get<1>(loopConstruct.t);
4235+
if(optLoopCons.has_value())
4236+
if(auto *ompNestedLoopCons{std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(&*optLoopCons)}) {
4237+
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value());
4238+
}
4239+
42344240
llvm::omp::Directive directive =
42354241
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v;
42364242
const parser::CharBlock &source =

flang/lib/Parser/unparse.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2926,7 +2926,7 @@ class UnparseVisitor {
29262926
Walk(std::get<OmpBeginLoopDirective>(x.t));
29272927
Put("\n");
29282928
EndOpenMP();
2929-
Walk(std::get<std::optional<DoConstruct>>(x.t));
2929+
Walk(std::get<std::optional<std::variant<DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
29302930
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
29312931
}
29322932
void Unparse(const BasedPointer &x) {

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "canonicalize-omp.h"
1010
#include "flang/Parser/parse-tree-visitor.h"
11+
#include "flang/Parser/parse-tree.h"
1112

1213
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1314
// Constructs more structured which provide explicit scopes for later
@@ -106,6 +107,12 @@ class CanonicalizationOfOmp {
106107
return nullptr;
107108
}
108109

110+
void missingDoConstruct(parser::OmpLoopDirective &dir) {
111+
messages_.Say(dir.source,
112+
"A DO loop must follow the %s directive"_err_en_US,
113+
parser::ToUpperCaseLetters(dir.source.ToString()));
114+
}
115+
109116
void RewriteOpenMPLoopConstruct(parser::OpenMPLoopConstruct &x,
110117
parser::Block &block, parser::Block::iterator it) {
111118
// Check the sequence of DoConstruct and OmpEndLoopDirective
@@ -135,31 +142,62 @@ class CanonicalizationOfOmp {
135142
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
136143
if (doCons->GetLoopControl()) {
137144
// move DoConstruct
138-
std::get<std::optional<parser::DoConstruct>>(x.t) =
145+
std::get<std::optional<std::variant<parser::DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
139146
std::move(*doCons);
140147
nextIt = block.erase(nextIt);
141148
// try to match OmpEndLoopDirective
142-
if (nextIt != block.end()) {
143-
if (auto *endDir{
144-
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
145-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
146-
std::move(*endDir);
147-
block.erase(nextIt);
148-
}
149+
if (auto *endDir{
150+
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
151+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
152+
std::move(*endDir);
153+
nextIt = block.erase(nextIt);
149154
}
150155
} else {
151156
messages_.Say(dir.source,
152157
"DO loop after the %s directive must have loop control"_err_en_US,
153158
parser::ToUpperCaseLetters(dir.source.ToString()));
154159
}
160+
} else if (auto *ompLoopCons{
161+
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
162+
// We should allow UNROLL and TILE constructs to be inserted between an OpenMP Loop Construct and the DO loop itself
163+
auto &beginDirective =
164+
std::get<parser::OmpBeginLoopDirective>(ompLoopCons->t);
165+
auto &beginLoopDirective =
166+
std::get<parser::OmpLoopDirective>(beginDirective.t);
167+
// iterate through the remaining block items to find the end directive for the unroll/tile directive.
168+
parser::Block::iterator endIt;
169+
endIt = nextIt;
170+
while(endIt != block.end()) {
171+
if (auto *endDir{
172+
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
173+
auto &endLoopDirective = std::get<parser::OmpLoopDirective>(endDir->t);
174+
if(endLoopDirective.v == dir.v) {
175+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
176+
std::move(*endDir);
177+
endIt = block.erase(endIt);
178+
continue;
179+
}
180+
}
181+
++endIt;
182+
}
183+
if ((beginLoopDirective.v == llvm::omp::Directive::OMPD_unroll ||
184+
beginLoopDirective.v == llvm::omp::Directive::OMPD_tile)) {
185+
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
186+
auto &ompLoop = std::get<std::optional<std::variant<parser::DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t);
187+
ompLoop = std::optional<std::variant<parser::DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>>{
188+
std::variant<parser::DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>{
189+
common::Indirection{std::move(*ompLoopCons)}}};
190+
nextIt = block.erase(nextIt);
191+
}
155192
} else {
156-
messages_.Say(dir.source,
157-
"A DO loop must follow the %s directive"_err_en_US,
158-
parser::ToUpperCaseLetters(dir.source.ToString()));
193+
missingDoConstruct(dir);
159194
}
160195
// If we get here, we either found a loop, or issued an error message.
161196
return;
162197
}
198+
if (nextIt == block.end()) {
199+
missingDoConstruct(dir);
200+
}
163201
}
164202

165203
void RewriteOmpAllocations(parser::ExecutionPart &body) {

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -762,10 +762,13 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
762762
}
763763
SetLoopInfo(x);
764764

765-
if (const auto &doConstruct{
766-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
767-
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
768-
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
765+
auto &optLoopCons = std::get<1>(x.t);
766+
if(optLoopCons.has_value()) {
767+
if (const auto &doConstruct{
768+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
769+
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
770+
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
771+
}
769772
}
770773
CheckLoopItrVariableIsInt(x);
771774
CheckAssociatedLoopConstraints(x);
@@ -779,19 +782,28 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
779782
(beginDir.v == llvm::omp::Directive::OMPD_distribute_simd)) {
780783
CheckDistLinear(x);
781784
}
785+
if (beginDir.v == llvm::omp::Directive::OMPD_tile) {
786+
const auto &clauses{std::get<parser::OmpClauseList>(beginLoopDir.t)};
787+
for (auto &clause : clauses.v) {
788+
789+
}
790+
}
782791
}
783792
const parser::Name OmpStructureChecker::GetLoopIndex(
784793
const parser::DoConstruct *x) {
785794
using Bounds = parser::LoopControl::Bounds;
786795
return std::get<Bounds>(x->GetLoopControl()->u).name.thing;
787796
}
788797
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
789-
if (const auto &loopConstruct{
790-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
791-
const parser::DoConstruct *loop{&*loopConstruct};
792-
if (loop && loop->IsDoNormal()) {
793-
const parser::Name &itrVal{GetLoopIndex(loop)};
794-
SetLoopIv(itrVal.symbol);
798+
auto &optLoopCons = std::get<1>(x.t);
799+
if (optLoopCons.has_value()) {
800+
if (const auto &loopConstruct{
801+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
802+
const parser::DoConstruct *loop{&*loopConstruct};
803+
if (loop && loop->IsDoNormal()) {
804+
const parser::Name &itrVal{GetLoopIndex(loop)};
805+
SetLoopIv(itrVal.symbol);
806+
}
795807
}
796808
}
797809
}
@@ -857,8 +869,10 @@ void OmpStructureChecker::CheckIteratorModifier(const parser::OmpIterator &x) {
857869

858870
void OmpStructureChecker::CheckLoopItrVariableIsInt(
859871
const parser::OpenMPLoopConstruct &x) {
860-
if (const auto &loopConstruct{
861-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
872+
auto &optLoopCons = std::get<1>(x.t);
873+
if (optLoopCons.has_value()) {
874+
if (const auto &loopConstruct{
875+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
862876

863877
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
864878
if (loop->IsDoNormal()) {
@@ -878,6 +892,7 @@ void OmpStructureChecker::CheckLoopItrVariableIsInt(
878892
const auto it{block.begin()};
879893
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
880894
: nullptr;
895+
}
881896
}
882897
}
883898
}
@@ -1077,8 +1092,10 @@ void OmpStructureChecker::CheckDistLinear(
10771092

10781093
// Match the loop index variables with the collected symbols from linear
10791094
// clauses.
1095+
auto &optLoopCons = std::get<1>(x.t);
1096+
if (optLoopCons.has_value()) {
10801097
if (const auto &loopConstruct{
1081-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
1098+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
10821099
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
10831100
if (loop->IsDoNormal()) {
10841101
const parser::Name &itrVal{GetLoopIndex(loop)};
@@ -1096,6 +1113,7 @@ void OmpStructureChecker::CheckDistLinear(
10961113
const auto it{block.begin()};
10971114
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
10981115
: nullptr;
1116+
}
10991117
}
11001118
}
11011119

flang/lib/Semantics/resolve-directives.cpp

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,10 +1796,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) {
17961796
SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
17971797

17981798
if (beginDir.v == llvm::omp::Directive::OMPD_do) {
1799-
if (const auto &doConstruct{
1800-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
1801-
if (doConstruct.value().IsDoWhile()) {
1802-
return true;
1799+
auto &optLoopCons = std::get<1>(x.t);
1800+
if (optLoopCons.has_value()) {
1801+
if (const auto &doConstruct{
1802+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
1803+
if (doConstruct->IsDoWhile()) {
1804+
return true;
1805+
}
18031806
}
18041807
}
18051808
}
@@ -1962,48 +1965,64 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel(
19621965
bool hasCollapseClause{
19631966
clause ? (clause->Id() == llvm::omp::OMPC_collapse) : false};
19641967

1965-
const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
1966-
if (outer.has_value()) {
1967-
for (const parser::DoConstruct *loop{&*outer}; loop && level > 0; --level) {
1968-
if (loop->IsDoConcurrent()) {
1969-
// DO CONCURRENT is explicitly allowed for the LOOP construct so long as
1970-
// there isn't a COLLAPSE clause
1971-
if (isLoopConstruct) {
1972-
if (hasCollapseClause) {
1973-
// hasCollapseClause implies clause != nullptr
1974-
context_.Say(clause->source,
1975-
"DO CONCURRENT loops cannot be used with the COLLAPSE clause."_err_en_US);
1968+
auto &optLoopCons = std::get<1>(x.t);
1969+
if (optLoopCons.has_value()) {
1970+
if (const auto &outer{std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
1971+
for (const parser::DoConstruct *loop{&*outer}; loop && level > 0; --level) {
1972+
if (loop->IsDoConcurrent()) {
1973+
// DO CONCURRENT is explicitly allowed for the LOOP construct so long as
1974+
// there isn't a COLLAPSE clause
1975+
if (isLoopConstruct) {
1976+
if (hasCollapseClause) {
1977+
// hasCollapseClause implies clause != nullptr
1978+
context_.Say(clause->source,
1979+
"DO CONCURRENT loops cannot be used with the COLLAPSE clause."_err_en_US);
1980+
}
1981+
} else {
1982+
auto &stmt =
1983+
std::get<parser::Statement<parser::NonLabelDoStmt>>(loop->t);
1984+
context_.Say(stmt.source,
1985+
"DO CONCURRENT loops cannot form part of a loop nest."_err_en_US);
19761986
}
1977-
} else {
1978-
auto &stmt =
1979-
std::get<parser::Statement<parser::NonLabelDoStmt>>(loop->t);
1980-
context_.Say(stmt.source,
1981-
"DO CONCURRENT loops cannot form part of a loop nest."_err_en_US);
1982-
}
1983-
}
1984-
// go through all the nested do-loops and resolve index variables
1985-
const parser::Name *iv{GetLoopIndex(*loop)};
1986-
if (iv) {
1987-
if (auto *symbol{ResolveOmp(*iv, ivDSA, currScope())}) {
1988-
SetSymbolDSA(*symbol, {Symbol::Flag::OmpPreDetermined, ivDSA});
1989-
iv->symbol = symbol; // adjust the symbol within region
1990-
AddToContextObjectWithDSA(*symbol, ivDSA);
19911987
}
1988+
// go through all the nested do-loops and resolve index variables
1989+
const parser::Name *iv{GetLoopIndex(*loop)};
1990+
if (iv) {
1991+
if (auto *symbol{ResolveOmp(*iv, ivDSA, currScope())}) {
1992+
SetSymbolDSA(*symbol, {Symbol::Flag::OmpPreDetermined, ivDSA});
1993+
iv->symbol = symbol; // adjust the symbol within region
1994+
AddToContextObjectWithDSA(*symbol, ivDSA);
1995+
}
19921996

1993-
const auto &block{std::get<parser::Block>(loop->t)};
1994-
const auto it{block.begin()};
1995-
loop = it != block.end() ? GetDoConstructIf(*it) : nullptr;
1997+
const auto &block{std::get<parser::Block>(loop->t)};
1998+
const auto it{block.begin()};
1999+
loop = it != block.end() ? GetDoConstructIf(*it) : nullptr;
2000+
}
2001+
}
2002+
CheckAssocLoopLevel(level, GetAssociatedClause());
2003+
} else if (const auto &loop{std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(&*optLoopCons)}) {
2004+
auto &beginDirective =
2005+
std::get<parser::OmpBeginLoopDirective>(loop->value().t);
2006+
auto &beginLoopDirective =
2007+
std::get<parser::OmpLoopDirective>(beginDirective.t);
2008+
if ((beginLoopDirective.v != llvm::omp::Directive::OMPD_unroll &&
2009+
beginLoopDirective.v != llvm::omp::Directive::OMPD_tile)) {
2010+
context_.Say(GetContext().directiveSource,
2011+
"Only UNROLL or TILE constructs are allowed between an OpenMP Loop Construct and a DO construct"_err_en_US,
2012+
parser::ToUpperCaseLetters(llvm::omp::getOpenMPDirectiveName(GetContext().directive, version).str()));
2013+
} else {
2014+
PrivatizeAssociatedLoopIndexAndCheckLoopLevel(loop->value());
19962015
}
2016+
} else {
2017+
context_.Say(GetContext().directiveSource,
2018+
"A DO loop must follow the %s directive"_err_en_US,
2019+
parser::ToUpperCaseLetters(
2020+
llvm::omp::getOpenMPDirectiveName(GetContext().directive, version)
2021+
.str()));
19972022
}
1998-
CheckAssocLoopLevel(level, GetAssociatedClause());
1999-
} else {
2000-
context_.Say(GetContext().directiveSource,
2001-
"A DO loop must follow the %s directive"_err_en_US,
2002-
parser::ToUpperCaseLetters(
2003-
llvm::omp::getOpenMPDirectiveName(GetContext().directive, version)
2004-
.str()));
20052023
}
20062024
}
2025+
20072026
void OmpAttributeVisitor::CheckAssocLoopLevel(
20082027
std::int64_t level, const parser::OmpClause *clause) {
20092028
if (clause && level != 0) {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
! Test to ensure TODO message is emitted for tile OpenMP 5.1 Directives when they are nested.
2+
3+
!RUN: not %flang -fopenmp -fopenmp-version=51 %s 2<&1 | FileCheck %s
4+
5+
subroutine loop_transformation_construct
6+
implicit none
7+
integer :: I = 10
8+
integer :: x
9+
integer :: y(I)
10+
11+
!$omp do
12+
!$omp tile
13+
do i = 1, I
14+
y(i) = y(i) * 5
15+
end do
16+
!$omp end tile
17+
!$omp end do
18+
end subroutine
19+
20+
!CHECK: not yet implemented: Unhandled loop directive (tile)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
! Test to ensure TODO message is emitted for unroll OpenMP 5.1 Directives when they are nested.
2+
3+
!RUN: not %flang -fopenmp -fopenmp-version=51 %s 2<&1 | FileCheck %s
4+
5+
program loop_transformation_construct
6+
implicit none
7+
integer, parameter :: I = 10
8+
integer :: x
9+
integer :: y(I)
10+
11+
!$omp do
12+
!$omp unroll
13+
do x = 1, I
14+
y(x) = y(x) * 5
15+
end do
16+
!$omp end unroll
17+
!$omp end do
18+
end program loop_transformation_construct
19+
20+
!CHECK: not yet implemented: Unhandled loop directive (unroll)

0 commit comments

Comments
 (0)