Skip to content

Commit 2aa3450

Browse files
authored
[flang][OpenACC] remap component references in structured constructs (#171501)
OpenACC data clauses of structured constructs may contain component references (`obj%comp`, or `obj%array(i:j:k)`, ...). This changes allows using the ACC dialect data operation result for such clauses every time the component is referred to inside the scope of the construct. The bulk of the change is to add the ability to map `evaluate::Component` to mlir values in the symbol map used in lowering. This is done by adding the `ComponentMap` helper class to the lowering symbol map, and using it to override `evaluate::Component` reference lowering in expression lowering (ConvertExprToHLFIR.cpp). Some changes are made in Lower/Support/Utils.h in order to set-up/expose the hashing/equality helpers needed to use `evaluate::Component` in llvm::DenseMap. In OpenACC.cpp, `genPrivatizationRecipes` and `genDataOperandOperations` are merged to unify the processing of Designator in data clauses. New code is added to unwrap the rightmost `evaluate::Component`, if any, and remap it. Note that when the right most part is an array reference on a component (`obj%array(i:j:k)`), the whole component `obj%array(` is remapped, and the array reference is dealt with a bound operand like for whole object array. After this patch, all designators in data clauses on structured constructs will be remapped, except for array reference in private/first private/reduction (the result type of the related operation needs to be changed and the recipe adapted to "offset back"), component reference in reduction (this patch is adding a TODO for it), and device_ptr (previously bypassed remapping because of issues with descriptors, should be OK to lift it in a further patch). Note that this patch assumes that it is illegal to have code with intermediate variable indexing in the component reference (e.g. `array(i)%comp`) where the value of the index would be changed inside the region (e.g., `i` is assigned a new value), or where the component would be used with different indices meant to have the same value as the one used in the clause (e.g. `array(i)%comp` where `j` is meant to be the same as `i`). I will try to add a warning in semantics for such questionable/risky usages.
1 parent f0d7d83 commit 2aa3450

File tree

11 files changed

+645
-282
lines changed

11 files changed

+645
-282
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Lower/LoweringOptions.h"
1717
#include "flang/Lower/PFTDefs.h"
1818
#include "flang/Lower/Support/Utils.h"
19+
#include "flang/Lower/SymbolMap.h"
1920
#include "flang/Optimizer/Builder/BoxValue.h"
2021
#include "flang/Optimizer/Dialect/FIRAttr.h"
2122
#include "flang/Semantics/symbol.h"

flang/include/flang/Lower/Support/Utils.h

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#define FORTRAN_LOWER_SUPPORT_UTILS_H
1515

1616
#include "flang/Common/indirection.h"
17-
#include "flang/Lower/IterationSpace.h"
1817
#include "flang/Parser/char-block.h"
1918
#include "flang/Semantics/tools.h"
2019
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -23,10 +22,25 @@
2322
#include "llvm/ADT/SmallSet.h"
2423
#include "llvm/ADT/StringRef.h"
2524

25+
namespace Fortran::evaluate {
26+
class Component;
27+
class ArrayRef;
28+
} // namespace Fortran::evaluate
29+
2630
namespace Fortran::lower {
2731
using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
32+
using ExplicitSpaceArrayBases =
33+
std::variant<const semantics::Symbol *, const evaluate::Component *,
34+
const evaluate::ArrayRef *>;
35+
// FIXME: needed for privatizeSymbol that does not belong to this header.
36+
class AbstractConverter;
37+
class SymMap;
2838
} // end namespace Fortran::lower
2939

40+
namespace fir {
41+
class FirOpBuilder;
42+
}
43+
3044
//===----------------------------------------------------------------------===//
3145
// Small inline helper functions to deal with repetitive, clumsy conversions.
3246
//===----------------------------------------------------------------------===//
@@ -89,12 +103,15 @@ A flatZip(const A &container1, const A &container2) {
89103

90104
namespace Fortran::lower {
91105
unsigned getHashValue(const Fortran::lower::SomeExpr *x);
92-
unsigned getHashValue(const Fortran::lower::ExplicitIterSpace::ArrayBases &x);
106+
unsigned getHashValue(const Fortran::lower::ExplicitSpaceArrayBases &x);
107+
unsigned getHashValue(const Fortran::evaluate::Component *x);
93108

94109
bool isEqual(const Fortran::lower::SomeExpr *x,
95110
const Fortran::lower::SomeExpr *y);
96-
bool isEqual(const Fortran::lower::ExplicitIterSpace::ArrayBases &x,
97-
const Fortran::lower::ExplicitIterSpace::ArrayBases &y);
111+
bool isEqual(const Fortran::lower::ExplicitSpaceArrayBases &x,
112+
const Fortran::lower::ExplicitSpaceArrayBases &y);
113+
bool isEqual(const Fortran::evaluate::Component *x,
114+
const Fortran::evaluate::Component *y);
98115

99116
template <typename OpType, typename OperandsStructType>
100117
void privatizeSymbol(
@@ -125,6 +142,24 @@ struct DenseMapInfo<const Fortran::lower::SomeExpr *> {
125142
return Fortran::lower::isEqual(lhs, rhs);
126143
}
127144
};
145+
146+
// DenseMapInfo for pointers to Fortran::evaluate::Component.
147+
template <>
148+
struct DenseMapInfo<const Fortran::evaluate::Component *> {
149+
static inline const Fortran::evaluate::Component *getEmptyKey() {
150+
return reinterpret_cast<Fortran::evaluate::Component *>(~0);
151+
}
152+
static inline const Fortran::evaluate::Component *getTombstoneKey() {
153+
return reinterpret_cast<Fortran::evaluate::Component *>(~0 - 1);
154+
}
155+
static unsigned getHashValue(const Fortran::evaluate::Component *v) {
156+
return Fortran::lower::getHashValue(v);
157+
}
158+
static bool isEqual(const Fortran::evaluate::Component *lhs,
159+
const Fortran::evaluate::Component *rhs) {
160+
return Fortran::lower::isEqual(lhs, rhs);
161+
}
162+
};
128163
} // namespace llvm
129164

130165
#endif // FORTRAN_LOWER_SUPPORT_UTILS_H

flang/include/flang/Lower/SymbolMap.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define FORTRAN_LOWER_SYMBOLMAP_H
1515

1616
#include "flang/Common/reference.h"
17+
#include "flang/Lower/Support/Utils.h"
1718
#include "flang/Optimizer/Builder/BoxValue.h"
1819
#include "flang/Optimizer/Dialect/FIRType.h"
1920
#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
@@ -134,6 +135,41 @@ struct SymbolBox : public fir::details::matcher<SymbolBox> {
134135
VT box;
135136
};
136137

138+
/// Helper class to map `Fortran::evaluate::Component` references to IR values.
139+
/// This is used when the evaluation of a component reference must be
140+
/// overridden with a pre-computed address.
141+
class ComponentMap {
142+
public:
143+
void insert(const Fortran::evaluate::Component &component,
144+
fir::FortranVariableOpInterface definingOp) {
145+
auto iter = componentMap.find(&component);
146+
if (iter != componentMap.end()) {
147+
iter->second = definingOp;
148+
return;
149+
}
150+
componentStorage.push_back(
151+
std::make_unique<Fortran::evaluate::Component>(component));
152+
componentMap.insert({componentStorage.back().get(), definingOp});
153+
}
154+
155+
std::optional<fir::FortranVariableOpInterface>
156+
lookup(const Fortran::evaluate::Component *component) const {
157+
auto iter = componentMap.find(component);
158+
if (iter != componentMap.end())
159+
return iter->second;
160+
return std::nullopt;
161+
}
162+
163+
LLVM_DUMP_METHOD void dump() const;
164+
165+
private:
166+
llvm::DenseMap<const Fortran::evaluate::Component *,
167+
fir::FortranVariableOpInterface>
168+
componentMap;
169+
llvm::SmallVector<std::unique_ptr<Fortran::evaluate::Component>>
170+
componentStorage;
171+
};
172+
137173
//===----------------------------------------------------------------------===//
138174
// Map of symbol information
139175
//===----------------------------------------------------------------------===//
@@ -156,12 +192,15 @@ class SymMap {
156192
void pushScope() {
157193
symbolMapStack.emplace_back();
158194
storageMapStack.emplace_back();
195+
componentMapStack.emplace_back();
159196
}
160197
void popScope() {
161198
symbolMapStack.pop_back();
162199
assert(symbolMapStack.size() >= 1);
163200
storageMapStack.pop_back();
164201
assert(storageMapStack.size() >= 1);
202+
componentMapStack.pop_back();
203+
assert(componentMapStack.size() >= 1);
165204
}
166205

167206
/// Add an extended value to the symbol table.
@@ -301,6 +340,8 @@ class SymMap {
301340
impliedDoStack.clear();
302341
storageMapStack.clear();
303342
storageMapStack.emplace_back();
343+
componentMapStack.clear();
344+
componentMapStack.emplace_back();
304345
}
305346

306347
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
@@ -329,6 +370,33 @@ class SymMap {
329370
return std::nullopt;
330371
}
331372

373+
/// Register a mapping from a front-end component reference to the FIR
374+
/// variable that should be used to implement it. This is used to override
375+
/// the default lowering of component references in specific contexts.
376+
void addComponentOverride(const Fortran::evaluate::Component &component,
377+
fir::FortranVariableOpInterface definingOp) {
378+
assert(!componentMapStack.empty() && "component map stack is empty");
379+
if (!componentMapStack.back())
380+
componentMapStack.back() = std::make_unique<ComponentMap>();
381+
componentMapStack.back().value()->insert(component, definingOp);
382+
}
383+
384+
/// Lookup an overridden FIR variable definition for a given component
385+
/// reference, if any.
386+
std::optional<fir::FortranVariableOpInterface>
387+
lookupComponentOverride(const Fortran::evaluate::Component &component) const {
388+
for (auto jmap = componentMapStack.rbegin(),
389+
jend = componentMapStack.rend();
390+
jmap != jend; ++jmap) {
391+
if (*jmap) {
392+
auto iter = (**jmap)->lookup(&component);
393+
if (iter != std::nullopt)
394+
return iter;
395+
}
396+
}
397+
return std::nullopt;
398+
}
399+
332400
/// Register the symbol's storage at the innermost level
333401
/// of the symbol table. If the storage is already registered,
334402
/// it will be replaced.
@@ -360,6 +428,12 @@ class SymMap {
360428
// A stack of maps between the symbols and their storage descriptors.
361429
llvm::SmallVector<llvm::DenseMap<const semantics::Symbol *, StorageDesc>>
362430
storageMapStack;
431+
432+
// A stack of maps from front-end component references to the FIR variables
433+
// that should be used to implement them. This allows overriding component
434+
// references in specific lowering contexts.
435+
llvm::SmallVector<std::optional<std::unique_ptr<ComponentMap>>>
436+
componentMapStack;
363437
};
364438

365439
/// RAII wrapper for SymMap.

flang/lib/Lower/ConvertExprToHLFIR.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ class HlfirDesignatorBuilder {
368368

369369
fir::FortranVariableOpInterface
370370
gen(const Fortran::evaluate::Component &component) {
371+
if (auto remapped = symMap.lookupComponentOverride(component))
372+
return *remapped;
371373
if (Fortran::semantics::IsAllocatableOrPointer(component.GetLastSymbol()))
372374
return genWholeAllocatableOrPointerComponent(component);
373375
PartInfo partInfo;
@@ -477,6 +479,8 @@ class HlfirDesignatorBuilder {
477479

478480
fir::FortranVariableOpInterface genWholeAllocatableOrPointerComponent(
479481
const Fortran::evaluate::Component &component) {
482+
if (auto remapped = symMap.lookupComponentOverride(component))
483+
return *remapped;
480484
// Generate whole allocatable or pointer component reference. The
481485
// hlfir.designate result will be a pointer/allocatable.
482486
PartInfo partInfo;
@@ -539,7 +543,8 @@ class HlfirDesignatorBuilder {
539543
// components need special care to deal with the array%array_comp(indices)
540544
// case.
541545
if (Fortran::semantics::IsAllocatableOrObjectPointer(
542-
&component->GetLastSymbol()))
546+
&component->GetLastSymbol()) ||
547+
symMap.lookupComponentOverride(*component))
543548
baseType = visit(*component, partInfo);
544549
else
545550
baseType = hlfir::getFortranElementOrSequenceType(
@@ -686,6 +691,15 @@ class HlfirDesignatorBuilder {
686691
partInfo.typeParams);
687692
return partInfo.base->getElementOrSequenceType();
688693
}
694+
if (auto remapped = symMap.lookupComponentOverride(component)) {
695+
// Do not generate field for the designate if the component
696+
// is overridden, the override value is already addressing
697+
// the component.
698+
partInfo.base = *remapped;
699+
hlfir::genLengthParameters(loc, getBuilder(), *partInfo.base,
700+
partInfo.typeParams);
701+
return partInfo.base->getElementOrSequenceType();
702+
}
689703
// This function must be called from contexts where the component is not the
690704
// base of an ArrayRef. In these cases, the component cannot be an array
691705
// if the base is an array. The code below determines the shape of the

0 commit comments

Comments
 (0)