diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td index e95af629ef32f..f643674f1d5d6 100644 --- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td +++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td @@ -161,10 +161,11 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> { let arguments = (ins Arg:$src, Arg:$dst, + Optional:$shape, cuf_DataTransferKindAttr:$transfer_kind); let assemblyFormat = [{ - $src `to` $dst attr-dict `:` type(operands) + $src `to` $dst (`,` $shape^ `:` type($shape) )? attr-dict `:` type($src) `,` type($dst) }]; let hasVerifier = 1; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index ccbb481f472d8..24cd6b22b8925 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -4272,18 +4272,19 @@ class FirConverter : public Fortran::lower::AbstractConverter { base = convertOp.getValue(); // Special case if the rhs is a constant. if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) { - builder.create(loc, base, lhsVal, - transferKindAttr); + builder.create( + loc, base, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr); } else { auto associate = hlfir::genAssociateExpr( loc, builder, rhs, rhs.getType(), ".cuf_host_tmp"); builder.create(loc, associate.getBase(), lhsVal, + /*shape=*/mlir::Value{}, transferKindAttr); builder.create(loc, associate); } } else { - builder.create(loc, rhsVal, lhsVal, - transferKindAttr); + builder.create( + loc, rhsVal, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr); } return; } @@ -4293,6 +4294,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { auto transferKindAttr = cuf::DataTransferKindAttr::get( builder.getContext(), cuf::DataTransferKind::DeviceHost); builder.create(loc, rhsVal, lhsVal, + /*shape=*/mlir::Value{}, transferKindAttr); return; } @@ -4303,6 +4305,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { auto transferKindAttr = cuf::DataTransferKindAttr::get( builder.getContext(), cuf::DataTransferKind::DeviceDevice); builder.create(loc, rhsVal, lhsVal, + /*shape=*/mlir::Value{}, transferKindAttr); return; } @@ -4346,8 +4349,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { addSymbol(sym, hlfir::translateToExtendedValue(loc, builder, temp).first, /*forced=*/true); - builder.create(loc, addr, temp, - transferKindAttr); + builder.create( + loc, addr, temp, /*shape=*/mlir::Value{}, transferKindAttr); ++nbDeviceResidentObject; } } diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp index f7b36b208a7de..3b4ad95cafe6b 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp @@ -99,6 +99,11 @@ llvm::LogicalResult cuf::AllocateOp::verify() { llvm::LogicalResult cuf::DataTransferOp::verify() { mlir::Type srcTy = getSrc().getType(); mlir::Type dstTy = getDst().getType(); + if (getShape()) { + if (!fir::isa_ref_type(srcTy) || !fir::isa_ref_type(dstTy)) + return emitOpError() + << "shape can only be specified on data transfer with references"; + } if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) || (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) || (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) || diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir index 06e08d14b2435..e9aeaa281e2a8 100644 --- a/flang/test/Fir/cuf-invalid.fir +++ b/flang/test/Fir/cuf-invalid.fir @@ -94,3 +94,34 @@ func.func @_QPsub1() { cuf.free %0 : !fir.ref {data_attr = #cuf.cuda} return } + +// ----- + +func.func @_QPsub1(%arg0: !fir.ref> {cuf.data_attr = #cuf.cuda, fir.bindc_name = "adev"}, %arg1: !fir.ref> {fir.bindc_name = "ahost"}, %arg2: !fir.ref {fir.bindc_name = "n"}, %arg3: !fir.ref {fir.bindc_name = "m"}) { + %0 = fir.dummy_scope : !fir.dscope + %1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFsub1En"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) + %2:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFsub1Em"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) + %3 = fir.load %1#0 : !fir.ref + %4 = fir.load %2#0 : !fir.ref + %5 = arith.muli %3, %4 : i32 + %6 = fir.convert %5 : (i32) -> i64 + %7 = fir.convert %6 : (i64) -> index + %c0 = arith.constant 0 : index + %8 = arith.cmpi sgt, %7, %c0 : index + %9 = arith.select %8, %7, %c0 : index + %10 = fir.shape %9 : (index) -> !fir.shape<1> + %11:2 = hlfir.declare %arg0(%10) dummy_scope %0 {data_attr = #cuf.cuda, uniq_name = "_QFsub1Eadev"} : (!fir.ref>, !fir.shape<1>, !fir.dscope) -> (!fir.box>, !fir.ref>) + %12 = fir.load %1#0 : !fir.ref + %13 = fir.load %2#0 : !fir.ref + %14 = arith.muli %12, %13 : i32 + %15 = fir.convert %14 : (i32) -> i64 + %16 = fir.convert %15 : (i64) -> index + %c0_0 = arith.constant 0 : index + %17 = arith.cmpi sgt, %16, %c0_0 : index + %18 = arith.select %17, %16, %c0_0 : index + %19 = fir.shape %18 : (index) -> !fir.shape<1> + %20:2 = hlfir.declare %arg1(%19) dummy_scope %0 {uniq_name = "_QFsub1Eahost"} : (!fir.ref>, !fir.shape<1>, !fir.dscope) -> (!fir.box>, !fir.ref>) + // expected-error@+1{{'cuf.data_transfer' op shape can only be specified on data transfer with references}} + cuf.data_transfer %20#0 to %11#0, %19 : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer} : !fir.box>, !fir.box> + return +}