From 58e8147d1690485ed0a6fcb59c7b6ea4b8cd2936 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Thu, 8 Feb 2024 08:49:11 -0800 Subject: [PATCH] [flang][openacc] Use original input for base address with optional (#80931) In #80317 the data op generation was updated to use correctly the #0 result from the hlfir.delcare op. In case of optional that are not descriptor, it is preferable to use the original input for the varPtr value of the OpenACC data op. This patch also make sure that the descriptor value of optional is only accessed when present. --- flang/lib/Lower/DirectivesCommon.h | 93 +++++++++++++++++++------ flang/lib/Lower/OpenACC.cpp | 20 ++++-- flang/test/Lower/OpenACC/acc-bounds.f90 | 38 +++++++++- 3 files changed, 124 insertions(+), 27 deletions(-) diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h index bd880376517dd8..8d560db34e05bf 100644 --- a/flang/lib/Lower/DirectivesCommon.h +++ b/flang/lib/Lower/DirectivesCommon.h @@ -52,10 +52,13 @@ namespace lower { /// operations. struct AddrAndBoundsInfo { explicit AddrAndBoundsInfo() {} - explicit AddrAndBoundsInfo(mlir::Value addr) : addr(addr) {} - explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value isPresent) - : addr(addr), isPresent(isPresent) {} + explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput) + : addr(addr), rawInput(rawInput) {} + explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput, + mlir::Value isPresent) + : addr(addr), rawInput(rawInput), isPresent(isPresent) {} mlir::Value addr = nullptr; + mlir::Value rawInput = nullptr; mlir::Value isPresent = nullptr; }; @@ -615,20 +618,30 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder, Fortran::lower::SymbolRef sym, mlir::Location loc) { mlir::Value symAddr = converter.getSymbolAddress(sym); + mlir::Value rawInput = symAddr; if (auto declareOp = - mlir::dyn_cast_or_null(symAddr.getDefiningOp())) + mlir::dyn_cast_or_null(symAddr.getDefiningOp())) { symAddr = declareOp.getResults()[0]; + rawInput = declareOp.getResults()[1]; + } // TODO: Might need revisiting to handle for non-shared clauses if (!symAddr) { if (const auto *details = - sym->detailsIf()) + sym->detailsIf()) { symAddr = converter.getSymbolAddress(details->symbol()); + rawInput = symAddr; + } } if (!symAddr) llvm::report_fatal_error("could not retrieve symbol address"); + mlir::Value isPresent; + if (Fortran::semantics::IsOptional(sym)) + isPresent = + builder.create(loc, builder.getI1Type(), rawInput); + if (auto boxTy = fir::unwrapRefType(symAddr.getType()).dyn_cast()) { if (boxTy.getEleTy().isa()) @@ -638,8 +651,6 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter, // `fir.ref>` type. if (symAddr.getType().isa()) { if (Fortran::semantics::IsOptional(sym)) { - mlir::Value isPresent = - builder.create(loc, builder.getI1Type(), symAddr); mlir::Value addr = builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true) .genThen([&]() { @@ -652,14 +663,13 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter, builder.create(loc, mlir::ValueRange{absent}); }) .getResults()[0]; - return AddrAndBoundsInfo(addr, isPresent); + return AddrAndBoundsInfo(addr, rawInput, isPresent); } mlir::Value addr = builder.create(loc, symAddr); - return AddrAndBoundsInfo(addr); - ; + return AddrAndBoundsInfo(addr, rawInput, isPresent); } } - return AddrAndBoundsInfo(symAddr); + return AddrAndBoundsInfo(symAddr, rawInput, isPresent); } template @@ -807,7 +817,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::StatementContext &stmtCtx, const std::list &subscripts, std::stringstream &asFortran, fir::ExtendedValue &dataExv, - bool dataExvIsAssumedSize, mlir::Value baseAddr, + bool dataExvIsAssumedSize, AddrAndBoundsInfo &info, bool treatIndexAsSection = false) { int dimension = 0; mlir::Type idxTy = builder.getIndexType(); @@ -831,11 +841,30 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value stride = one; bool strideInBytes = false; - if (fir::unwrapRefType(baseAddr.getType()).isa()) { - mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension); - auto dimInfo = builder.create(loc, idxTy, idxTy, idxTy, - baseAddr, d); - stride = dimInfo.getByteStride(); + if (fir::unwrapRefType(info.addr.getType()).isa()) { + if (info.isPresent) { + stride = + builder + .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value d = + builder.createIntegerConstant(loc, idxTy, dimension); + auto dimInfo = builder.create( + loc, idxTy, idxTy, idxTy, info.addr, d); + builder.create(loc, dimInfo.getByteStride()); + }) + .genElse([&] { + mlir::Value zero = + builder.createIntegerConstant(loc, idxTy, 0); + builder.create(loc, zero); + }) + .getResults()[0]; + } else { + mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension); + auto dimInfo = builder.create(loc, idxTy, idxTy, + idxTy, info.addr, d); + stride = dimInfo.getByteStride(); + } strideInBytes = true; } @@ -919,7 +948,26 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, } } - extent = fir::factory::readExtent(builder, loc, dataExv, dimension); + if (info.isPresent && + fir::unwrapRefType(info.addr.getType()).isa()) { + extent = + builder + .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value ext = fir::factory::readExtent( + builder, loc, dataExv, dimension); + builder.create(loc, ext); + }) + .genElse([&] { + mlir::Value zero = + builder.createIntegerConstant(loc, idxTy, 0); + builder.create(loc, zero); + }) + .getResults()[0]; + } else { + extent = fir::factory::readExtent(builder, loc, dataExv, dimension); + } + if (dataExvIsAssumedSize && dimension + 1 == dataExvRank) { extent = zero; if (ubound && lbound) { @@ -976,6 +1024,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( dataExv = converter.genExprAddr(operandLocation, *exprBase, stmtCtx); info.addr = fir::getBase(dataExv); + info.rawInput = info.addr; asFortran << (*exprBase).AsFortran(); } else { const Fortran::parser::Name &name = @@ -993,7 +1042,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( bounds = genBoundsOps( builder, operandLocation, converter, stmtCtx, arrayElement->subscripts, asFortran, dataExv, - dataExvIsAssumedSize, info.addr, treatIndexAsSection); + dataExvIsAssumedSize, info, treatIndexAsSection); } asFortran << ')'; } else if (auto structComp = Fortran::parser::Unwrap< @@ -1001,6 +1050,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( fir::ExtendedValue compExv = converter.genExprAddr(operandLocation, *expr, stmtCtx); info.addr = fir::getBase(compExv); + info.rawInput = info.addr; if (fir::unwrapRefType(info.addr.getType()) .isa()) bounds = genBaseBoundsOps( @@ -1012,7 +1062,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( *Fortran::parser::GetLastName(*structComp).symbol); if (isOptional) info.isPresent = builder.create( - operandLocation, builder.getI1Type(), info.addr); + operandLocation, builder.getI1Type(), info.rawInput); if (auto loadOp = mlir::dyn_cast_or_null( info.addr.getDefiningOp())) { @@ -1020,6 +1070,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( fir::isPointerType(loadOp.getType())) info.addr = builder.create(operandLocation, info.addr); + info.rawInput = info.addr; } // If the component is an allocatable or pointer the result of @@ -1029,6 +1080,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( if (auto boxAddrOp = mlir::dyn_cast_or_null( info.addr.getDefiningOp())) { info.addr = boxAddrOp.getVal(); + info.rawInput = info.addr; bounds = genBoundsOpsFromBox( builder, operandLocation, converter, compExv, info); } @@ -1043,6 +1095,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( fir::ExtendedValue compExv = converter.genExprAddr(operandLocation, *expr, stmtCtx); info.addr = fir::getBase(compExv); + info.rawInput = info.addr; asFortran << (*expr).AsFortran(); } else if (const auto *dataRef{ std::get_if( diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 43f54c6d2a71bb..6ae270f63f5cf4 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -67,9 +67,12 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value varPtrPtr; if (auto boxTy = baseAddr.getType().dyn_cast()) { if (isPresent) { + mlir::Type ifRetTy = boxTy.getEleTy(); + if (!fir::isa_ref_type(ifRetTy)) + ifRetTy = fir::ReferenceType::get(ifRetTy); baseAddr = builder - .genIfOp(loc, {boxTy.getEleTy()}, isPresent, + .genIfOp(loc, {ifRetTy}, isPresent, /*withElseRegion=*/true) .genThen([&]() { mlir::Value boxAddr = @@ -78,7 +81,7 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc, }) .genElse([&] { mlir::Value absent = - builder.create(loc, boxTy.getEleTy()); + builder.create(loc, ifRetTy); builder.create(loc, mlir::ValueRange{absent}); }) .getResults()[0]; @@ -295,9 +298,16 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, asFortran, bounds, /*treatIndexAsSection=*/true); - Op op = createDataEntryOp( - builder, operandLocation, info.addr, asFortran, bounds, structured, - implicit, dataClause, info.addr.getType(), info.isPresent); + // If the input value is optional and is not a descriptor, we use the + // rawInput directly. + mlir::Value baseAddr = + ((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) && + info.isPresent) + ? info.rawInput + : info.addr; + Op op = createDataEntryOp(builder, operandLocation, baseAddr, asFortran, + bounds, structured, implicit, dataClause, + baseAddr.getType(), info.isPresent); dataOperands.push_back(op.getAccPtr()); } } diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90 index bd96bc8bcba359..df97cbcd187d2b 100644 --- a/flang/test/Lower/OpenACC/acc-bounds.f90 +++ b/flang/test/Lower/OpenACC/acc-bounds.f90 @@ -126,8 +126,8 @@ subroutine acc_optional_data(a) ! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data( ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref>>> {fir.bindc_name = "a", fir.optional}) { -! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) -! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#0 : (!fir.ref>>>) -> i1 +! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) +! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref>>>) -> i1 ! CHECK: %[[BOX:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box>>) { ! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref>>> ! CHECK: fir.result %[[LOAD]] : !fir.box>> @@ -153,4 +153,38 @@ subroutine acc_optional_data(a) ! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr>) bounds(%[[BOUND]]) -> !fir.ptr> {name = "a"} ! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr>) + subroutine acc_optional_data2(a, n) + integer :: n + real, optional :: a(n) + !$acc data no_create(a) + !$acc end data + end subroutine + +! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data2( +! CHECK-SAME: %[[A:.*]]: !fir.ref> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref {fir.bindc_name = "n"}) { +! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs, uniq_name = "_QMopenacc_boundsFacc_optional_data2Ea"} : (!fir.ref>, !fir.shape<1>) -> (!fir.box>, !fir.ref>) +! CHECK: %[[NO_CREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref>) bounds(%10) -> !fir.ref> {name = "a"} +! CHECK: acc.data dataOperands(%[[NO_CREATE]] : !fir.ref>) { + + subroutine acc_optional_data3(a, n) + integer :: n + real, optional :: a(n) + !$acc data no_create(a(1:n)) + !$acc end data + end subroutine + +! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data3( +! CHECK-SAME: %[[A:.*]]: !fir.ref> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref {fir.bindc_name = "n"}) { +! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs, uniq_name = "_QMopenacc_boundsFacc_optional_data3Ea"} : (!fir.ref>, !fir.shape<1>) -> (!fir.box>, !fir.ref>) +! CHECK: %[[PRES:.*]] = fir.is_present %[[DECL_A]]#1 : (!fir.ref>) -> i1 +! CHECK: %[[STRIDE:.*]] = fir.if %[[PRES]] -> (index) { +! CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[DECL_A]]#0, %c0{{.*}} : (!fir.box>, index) -> (index, index, index) +! CHECK: fir.result %[[DIMS]]#2 : index +! CHECK: } else { +! CHECK: fir.result %c0{{.*}} : index +! CHECK: } +! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true} +! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref>) bounds(%14) -> !fir.ref> {name = "a(1:n)"} +! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref>) { + end module