diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp index 07c9e6a1726bf..13d612354da84 100644 --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -609,11 +609,22 @@ static void genNamelistIO(Fortran::lower::AbstractConverter &converter, ok = builder.create(loc, funcOp, args).getResult(0); } +/// Is \p type a derived type or an array of derived type? +static bool containsDerivedType(mlir::Type type) { + mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(type)); + if (mlir::isa(argTy)) + return true; + if (auto seqTy = mlir::dyn_cast(argTy)) + if (mlir::isa(seqTy.getEleTy())) + return true; + return false; +} + /// Get the output function to call for a value of the given type. static mlir::func::FuncOp getOutputFunc(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type type, bool isFormatted) { - if (mlir::isa(fir::unwrapPassByRefType(type))) + if (containsDerivedType(type)) return fir::runtime::getIORuntimeFunc(loc, builder); if (!isFormatted) @@ -710,7 +721,7 @@ static void genOutputItemList( if (mlir::isa(argType)) { mlir::Value box = fir::getBase(converter.genExprBox(loc, *expr, stmtCtx)); outputFuncArgs.push_back(builder.createConvert(loc, argType, box)); - if (mlir::isa(fir::unwrapPassByRefType(itemTy))) + if (containsDerivedType(itemTy)) outputFuncArgs.push_back(getNonTbpDefinedIoTableAddr(converter)); } else if (helper.isCharacterScalar(itemTy)) { fir::ExtendedValue exv = converter.genExprAddr(loc, expr, stmtCtx); @@ -745,7 +756,7 @@ static void genOutputItemList( static mlir::func::FuncOp getInputFunc(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type type, bool isFormatted) { - if (mlir::isa(fir::unwrapPassByRefType(type))) + if (containsDerivedType(type)) return fir::runtime::getIORuntimeFunc(loc, builder); if (!isFormatted) @@ -817,7 +828,7 @@ createIoRuntimeCallForItem(Fortran::lower::AbstractConverter &converter, auto boxTy = mlir::dyn_cast(box.getType()); assert(boxTy && "must be previously emboxed"); inputFuncArgs.push_back(builder.createConvert(loc, argType, box)); - if (mlir::isa(fir::unwrapPassByRefType(boxTy))) + if (containsDerivedType(boxTy)) inputFuncArgs.push_back(getNonTbpDefinedIoTableAddr(converter)); } else { mlir::Value itemAddr = fir::getBase(item); diff --git a/flang/test/Lower/io-derived-type.f90 b/flang/test/Lower/io-derived-type.f90 index ecbbc22d24b1e..316a2cdb5b14f 100644 --- a/flang/test/Lower/io-derived-type.f90 +++ b/flang/test/Lower/io-derived-type.f90 @@ -101,6 +101,7 @@ program p use m character*3 ccc(4) namelist /nnn/ jjj, ccc + type(t) :: y(5) ! CHECK: fir.call @_QMmPtest1 call test1 @@ -115,6 +116,16 @@ program p ! CHECK: %[[V_100:[0-9]+]] = fir.convert %[[V_99]] : (!fir.ref, !fir.ref, i32, i1>>>, i1>>) -> !fir.ref ! CHECK: %[[V_101:[0-9]+]] = fir.call @_FortranAioOutputDerivedType(%{{.*}}, %[[V_98]], %[[V_100]]) fastmath : (!fir.ref, !fir.box, !fir.ref) -> i1 print *, 'main, should call wft: ', t(4) + + ! CHECK: %[[V_33:[0-9]+]] = fir.shape %c2{{.*}} : (index) -> !fir.shape<1> + ! CHECK: %[[V_34:[0-9]+]] = hlfir.designate %7#0 (%c2{{.*}}:%c3{{.*}}:%c1{{.*}}) shape %[[V_33]] : (!fir.ref>>, index, index, index, !fir.shape<1>) -> !fir.ref>> + ! CHECK: %[[V_35:[0-9]+]] = fir.shape %c2{{.*}} : (index) -> !fir.shape<1> + ! CHECK: %[[V_36:[0-9]+]] = fir.embox %[[V_34]](%[[V_35]]) : (!fir.ref>>, !fir.shape<1>) -> !fir.box>> + ! CHECK: %[[V_37:[0-9]+]] = fir.convert %[[V_36]] : (!fir.box>>) -> !fir.box + ! CHECK: %[[V_38:[0-9]+]] = fir.address_of(@_QQF.nonTbpDefinedIoTable) : !fir.ref, !fir.ref, i32, i1>>>, i1>> + ! CHECK: %[[V_39:[0-9]+]] = fir.convert %[[V_38]] : (!fir.ref, !fir.ref, i32, i1>>>, i1>>) -> !fir.ref + ! CHECK: %[[V_40:[0-9]+]] = fir.call @_FortranAioOutputDerivedType(%{{.*}}, %[[V_37]], %[[V_39]]) fastmath : (!fir.ref, !fir.box, !fir.ref) -> i1 + print *, y(2:3) end ! CHECK: fir.global linkonce @_QQMmFtest1.nonTbpDefinedIoTable.list constant : !fir.array<1xtuple, !fir.ref, i32, i1>>