-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][openacc] Use original input for base address with optional #80931
Conversation
@llvm/pr-subscribers-openacc @llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesIn #80317 the data op generation was updated to use correctly the #0 result from the hlfir.delcare op. In case of optional that are not descriptor, it is preferable to use the original input for the varPtr value of the OpenACC data op. Full diff: https://github.com/llvm/llvm-project/pull/80931.diff 3 Files Affected:
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index bd880376517dd..8d560db34e05b 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -52,10 +52,13 @@ namespace lower {
/// operations.
struct AddrAndBoundsInfo {
explicit AddrAndBoundsInfo() {}
- explicit AddrAndBoundsInfo(mlir::Value addr) : addr(addr) {}
- explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value isPresent)
- : addr(addr), isPresent(isPresent) {}
+ explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput)
+ : addr(addr), rawInput(rawInput) {}
+ explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput,
+ mlir::Value isPresent)
+ : addr(addr), rawInput(rawInput), isPresent(isPresent) {}
mlir::Value addr = nullptr;
+ mlir::Value rawInput = nullptr;
mlir::Value isPresent = nullptr;
};
@@ -615,20 +618,30 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
fir::FirOpBuilder &builder,
Fortran::lower::SymbolRef sym, mlir::Location loc) {
mlir::Value symAddr = converter.getSymbolAddress(sym);
+ mlir::Value rawInput = symAddr;
if (auto declareOp =
- mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp()))
+ mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
symAddr = declareOp.getResults()[0];
+ rawInput = declareOp.getResults()[1];
+ }
// TODO: Might need revisiting to handle for non-shared clauses
if (!symAddr) {
if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>())
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
symAddr = converter.getSymbolAddress(details->symbol());
+ rawInput = symAddr;
+ }
}
if (!symAddr)
llvm::report_fatal_error("could not retrieve symbol address");
+ mlir::Value isPresent;
+ if (Fortran::semantics::IsOptional(sym))
+ isPresent =
+ builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
+
if (auto boxTy =
fir::unwrapRefType(symAddr.getType()).dyn_cast<fir::BaseBoxType>()) {
if (boxTy.getEleTy().isa<fir::RecordType>())
@@ -638,8 +651,6 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
// `fir.ref<fir.class<T>>` type.
if (symAddr.getType().isa<fir::ReferenceType>()) {
if (Fortran::semantics::IsOptional(sym)) {
- mlir::Value isPresent =
- builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), symAddr);
mlir::Value addr =
builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
.genThen([&]() {
@@ -652,14 +663,13 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
})
.getResults()[0];
- return AddrAndBoundsInfo(addr, isPresent);
+ return AddrAndBoundsInfo(addr, rawInput, isPresent);
}
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
- return AddrAndBoundsInfo(addr);
- ;
+ return AddrAndBoundsInfo(addr, rawInput, isPresent);
}
}
- return AddrAndBoundsInfo(symAddr);
+ return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}
template <typename BoundsOp, typename BoundsType>
@@ -807,7 +817,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::StatementContext &stmtCtx,
const std::list<Fortran::parser::SectionSubscript> &subscripts,
std::stringstream &asFortran, fir::ExtendedValue &dataExv,
- bool dataExvIsAssumedSize, mlir::Value baseAddr,
+ bool dataExvIsAssumedSize, AddrAndBoundsInfo &info,
bool treatIndexAsSection = false) {
int dimension = 0;
mlir::Type idxTy = builder.getIndexType();
@@ -831,11 +841,30 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value stride = one;
bool strideInBytes = false;
- if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>()) {
- mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
- auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
- baseAddr, d);
- stride = dimInfo.getByteStride();
+ if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
+ if (info.isPresent) {
+ stride =
+ builder
+ .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value d =
+ builder.createIntegerConstant(loc, idxTy, dimension);
+ auto dimInfo = builder.create<fir::BoxDimsOp>(
+ loc, idxTy, idxTy, idxTy, info.addr, d);
+ builder.create<fir::ResultOp>(loc, dimInfo.getByteStride());
+ })
+ .genElse([&] {
+ mlir::Value zero =
+ builder.createIntegerConstant(loc, idxTy, 0);
+ builder.create<fir::ResultOp>(loc, zero);
+ })
+ .getResults()[0];
+ } else {
+ mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
+ auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy,
+ idxTy, info.addr, d);
+ stride = dimInfo.getByteStride();
+ }
strideInBytes = true;
}
@@ -919,7 +948,26 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
}
}
- extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
+ if (info.isPresent &&
+ fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
+ extent =
+ builder
+ .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value ext = fir::factory::readExtent(
+ builder, loc, dataExv, dimension);
+ builder.create<fir::ResultOp>(loc, ext);
+ })
+ .genElse([&] {
+ mlir::Value zero =
+ builder.createIntegerConstant(loc, idxTy, 0);
+ builder.create<fir::ResultOp>(loc, zero);
+ })
+ .getResults()[0];
+ } else {
+ extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
+ }
+
if (dataExvIsAssumedSize && dimension + 1 == dataExvRank) {
extent = zero;
if (ubound && lbound) {
@@ -976,6 +1024,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
dataExv = converter.genExprAddr(operandLocation, *exprBase,
stmtCtx);
info.addr = fir::getBase(dataExv);
+ info.rawInput = info.addr;
asFortran << (*exprBase).AsFortran();
} else {
const Fortran::parser::Name &name =
@@ -993,7 +1042,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
bounds = genBoundsOps<BoundsOp, BoundsType>(
builder, operandLocation, converter, stmtCtx,
arrayElement->subscripts, asFortran, dataExv,
- dataExvIsAssumedSize, info.addr, treatIndexAsSection);
+ dataExvIsAssumedSize, info, treatIndexAsSection);
}
asFortran << ')';
} else if (auto structComp = Fortran::parser::Unwrap<
@@ -1001,6 +1050,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
fir::ExtendedValue compExv =
converter.genExprAddr(operandLocation, *expr, stmtCtx);
info.addr = fir::getBase(compExv);
+ info.rawInput = info.addr;
if (fir::unwrapRefType(info.addr.getType())
.isa<fir::SequenceType>())
bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
@@ -1012,7 +1062,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
*Fortran::parser::GetLastName(*structComp).symbol);
if (isOptional)
info.isPresent = builder.create<fir::IsPresentOp>(
- operandLocation, builder.getI1Type(), info.addr);
+ operandLocation, builder.getI1Type(), info.rawInput);
if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(
info.addr.getDefiningOp())) {
@@ -1020,6 +1070,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
fir::isPointerType(loadOp.getType()))
info.addr = builder.create<fir::BoxAddrOp>(operandLocation,
info.addr);
+ info.rawInput = info.addr;
}
// If the component is an allocatable or pointer the result of
@@ -1029,6 +1080,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
info.addr.getDefiningOp())) {
info.addr = boxAddrOp.getVal();
+ info.rawInput = info.addr;
bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
builder, operandLocation, converter, compExv, info);
}
@@ -1043,6 +1095,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
fir::ExtendedValue compExv =
converter.genExprAddr(operandLocation, *expr, stmtCtx);
info.addr = fir::getBase(compExv);
+ info.rawInput = info.addr;
asFortran << (*expr).AsFortran();
} else if (const auto *dataRef{
std::get_if<Fortran::parser::DataRef>(
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 43f54c6d2a71b..6ae270f63f5cf 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -67,9 +67,12 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value varPtrPtr;
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
if (isPresent) {
+ mlir::Type ifRetTy = boxTy.getEleTy();
+ if (!fir::isa_ref_type(ifRetTy))
+ ifRetTy = fir::ReferenceType::get(ifRetTy);
baseAddr =
builder
- .genIfOp(loc, {boxTy.getEleTy()}, isPresent,
+ .genIfOp(loc, {ifRetTy}, isPresent,
/*withElseRegion=*/true)
.genThen([&]() {
mlir::Value boxAddr =
@@ -78,7 +81,7 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
})
.genElse([&] {
mlir::Value absent =
- builder.create<fir::AbsentOp>(loc, boxTy.getEleTy());
+ builder.create<fir::AbsentOp>(loc, ifRetTy);
builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
})
.getResults()[0];
@@ -295,9 +298,16 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
asFortran, bounds,
/*treatIndexAsSection=*/true);
- Op op = createDataEntryOp<Op>(
- builder, operandLocation, info.addr, asFortran, bounds, structured,
- implicit, dataClause, info.addr.getType(), info.isPresent);
+ // If the input value is optional and is not a descriptor, we use the
+ // rawInput directly.
+ mlir::Value baseAddr =
+ ((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) &&
+ info.isPresent)
+ ? info.rawInput
+ : info.addr;
+ Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
+ bounds, structured, implicit, dataClause,
+ baseAddr.getType(), info.isPresent);
dataOperands.push_back(op.getAccPtr());
}
}
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index bd96bc8bcba35..df97cbcd187d2 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -126,8 +126,8 @@ subroutine acc_optional_data(a)
! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
-! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
-! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#0 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
+! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
! CHECK: %[[BOX:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
! CHECK: fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
@@ -153,4 +153,38 @@ subroutine acc_optional_data(a)
! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr<!fir.array<?xf32>>)
+ subroutine acc_optional_data2(a, n)
+ integer :: n
+ real, optional :: a(n)
+ !$acc data no_create(a)
+ !$acc end data
+ end subroutine
+
+! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data2(
+! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
+! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "_QMopenacc_boundsFacc_optional_data2Ea"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+! CHECK: %[[NO_CREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%10) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
+! CHECK: acc.data dataOperands(%[[NO_CREATE]] : !fir.ref<!fir.array<?xf32>>) {
+
+ subroutine acc_optional_data3(a, n)
+ integer :: n
+ real, optional :: a(n)
+ !$acc data no_create(a(1:n))
+ !$acc end data
+ end subroutine
+
+! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data3(
+! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
+! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "_QMopenacc_boundsFacc_optional_data3Ea"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+! CHECK: %[[PRES:.*]] = fir.is_present %[[DECL_A]]#1 : (!fir.ref<!fir.array<?xf32>>) -> i1
+! CHECK: %[[STRIDE:.*]] = fir.if %[[PRES]] -> (index) {
+! CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[DECL_A]]#0, %c0{{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+! CHECK: fir.result %[[DIMS]]#2 : index
+! CHECK: } else {
+! CHECK: fir.result %c0{{.*}} : index
+! CHECK: }
+! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true}
+! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%14) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
+! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref<!fir.array<?xf32>>) {
+
end module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Valentin!
It looks good to me, though, I think we will need to handle optional arguments differently in more complex cases. If I am reading OpenACC 3.3 2.17.1 right, the following data clause does not have any effect:
type(t), optional :: x(:)
!$acc enter data copyin(x(10)%member(10)%member)
So it looks like we need to conditionalize the designator lowering with the isPresent check earlier. It is not for this PR.
Yes it does not have any effect if x is absent. We will need to revisit this to handle all these cases correctly. Lowering became complex because we have an assumption that the bounds are reachable from the data op. Mayeb we will need to change that somehow. |
I think we may do something like this:
It becomes hard to reason about the provenance of the variable in the data clause, though. So maybe we can try to put the whole |
This is what we done in some cases but I need to generalize this to all the paths.
This was my original idea but the operation that have dataClauseOperands expect the definingOp to be a DataClauseOp so this would currently fail the verifier. I need to think about the represenatation again so that we can fit the optional better. |
In #80317 the data op generation was updated to use correctly the #0 result from the hlfir.delcare op. In case of optional that are not descriptor, it is preferable to use the original input for the varPtr value of the OpenACC data op.
This patch also make sure that the descriptor value of optional is only accessed when present.