diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h index bd880376517dd8..8d560db34e05bf 100644 --- a/flang/lib/Lower/DirectivesCommon.h +++ b/flang/lib/Lower/DirectivesCommon.h @@ -52,10 +52,13 @@ namespace lower { /// operations. struct AddrAndBoundsInfo { explicit AddrAndBoundsInfo() {} - explicit AddrAndBoundsInfo(mlir::Value addr) : addr(addr) {} - explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value isPresent) - : addr(addr), isPresent(isPresent) {} + explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput) + : addr(addr), rawInput(rawInput) {} + explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput, + mlir::Value isPresent) + : addr(addr), rawInput(rawInput), isPresent(isPresent) {} mlir::Value addr = nullptr; + mlir::Value rawInput = nullptr; mlir::Value isPresent = nullptr; }; @@ -615,20 +618,30 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder, Fortran::lower::SymbolRef sym, mlir::Location loc) { mlir::Value symAddr = converter.getSymbolAddress(sym); + mlir::Value rawInput = symAddr; if (auto declareOp = - mlir::dyn_cast_or_null(symAddr.getDefiningOp())) + mlir::dyn_cast_or_null(symAddr.getDefiningOp())) { symAddr = declareOp.getResults()[0]; + rawInput = declareOp.getResults()[1]; + } // TODO: Might need revisiting to handle for non-shared clauses if (!symAddr) { if (const auto *details = - sym->detailsIf()) + sym->detailsIf()) { symAddr = converter.getSymbolAddress(details->symbol()); + rawInput = symAddr; + } } if (!symAddr) llvm::report_fatal_error("could not retrieve symbol address"); + mlir::Value isPresent; + if (Fortran::semantics::IsOptional(sym)) + isPresent = + builder.create(loc, builder.getI1Type(), rawInput); + if (auto boxTy = fir::unwrapRefType(symAddr.getType()).dyn_cast()) { if (boxTy.getEleTy().isa()) @@ -638,8 +651,6 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter, // `fir.ref>` type. if (symAddr.getType().isa()) { if (Fortran::semantics::IsOptional(sym)) { - mlir::Value isPresent = - builder.create(loc, builder.getI1Type(), symAddr); mlir::Value addr = builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true) .genThen([&]() { @@ -652,14 +663,13 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter, builder.create(loc, mlir::ValueRange{absent}); }) .getResults()[0]; - return AddrAndBoundsInfo(addr, isPresent); + return AddrAndBoundsInfo(addr, rawInput, isPresent); } mlir::Value addr = builder.create(loc, symAddr); - return AddrAndBoundsInfo(addr); - ; + return AddrAndBoundsInfo(addr, rawInput, isPresent); } } - return AddrAndBoundsInfo(symAddr); + return AddrAndBoundsInfo(symAddr, rawInput, isPresent); } template @@ -807,7 +817,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, Fortran::lower::StatementContext &stmtCtx, const std::list &subscripts, std::stringstream &asFortran, fir::ExtendedValue &dataExv, - bool dataExvIsAssumedSize, mlir::Value baseAddr, + bool dataExvIsAssumedSize, AddrAndBoundsInfo &info, bool treatIndexAsSection = false) { int dimension = 0; mlir::Type idxTy = builder.getIndexType(); @@ -831,11 +841,30 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value stride = one; bool strideInBytes = false; - if (fir::unwrapRefType(baseAddr.getType()).isa()) { - mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension); - auto dimInfo = builder.create(loc, idxTy, idxTy, idxTy, - baseAddr, d); - stride = dimInfo.getByteStride(); + if (fir::unwrapRefType(info.addr.getType()).isa()) { + if (info.isPresent) { + stride = + builder + .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value d = + builder.createIntegerConstant(loc, idxTy, dimension); + auto dimInfo = builder.create( + loc, idxTy, idxTy, idxTy, info.addr, d); + builder.create(loc, dimInfo.getByteStride()); + }) + .genElse([&] { + mlir::Value zero = + builder.createIntegerConstant(loc, idxTy, 0); + builder.create(loc, zero); + }) + .getResults()[0]; + } else { + mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension); + auto dimInfo = builder.create(loc, idxTy, idxTy, + idxTy, info.addr, d); + stride = dimInfo.getByteStride(); + } strideInBytes = true; } @@ -919,7 +948,26 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc, } } - extent = fir::factory::readExtent(builder, loc, dataExv, dimension); + if (info.isPresent && + fir::unwrapRefType(info.addr.getType()).isa()) { + extent = + builder + .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value ext = fir::factory::readExtent( + builder, loc, dataExv, dimension); + builder.create(loc, ext); + }) + .genElse([&] { + mlir::Value zero = + builder.createIntegerConstant(loc, idxTy, 0); + builder.create(loc, zero); + }) + .getResults()[0]; + } else { + extent = fir::factory::readExtent(builder, loc, dataExv, dimension); + } + if (dataExvIsAssumedSize && dimension + 1 == dataExvRank) { extent = zero; if (ubound && lbound) { @@ -976,6 +1024,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( dataExv = converter.genExprAddr(operandLocation, *exprBase, stmtCtx); info.addr = fir::getBase(dataExv); + info.rawInput = info.addr; asFortran << (*exprBase).AsFortran(); } else { const Fortran::parser::Name &name = @@ -993,7 +1042,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( bounds = genBoundsOps( builder, operandLocation, converter, stmtCtx, arrayElement->subscripts, asFortran, dataExv, - dataExvIsAssumedSize, info.addr, treatIndexAsSection); + dataExvIsAssumedSize, info, treatIndexAsSection); } asFortran << ')'; } else if (auto structComp = Fortran::parser::Unwrap< @@ -1001,6 +1050,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( fir::ExtendedValue compExv = converter.genExprAddr(operandLocation, *expr, stmtCtx); info.addr = fir::getBase(compExv); + info.rawInput = info.addr; if (fir::unwrapRefType(info.addr.getType()) .isa()) bounds = genBaseBoundsOps( @@ -1012,7 +1062,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( *Fortran::parser::GetLastName(*structComp).symbol); if (isOptional) info.isPresent = builder.create( - operandLocation, builder.getI1Type(), info.addr); + operandLocation, builder.getI1Type(), info.rawInput); if (auto loadOp = mlir::dyn_cast_or_null( info.addr.getDefiningOp())) { @@ -1020,6 +1070,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( fir::isPointerType(loadOp.getType())) info.addr = builder.create(operandLocation, info.addr); + info.rawInput = info.addr; } // If the component is an allocatable or pointer the result of @@ -1029,6 +1080,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( if (auto boxAddrOp = mlir::dyn_cast_or_null( info.addr.getDefiningOp())) { info.addr = boxAddrOp.getVal(); + info.rawInput = info.addr; bounds = genBoundsOpsFromBox( builder, operandLocation, converter, compExv, info); } @@ -1043,6 +1095,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds( fir::ExtendedValue compExv = converter.genExprAddr(operandLocation, *expr, stmtCtx); info.addr = fir::getBase(compExv); + info.rawInput = info.addr; asFortran << (*expr).AsFortran(); } else if (const auto *dataRef{ std::get_if( diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 43f54c6d2a71bb..6ae270f63f5cf4 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -67,9 +67,12 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value varPtrPtr; if (auto boxTy = baseAddr.getType().dyn_cast()) { if (isPresent) { + mlir::Type ifRetTy = boxTy.getEleTy(); + if (!fir::isa_ref_type(ifRetTy)) + ifRetTy = fir::ReferenceType::get(ifRetTy); baseAddr = builder - .genIfOp(loc, {boxTy.getEleTy()}, isPresent, + .genIfOp(loc, {ifRetTy}, isPresent, /*withElseRegion=*/true) .genThen([&]() { mlir::Value boxAddr = @@ -78,7 +81,7 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc, }) .genElse([&] { mlir::Value absent = - builder.create(loc, boxTy.getEleTy()); + builder.create(loc, ifRetTy); builder.create(loc, mlir::ValueRange{absent}); }) .getResults()[0]; @@ -295,9 +298,16 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, asFortran, bounds, /*treatIndexAsSection=*/true); - Op op = createDataEntryOp( - builder, operandLocation, info.addr, asFortran, bounds, structured, - implicit, dataClause, info.addr.getType(), info.isPresent); + // If the input value is optional and is not a descriptor, we use the + // rawInput directly. + mlir::Value baseAddr = + ((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) && + info.isPresent) + ? info.rawInput + : info.addr; + Op op = createDataEntryOp(builder, operandLocation, baseAddr, asFortran, + bounds, structured, implicit, dataClause, + baseAddr.getType(), info.isPresent); dataOperands.push_back(op.getAccPtr()); } } diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90 index bd96bc8bcba359..df97cbcd187d2b 100644 --- a/flang/test/Lower/OpenACC/acc-bounds.f90 +++ b/flang/test/Lower/OpenACC/acc-bounds.f90 @@ -126,8 +126,8 @@ subroutine acc_optional_data(a) ! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data( ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref>>> {fir.bindc_name = "a", fir.optional}) { -! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) -! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#0 : (!fir.ref>>>) -> i1 +! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) +! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref>>>) -> i1 ! CHECK: %[[BOX:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box>>) { ! CHECK: %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref>>> ! CHECK: fir.result %[[LOAD]] : !fir.box>> @@ -153,4 +153,38 @@ subroutine acc_optional_data(a) ! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr>) bounds(%[[BOUND]]) -> !fir.ptr> {name = "a"} ! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr>) + subroutine acc_optional_data2(a, n) + integer :: n + real, optional :: a(n) + !$acc data no_create(a) + !$acc end data + end subroutine + +! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data2( +! CHECK-SAME: %[[A:.*]]: !fir.ref> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref {fir.bindc_name = "n"}) { +! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs, uniq_name = "_QMopenacc_boundsFacc_optional_data2Ea"} : (!fir.ref>, !fir.shape<1>) -> (!fir.box>, !fir.ref>) +! CHECK: %[[NO_CREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref>) bounds(%10) -> !fir.ref> {name = "a"} +! CHECK: acc.data dataOperands(%[[NO_CREATE]] : !fir.ref>) { + + subroutine acc_optional_data3(a, n) + integer :: n + real, optional :: a(n) + !$acc data no_create(a(1:n)) + !$acc end data + end subroutine + +! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data3( +! CHECK-SAME: %[[A:.*]]: !fir.ref> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref {fir.bindc_name = "n"}) { +! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs, uniq_name = "_QMopenacc_boundsFacc_optional_data3Ea"} : (!fir.ref>, !fir.shape<1>) -> (!fir.box>, !fir.ref>) +! CHECK: %[[PRES:.*]] = fir.is_present %[[DECL_A]]#1 : (!fir.ref>) -> i1 +! CHECK: %[[STRIDE:.*]] = fir.if %[[PRES]] -> (index) { +! CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[DECL_A]]#0, %c0{{.*}} : (!fir.box>, index) -> (index, index, index) +! CHECK: fir.result %[[DIMS]]#2 : index +! CHECK: } else { +! CHECK: fir.result %c0{{.*}} : index +! CHECK: } +! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true} +! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref>) bounds(%14) -> !fir.ref> {name = "a(1:n)"} +! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref>) { + end module