Skip to content
5 changes: 4 additions & 1 deletion flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ struct AccEndCombinedDirective;
struct OpenACCDeclarativeConstruct;
struct OpenACCRoutineConstruct;
struct OpenMPConstruct;
struct OpenMPLoopConstruct;
struct OpenMPDeclarativeConstruct;
struct OmpEndLoopDirective;
struct OmpMemoryOrderClause;
Expand Down Expand Up @@ -5021,11 +5022,13 @@ struct OpenMPBlockConstruct {
};

// OpenMP directives enclosing do loop
using NestedConstruct =
std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>;
struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
: t({std::move(a), std::nullopt, std::nullopt}) {}
std::tuple<OmpBeginLoopDirective, std::optional<DoConstruct>,
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
std::optional<OmpEndLoopDirective>>
t;
};
Expand Down
10 changes: 10 additions & 0 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4231,6 +4231,16 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);

auto &optLoopCons =
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
if (optLoopCons.has_value()) {
if (auto *ompNestedLoopCons{
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value());
}
}

llvm::omp::Directive directive =
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v;
const parser::CharBlock &source =
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2926,7 +2926,8 @@ class UnparseVisitor {
Walk(std::get<OmpBeginLoopDirective>(x.t));
Put("\n");
EndOpenMP();
Walk(std::get<std::optional<DoConstruct>>(x.t));
Walk(std::get<std::optional<std::variant<DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
}
void Unparse(const BasedPointer &x) {
Expand Down
97 changes: 86 additions & 11 deletions flang/lib/Semantics/canonicalize-omp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "canonicalize-omp.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"

// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
// Constructs more structured which provide explicit scopes for later
Expand Down Expand Up @@ -125,6 +126,16 @@ class CanonicalizationOfOmp {
parser::Block::iterator nextIt;
auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
auto &dir{std::get<parser::OmpLoopDirective>(beginDir.t)};
auto missingDoConstruct = [](auto &dir, auto &messages) {
messages.Say(dir.source,
"A DO loop must follow the %s directive"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
};
auto tileUnrollError = [](auto &dir, auto &messages) {
messages.Say(dir.source,
"If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
};

nextIt = it;
while (++nextIt != block.end()) {
Expand All @@ -135,31 +146,95 @@ class CanonicalizationOfOmp {
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
if (doCons->GetLoopControl()) {
// move DoConstruct
std::get<std::optional<parser::DoConstruct>>(x.t) =
std::get<std::optional<std::variant<parser::DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
std::move(*doCons);
nextIt = block.erase(nextIt);
// try to match OmpEndLoopDirective
if (nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
block.erase(nextIt);
}
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
}
} else {
messages_.Say(dir.source,
"DO loop after the %s directive must have loop control"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
}
} else if (auto *ompLoopCons{
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
// We should allow UNROLL and TILE constructs to be inserted between an
// OpenMP Loop Construct and the DO loop itself
auto &nestedBeginDirective =
std::get<parser::OmpBeginLoopDirective>(ompLoopCons->t);
auto &nestedBeginLoopDirective =
std::get<parser::OmpLoopDirective>(nestedBeginDirective.t);
if ((nestedBeginLoopDirective.v == llvm::omp::Directive::OMPD_unroll ||
nestedBeginLoopDirective.v ==
llvm::omp::Directive::OMPD_tile) &&
!(nestedBeginLoopDirective.v == llvm::omp::Directive::OMPD_unroll &&
dir.v == llvm::omp::Directive::OMPD_tile)) {
// iterate through the remaining block items to find the end directive
// for the unroll/tile directive.
parser::Block::iterator endIt;
endIt = nextIt;
while (endIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
auto &endLoopDirective =
std::get<parser::OmpLoopDirective>(endDir->t);
if (endLoopDirective.v == dir.v) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
endIt = block.erase(endIt);
continue;
}
}
++endIt;
}
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
ompLoop =
std::optional<parser::NestedConstruct>{parser::NestedConstruct{
common::Indirection{std::move(*ompLoopCons)}}};
nextIt = block.erase(nextIt);
} else if (nestedBeginLoopDirective.v ==
llvm::omp::Directive::OMPD_unroll &&
dir.v == llvm::omp::Directive::OMPD_tile) {
// if a loop has been unrolled, the user can not then tile that loop
// as it has been unrolled
parser::OmpClauseList &unrollClauseList{
std::get<parser::OmpClauseList>(nestedBeginDirective.t)};
if (unrollClauseList.v.empty()) {
// if the clause list is empty for an unroll construct, we assume
// the loop is being fully unrolled
tileUnrollError(dir, messages_);
} else {
// parse the clauses for the unroll directive to find the full
// clause
for (auto clause{unrollClauseList.v.begin()};
clause != unrollClauseList.v.end(); ++clause) {
if (clause->Id() == llvm::omp::OMPC_full) {
tileUnrollError(dir, messages_);
}
}
}
} else {
messages_.Say(nestedBeginLoopDirective.source,
"Only Loop Transformation Constructs or Loop Nests can be nested within Loop Constructs"_err_en_US,
parser::ToUpperCaseLetters(
nestedBeginLoopDirective.source.ToString()));
}
} else {
messages_.Say(dir.source,
"A DO loop must follow the %s directive"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
missingDoConstruct(dir, messages_);
}
// If we get here, we either found a loop, or issued an error message.
return;
}
if (nextIt == block.end()) {
missingDoConstruct(dir, messages_);
}
}

void RewriteOmpAllocations(parser::ExecutionPart &body) {
Expand Down
100 changes: 56 additions & 44 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,10 +762,13 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
}
SetLoopInfo(x);

if (const auto &doConstruct{
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &doConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
}
}
CheckLoopItrVariableIsInt(x);
CheckAssociatedLoopConstraints(x);
Expand All @@ -786,12 +789,15 @@ const parser::Name OmpStructureChecker::GetLoopIndex(
return std::get<Bounds>(x->GetLoopControl()->u).name.thing;
}
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
if (const auto &loopConstruct{
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
const parser::DoConstruct *loop{&*loopConstruct};
if (loop && loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
SetLoopIv(itrVal.symbol);
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
const parser::DoConstruct *loop{&*loopConstruct};
if (loop && loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
SetLoopIv(itrVal.symbol);
}
}
}
}
Expand Down Expand Up @@ -857,27 +863,30 @@ void OmpStructureChecker::CheckIteratorModifier(const parser::OmpIterator &x) {

void OmpStructureChecker::CheckLoopItrVariableIsInt(
const parser::OpenMPLoopConstruct &x) {
if (const auto &loopConstruct{
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {

for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
const auto *type{itrVal.symbol->GetType()};
if (!type->IsNumeric(TypeCategory::Integer)) {
context_.Say(itrVal.source,
"The DO loop iteration"
" variable must be of the type integer."_err_en_US,
itrVal.ToString());
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
const auto *type{itrVal.symbol->GetType()};
if (!type->IsNumeric(TypeCategory::Integer)) {
context_.Say(itrVal.source,
"The DO loop iteration"
" variable must be of the type integer."_err_en_US,
itrVal.ToString());
}
}
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
}
}
Expand Down Expand Up @@ -1077,25 +1086,28 @@ void OmpStructureChecker::CheckDistLinear(

// Match the loop index variables with the collected symbols from linear
// clauses.
if (const auto &loopConstruct{
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
// Remove the symbol from the collected set
indexVars.erase(&itrVal.symbol->GetUltimate());
}
collapseVal--;
if (collapseVal == 0) {
break;
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
// Remove the symbol from the collected set
indexVars.erase(&itrVal.symbol->GetUltimate());
}
collapseVal--;
if (collapseVal == 0) {
break;
}
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
}

Expand Down
Loading