Skip to content

Commit 7af61d5

Browse files
authored
[flang][cuda] Add shape to cuf.data_transfer operation (#104631)
When doing data transfer with dynamic sized array, we are currently generating a data transfer between two descriptors. If the shape values can be provided, we can keep the data transfer between two references. This patch adds the shape operands to the operation. This will be exploited in lowering in a follow up patch.
1 parent 7106643 commit 7af61d5

File tree

4 files changed

+47
-7
lines changed

4 files changed

+47
-7
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,11 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
161161

162162
let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
163163
Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
164+
Optional<fir_ShapeType>:$shape,
164165
cuf_DataTransferKindAttr:$transfer_kind);
165166

166167
let assemblyFormat = [{
167-
$src `to` $dst attr-dict `:` type(operands)
168+
$src `to` $dst (`,` $shape^ `:` type($shape) )? attr-dict `:` type($src) `,` type($dst)
168169
}];
169170

170171
let hasVerifier = 1;

flang/lib/Lower/Bridge.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4272,18 +4272,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
42724272
base = convertOp.getValue();
42734273
// Special case if the rhs is a constant.
42744274
if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
4275-
builder.create<cuf::DataTransferOp>(loc, base, lhsVal,
4276-
transferKindAttr);
4275+
builder.create<cuf::DataTransferOp>(
4276+
loc, base, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
42774277
} else {
42784278
auto associate = hlfir::genAssociateExpr(
42794279
loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
42804280
builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
4281+
/*shape=*/mlir::Value{},
42814282
transferKindAttr);
42824283
builder.create<hlfir::EndAssociateOp>(loc, associate);
42834284
}
42844285
} else {
4285-
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
4286-
transferKindAttr);
4286+
builder.create<cuf::DataTransferOp>(
4287+
loc, rhsVal, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
42874288
}
42884289
return;
42894290
}
@@ -4293,6 +4294,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
42934294
auto transferKindAttr = cuf::DataTransferKindAttr::get(
42944295
builder.getContext(), cuf::DataTransferKind::DeviceHost);
42954296
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
4297+
/*shape=*/mlir::Value{},
42964298
transferKindAttr);
42974299
return;
42984300
}
@@ -4303,6 +4305,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
43034305
auto transferKindAttr = cuf::DataTransferKindAttr::get(
43044306
builder.getContext(), cuf::DataTransferKind::DeviceDevice);
43054307
builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
4308+
/*shape=*/mlir::Value{},
43064309
transferKindAttr);
43074310
return;
43084311
}
@@ -4346,8 +4349,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
43464349
addSymbol(sym,
43474350
hlfir::translateToExtendedValue(loc, builder, temp).first,
43484351
/*forced=*/true);
4349-
builder.create<cuf::DataTransferOp>(loc, addr, temp,
4350-
transferKindAttr);
4352+
builder.create<cuf::DataTransferOp>(
4353+
loc, addr, temp, /*shape=*/mlir::Value{}, transferKindAttr);
43514354
++nbDeviceResidentObject;
43524355
}
43534356
}

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ llvm::LogicalResult cuf::AllocateOp::verify() {
9999
llvm::LogicalResult cuf::DataTransferOp::verify() {
100100
mlir::Type srcTy = getSrc().getType();
101101
mlir::Type dstTy = getDst().getType();
102+
if (getShape()) {
103+
if (!fir::isa_ref_type(srcTy) || !fir::isa_ref_type(dstTy))
104+
return emitOpError()
105+
<< "shape can only be specified on data transfer with references";
106+
}
102107
if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
103108
(fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) ||
104109
(fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) ||

flang/test/Fir/cuf-invalid.fir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,34 @@ func.func @_QPsub1() {
9494
cuf.free %0 : !fir.ref<f32> {data_attr = #cuf.cuda<constant>}
9595
return
9696
}
97+
98+
// -----
99+
100+
func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "adev"}, %arg1: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "ahost"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}, %arg3: !fir.ref<i32> {fir.bindc_name = "m"}) {
101+
%0 = fir.dummy_scope : !fir.dscope
102+
%1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFsub1En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
103+
%2:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFsub1Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
104+
%3 = fir.load %1#0 : !fir.ref<i32>
105+
%4 = fir.load %2#0 : !fir.ref<i32>
106+
%5 = arith.muli %3, %4 : i32
107+
%6 = fir.convert %5 : (i32) -> i64
108+
%7 = fir.convert %6 : (i64) -> index
109+
%c0 = arith.constant 0 : index
110+
%8 = arith.cmpi sgt, %7, %c0 : index
111+
%9 = arith.select %8, %7, %c0 : index
112+
%10 = fir.shape %9 : (index) -> !fir.shape<1>
113+
%11:2 = hlfir.declare %arg0(%10) dummy_scope %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
114+
%12 = fir.load %1#0 : !fir.ref<i32>
115+
%13 = fir.load %2#0 : !fir.ref<i32>
116+
%14 = arith.muli %12, %13 : i32
117+
%15 = fir.convert %14 : (i32) -> i64
118+
%16 = fir.convert %15 : (i64) -> index
119+
%c0_0 = arith.constant 0 : index
120+
%17 = arith.cmpi sgt, %16, %c0_0 : index
121+
%18 = arith.select %17, %16, %c0_0 : index
122+
%19 = fir.shape %18 : (index) -> !fir.shape<1>
123+
%20:2 = hlfir.declare %arg1(%19) dummy_scope %0 {uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
124+
// expected-error@+1{{'cuf.data_transfer' op shape can only be specified on data transfer with references}}
125+
cuf.data_transfer %20#0 to %11#0, %19 : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
126+
return
127+
}

0 commit comments

Comments
 (0)